1use crate::{
4 bigint::{prelude::*, Limb, NonZero},
5 rand_core::{CryptoRng, RngCore},
6 subtle::{
7 Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
8 CtOption,
9 },
10 Curve, Error, FieldBytes, IsHigh, Result,
11};
12use base16ct::HexDisplay;
13use core::{
14 cmp::Ordering,
15 fmt,
16 ops::{Add, AddAssign, Neg, Sub, SubAssign},
17 str,
18};
19use generic_array::GenericArray;
20use zeroize::DefaultIsZeroes;
21
22#[cfg(feature = "arithmetic")]
23use {
24 super::{Scalar, ScalarArithmetic},
25 group::ff::PrimeField,
26};
27
28#[cfg(feature = "serde")]
29use serdect::serde::{de, ser, Deserialize, Serialize};
30
31#[derive(Copy, Clone, Debug, Default)]
46#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
47pub struct ScalarCore<C: Curve> {
48 inner: C::UInt,
50}
51
52impl<C> ScalarCore<C>
53where
54 C: Curve,
55{
56 pub const ZERO: Self = Self {
58 inner: C::UInt::ZERO,
59 };
60
61 pub const ONE: Self = Self {
63 inner: C::UInt::ONE,
64 };
65
66 pub const MODULUS: C::UInt = C::ORDER;
68
69 pub fn random(rng: impl CryptoRng + RngCore) -> Self {
71 Self {
72 inner: C::UInt::random_mod(rng, &NonZero::new(Self::MODULUS).unwrap()),
73 }
74 }
75
76 pub fn new(uint: C::UInt) -> CtOption<Self> {
78 CtOption::new(Self { inner: uint }, uint.ct_lt(&Self::MODULUS))
79 }
80
81 pub fn from_be_bytes(bytes: FieldBytes<C>) -> CtOption<Self> {
83 Self::new(C::UInt::from_be_byte_array(bytes))
84 }
85
86 pub fn from_be_slice(slice: &[u8]) -> Result<Self> {
88 if slice.len() == C::UInt::BYTE_SIZE {
89 Option::from(Self::from_be_bytes(GenericArray::clone_from_slice(slice))).ok_or(Error)
90 } else {
91 Err(Error)
92 }
93 }
94
95 pub fn from_le_bytes(bytes: FieldBytes<C>) -> CtOption<Self> {
97 Self::new(C::UInt::from_le_byte_array(bytes))
98 }
99
100 pub fn from_le_slice(slice: &[u8]) -> Result<Self> {
102 if slice.len() == C::UInt::BYTE_SIZE {
103 Option::from(Self::from_le_bytes(GenericArray::clone_from_slice(slice))).ok_or(Error)
104 } else {
105 Err(Error)
106 }
107 }
108
109 pub fn as_uint(&self) -> &C::UInt {
111 &self.inner
112 }
113
114 pub fn as_limbs(&self) -> &[Limb] {
116 self.inner.as_ref()
117 }
118
119 pub fn is_zero(&self) -> Choice {
121 self.inner.is_zero()
122 }
123
124 pub fn is_even(&self) -> Choice {
126 self.inner.is_even()
127 }
128
129 pub fn is_odd(&self) -> Choice {
131 self.inner.is_odd()
132 }
133
134 pub fn to_be_bytes(self) -> FieldBytes<C> {
136 self.inner.to_be_byte_array()
137 }
138
139 pub fn to_le_bytes(self) -> FieldBytes<C> {
141 self.inner.to_le_byte_array()
142 }
143}
144
145#[cfg(feature = "arithmetic")]
146impl<C> ScalarCore<C>
147where
148 C: Curve + ScalarArithmetic,
149{
150 pub(super) fn to_scalar(self) -> Scalar<C> {
153 Scalar::<C>::from_repr(self.to_be_bytes()).unwrap()
154 }
155}
156
157impl<C> AsRef<[Limb]> for ScalarCore<C>
159where
160 C: Curve,
161{
162 fn as_ref(&self) -> &[Limb] {
163 self.as_limbs()
164 }
165}
166
167impl<C> ConditionallySelectable for ScalarCore<C>
168where
169 C: Curve,
170{
171 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
172 Self {
173 inner: C::UInt::conditional_select(&a.inner, &b.inner, choice),
174 }
175 }
176}
177
178impl<C> ConstantTimeEq for ScalarCore<C>
179where
180 C: Curve,
181{
182 fn ct_eq(&self, other: &Self) -> Choice {
183 self.inner.ct_eq(&other.inner)
184 }
185}
186
187impl<C> ConstantTimeLess for ScalarCore<C>
188where
189 C: Curve,
190{
191 fn ct_lt(&self, other: &Self) -> Choice {
192 self.inner.ct_lt(&other.inner)
193 }
194}
195
196impl<C> ConstantTimeGreater for ScalarCore<C>
197where
198 C: Curve,
199{
200 fn ct_gt(&self, other: &Self) -> Choice {
201 self.inner.ct_gt(&other.inner)
202 }
203}
204
205impl<C: Curve> DefaultIsZeroes for ScalarCore<C> {}
206
207impl<C: Curve> Eq for ScalarCore<C> {}
208
209impl<C> PartialEq for ScalarCore<C>
210where
211 C: Curve,
212{
213 fn eq(&self, other: &Self) -> bool {
214 self.ct_eq(other).into()
215 }
216}
217
218impl<C> PartialOrd for ScalarCore<C>
219where
220 C: Curve,
221{
222 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
223 Some(self.cmp(other))
224 }
225}
226
227impl<C> Ord for ScalarCore<C>
228where
229 C: Curve,
230{
231 fn cmp(&self, other: &Self) -> Ordering {
232 self.inner.cmp(&other.inner)
233 }
234}
235
236impl<C> From<u64> for ScalarCore<C>
237where
238 C: Curve,
239{
240 fn from(n: u64) -> Self {
241 Self {
242 inner: C::UInt::from(n),
243 }
244 }
245}
246
247impl<C> Add<ScalarCore<C>> for ScalarCore<C>
248where
249 C: Curve,
250{
251 type Output = Self;
252
253 fn add(self, other: Self) -> Self {
254 self.add(&other)
255 }
256}
257
258impl<C> Add<&ScalarCore<C>> for ScalarCore<C>
259where
260 C: Curve,
261{
262 type Output = Self;
263
264 fn add(self, other: &Self) -> Self {
265 Self {
266 inner: self.inner.add_mod(&other.inner, &Self::MODULUS),
267 }
268 }
269}
270
271impl<C> AddAssign<ScalarCore<C>> for ScalarCore<C>
272where
273 C: Curve,
274{
275 fn add_assign(&mut self, other: Self) {
276 *self = *self + other;
277 }
278}
279
280impl<C> AddAssign<&ScalarCore<C>> for ScalarCore<C>
281where
282 C: Curve,
283{
284 fn add_assign(&mut self, other: &Self) {
285 *self = *self + other;
286 }
287}
288
289impl<C> Sub<ScalarCore<C>> for ScalarCore<C>
290where
291 C: Curve,
292{
293 type Output = Self;
294
295 fn sub(self, other: Self) -> Self {
296 self.sub(&other)
297 }
298}
299
300impl<C> Sub<&ScalarCore<C>> for ScalarCore<C>
301where
302 C: Curve,
303{
304 type Output = Self;
305
306 fn sub(self, other: &Self) -> Self {
307 Self {
308 inner: self.inner.sub_mod(&other.inner, &Self::MODULUS),
309 }
310 }
311}
312
313impl<C> SubAssign<ScalarCore<C>> for ScalarCore<C>
314where
315 C: Curve,
316{
317 fn sub_assign(&mut self, other: Self) {
318 *self = *self - other;
319 }
320}
321
322impl<C> SubAssign<&ScalarCore<C>> for ScalarCore<C>
323where
324 C: Curve,
325{
326 fn sub_assign(&mut self, other: &Self) {
327 *self = *self - other;
328 }
329}
330
331impl<C> Neg for ScalarCore<C>
332where
333 C: Curve,
334{
335 type Output = Self;
336
337 fn neg(self) -> Self {
338 Self {
339 inner: self.inner.neg_mod(&Self::MODULUS),
340 }
341 }
342}
343
344impl<C> Neg for &ScalarCore<C>
345where
346 C: Curve,
347{
348 type Output = ScalarCore<C>;
349
350 fn neg(self) -> ScalarCore<C> {
351 -*self
352 }
353}
354
355impl<C> IsHigh for ScalarCore<C>
356where
357 C: Curve,
358{
359 fn is_high(&self) -> Choice {
360 let n_2 = C::ORDER >> 1;
361 self.inner.ct_gt(&n_2)
362 }
363}
364
365impl<C> fmt::Display for ScalarCore<C>
366where
367 C: Curve,
368{
369 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370 write!(f, "{:X}", self)
371 }
372}
373
374impl<C> fmt::LowerHex for ScalarCore<C>
375where
376 C: Curve,
377{
378 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379 write!(f, "{:x}", HexDisplay(&self.to_be_bytes()))
380 }
381}
382
383impl<C> fmt::UpperHex for ScalarCore<C>
384where
385 C: Curve,
386{
387 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388 write!(f, "{:X}", HexDisplay(&self.to_be_bytes()))
389 }
390}
391
392impl<C> str::FromStr for ScalarCore<C>
393where
394 C: Curve,
395{
396 type Err = Error;
397
398 fn from_str(hex: &str) -> Result<Self> {
399 let mut bytes = FieldBytes::<C>::default();
400 base16ct::lower::decode(hex, &mut bytes)?;
401 Option::from(Self::from_be_bytes(bytes)).ok_or(Error)
402 }
403}
404
405#[cfg(feature = "serde")]
406#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
407impl<C> Serialize for ScalarCore<C>
408where
409 C: Curve,
410{
411 fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
412 where
413 S: ser::Serializer,
414 {
415 serdect::array::serialize_hex_upper_or_bin(&self.to_be_bytes(), serializer)
416 }
417}
418
419#[cfg(feature = "serde")]
420#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
421impl<'de, C> Deserialize<'de> for ScalarCore<C>
422where
423 C: Curve,
424{
425 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
426 where
427 D: de::Deserializer<'de>,
428 {
429 let mut bytes = FieldBytes::<C>::default();
430 serdect::array::deserialize_hex_or_bin(&mut bytes, deserializer)?;
431 Option::from(Self::from_be_bytes(bytes))
432 .ok_or_else(|| de::Error::custom("scalar out of range"))
433 }
434}