1use crate::{Error, Result};
9use base16ct::HexDisplay;
10use core::{
11 cmp::Ordering,
12 fmt::{self, Debug},
13 ops::Add,
14 str,
15};
16use generic_array::{
17 typenum::{U1, U28, U32, U48, U66},
18 ArrayLength, GenericArray,
19};
20
21#[cfg(feature = "alloc")]
22use alloc::boxed::Box;
23
24#[cfg(feature = "serde")]
25use serdect::serde::{de, ser, Deserialize, Serialize};
26
27#[cfg(feature = "subtle")]
28use subtle::{Choice, ConditionallySelectable};
29
30#[cfg(feature = "zeroize")]
31use zeroize::Zeroize;
32
33pub trait ModulusSize: 'static + ArrayLength<u8> + Copy + Debug {
37 type CompressedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
41
42 type UncompressedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
46
47 type UntaggedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
50}
51
52macro_rules! impl_modulus_size {
53 ($($size:ty),+) => {
54 $(impl ModulusSize for $size {
55 type CompressedPointSize = <$size as Add<U1>>::Output;
56 type UncompressedPointSize = <Self::UntaggedPointSize as Add<U1>>::Output;
57 type UntaggedPointSize = <$size as Add>::Output;
58 })+
59 }
60}
61
62impl_modulus_size!(U28, U32, U48, U66);
63
64#[derive(Clone, Default)]
70pub struct EncodedPoint<Size>
71where
72 Size: ModulusSize,
73{
74 bytes: GenericArray<u8, Size::UncompressedPointSize>,
75}
76
77#[allow(clippy::len_without_is_empty)]
78impl<Size> EncodedPoint<Size>
79where
80 Size: ModulusSize,
81{
82 pub fn from_bytes(input: impl AsRef<[u8]>) -> Result<Self> {
89 let input = input.as_ref();
90
91 let tag = input
93 .first()
94 .cloned()
95 .ok_or(Error::PointEncoding)
96 .and_then(Tag::from_u8)?;
97
98 let expected_len = tag.message_len(Size::to_usize());
100
101 if input.len() != expected_len {
102 return Err(Error::PointEncoding);
103 }
104
105 let mut bytes = GenericArray::default();
106 bytes[..expected_len].copy_from_slice(input);
107 Ok(Self { bytes })
108 }
109
110 pub fn from_untagged_bytes(bytes: &GenericArray<u8, Size::UntaggedPointSize>) -> Self {
114 let (x, y) = bytes.split_at(Size::to_usize());
115 Self::from_affine_coordinates(x.into(), y.into(), false)
116 }
117
118 pub fn from_affine_coordinates(
121 x: &GenericArray<u8, Size>,
122 y: &GenericArray<u8, Size>,
123 compress: bool,
124 ) -> Self {
125 let tag = if compress {
126 Tag::compress_y(y.as_slice())
127 } else {
128 Tag::Uncompressed
129 };
130
131 let mut bytes = GenericArray::default();
132 bytes[0] = tag.into();
133 bytes[1..(Size::to_usize() + 1)].copy_from_slice(x);
134
135 if !compress {
136 bytes[(Size::to_usize() + 1)..].copy_from_slice(y);
137 }
138
139 Self { bytes }
140 }
141
142 pub fn identity() -> Self {
145 Self::default()
146 }
147
148 pub fn len(&self) -> usize {
150 self.tag().message_len(Size::to_usize())
151 }
152
153 pub fn as_bytes(&self) -> &[u8] {
155 &self.bytes[..self.len()]
156 }
157
158 #[cfg(feature = "alloc")]
160 #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
161 pub fn to_bytes(&self) -> Box<[u8]> {
162 self.as_bytes().to_vec().into_boxed_slice()
163 }
164
165 pub fn is_compact(&self) -> bool {
167 self.tag().is_compact()
168 }
169
170 pub fn is_compressed(&self) -> bool {
172 self.tag().is_compressed()
173 }
174
175 pub fn is_identity(&self) -> bool {
177 self.tag().is_identity()
178 }
179
180 pub fn compress(&self) -> Self {
182 match self.coordinates() {
183 Coordinates::Compressed { .. }
184 | Coordinates::Compact { .. }
185 | Coordinates::Identity => self.clone(),
186 Coordinates::Uncompressed { x, y } => Self::from_affine_coordinates(x, y, true),
187 }
188 }
189
190 pub fn tag(&self) -> Tag {
192 Tag::from_u8(self.bytes[0]).expect("invalid tag")
194 }
195
196 #[inline]
198 pub fn coordinates(&self) -> Coordinates<'_, Size> {
199 if self.is_identity() {
200 return Coordinates::Identity;
201 }
202
203 let (x, y) = self.bytes[1..].split_at(Size::to_usize());
204
205 if self.is_compressed() {
206 Coordinates::Compressed {
207 x: x.into(),
208 y_is_odd: self.tag() as u8 & 1 == 1,
209 }
210 } else if self.is_compact() {
211 Coordinates::Compact { x: x.into() }
212 } else {
213 Coordinates::Uncompressed {
214 x: x.into(),
215 y: y.into(),
216 }
217 }
218 }
219
220 pub fn x(&self) -> Option<&GenericArray<u8, Size>> {
224 match self.coordinates() {
225 Coordinates::Identity => None,
226 Coordinates::Compressed { x, .. } => Some(x),
227 Coordinates::Uncompressed { x, .. } => Some(x),
228 Coordinates::Compact { x } => Some(x),
229 }
230 }
231
232 pub fn y(&self) -> Option<&GenericArray<u8, Size>> {
236 match self.coordinates() {
237 Coordinates::Compressed { .. } | Coordinates::Identity => None,
238 Coordinates::Uncompressed { y, .. } => Some(y),
239 Coordinates::Compact { .. } => None,
240 }
241 }
242}
243
244impl<Size> AsRef<[u8]> for EncodedPoint<Size>
245where
246 Size: ModulusSize,
247{
248 #[inline]
249 fn as_ref(&self) -> &[u8] {
250 self.as_bytes()
251 }
252}
253
254#[cfg(feature = "subtle")]
255impl<Size> ConditionallySelectable for EncodedPoint<Size>
256where
257 Size: ModulusSize,
258 <Size::UncompressedPointSize as ArrayLength<u8>>::ArrayType: Copy,
259{
260 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
261 let mut bytes = GenericArray::default();
262
263 for (i, byte) in bytes.iter_mut().enumerate() {
264 *byte = u8::conditional_select(&a.bytes[i], &b.bytes[i], choice);
265 }
266
267 Self { bytes }
268 }
269}
270
271impl<Size> Copy for EncodedPoint<Size>
272where
273 Size: ModulusSize,
274 <Size::UncompressedPointSize as ArrayLength<u8>>::ArrayType: Copy,
275{
276}
277
278impl<Size> Debug for EncodedPoint<Size>
279where
280 Size: ModulusSize,
281{
282 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283 write!(f, "EncodedPoint({:?})", self.coordinates())
284 }
285}
286
287impl<Size: ModulusSize> Eq for EncodedPoint<Size> {}
288
289impl<Size> PartialEq for EncodedPoint<Size>
290where
291 Size: ModulusSize,
292{
293 fn eq(&self, other: &Self) -> bool {
294 self.as_bytes() == other.as_bytes()
295 }
296}
297
298impl<Size: ModulusSize> PartialOrd for EncodedPoint<Size>
299where
300 Size: ModulusSize,
301{
302 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
303 Some(self.cmp(other))
304 }
305}
306
307impl<Size: ModulusSize> Ord for EncodedPoint<Size>
308where
309 Size: ModulusSize,
310{
311 fn cmp(&self, other: &Self) -> Ordering {
312 self.as_bytes().cmp(other.as_bytes())
313 }
314}
315
316impl<Size: ModulusSize> TryFrom<&[u8]> for EncodedPoint<Size>
317where
318 Size: ModulusSize,
319{
320 type Error = Error;
321
322 fn try_from(bytes: &[u8]) -> Result<Self> {
323 Self::from_bytes(bytes)
324 }
325}
326
327#[cfg(feature = "zeroize")]
328impl<Size> Zeroize for EncodedPoint<Size>
329where
330 Size: ModulusSize,
331{
332 fn zeroize(&mut self) {
333 self.bytes.zeroize();
334 *self = Self::identity();
335 }
336}
337
338impl<Size> fmt::Display for EncodedPoint<Size>
339where
340 Size: ModulusSize,
341{
342 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343 write!(f, "{:X}", self)
344 }
345}
346
347impl<Size> fmt::LowerHex for EncodedPoint<Size>
348where
349 Size: ModulusSize,
350{
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 write!(f, "{:x}", HexDisplay(self.as_bytes()))
353 }
354}
355
356impl<Size> fmt::UpperHex for EncodedPoint<Size>
357where
358 Size: ModulusSize,
359{
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 write!(f, "{:X}", HexDisplay(self.as_bytes()))
362 }
363}
364
365impl<Size> str::FromStr for EncodedPoint<Size>
370where
371 Size: ModulusSize,
372{
373 type Err = Error;
374
375 fn from_str(hex: &str) -> Result<Self> {
376 let mut buf = GenericArray::<u8, Size::UncompressedPointSize>::default();
377 base16ct::mixed::decode(hex, &mut buf)
378 .map_err(|_| Error::PointEncoding)
379 .and_then(Self::from_bytes)
380 }
381}
382
383#[cfg(feature = "serde")]
384#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
385impl<Size> Serialize for EncodedPoint<Size>
386where
387 Size: ModulusSize,
388{
389 fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
390 where
391 S: ser::Serializer,
392 {
393 serdect::slice::serialize_hex_upper_or_bin(&self.as_bytes(), serializer)
394 }
395}
396
397#[cfg(feature = "serde")]
398#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
399impl<'de, Size> Deserialize<'de> for EncodedPoint<Size>
400where
401 Size: ModulusSize,
402{
403 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
404 where
405 D: de::Deserializer<'de>,
406 {
407 let bytes = serdect::slice::deserialize_hex_or_bin_vec(deserializer)?;
408 Self::from_bytes(&bytes).map_err(de::Error::custom)
409 }
410}
411
412#[derive(Copy, Clone, Debug, Eq, PartialEq)]
415pub enum Coordinates<'a, Size: ModulusSize> {
416 Identity,
418
419 Compact {
421 x: &'a GenericArray<u8, Size>,
423 },
424
425 Compressed {
427 x: &'a GenericArray<u8, Size>,
429
430 y_is_odd: bool,
432 },
433
434 Uncompressed {
436 x: &'a GenericArray<u8, Size>,
438
439 y: &'a GenericArray<u8, Size>,
441 },
442}
443
444impl<'a, Size: ModulusSize> Coordinates<'a, Size> {
445 pub fn tag(&self) -> Tag {
447 match self {
448 Coordinates::Compact { .. } => Tag::Compact,
449 Coordinates::Compressed { y_is_odd, .. } => {
450 if *y_is_odd {
451 Tag::CompressedOddY
452 } else {
453 Tag::CompressedEvenY
454 }
455 }
456 Coordinates::Identity => Tag::Identity,
457 Coordinates::Uncompressed { .. } => Tag::Uncompressed,
458 }
459 }
460}
461
462#[derive(Copy, Clone, Debug, Eq, PartialEq)]
464#[repr(u8)]
465pub enum Tag {
466 Identity = 0,
468
469 CompressedEvenY = 2,
471
472 CompressedOddY = 3,
474
475 Uncompressed = 4,
477
478 Compact = 5,
480}
481
482impl Tag {
483 pub fn from_u8(byte: u8) -> Result<Self> {
485 match byte {
486 0 => Ok(Tag::Identity),
487 2 => Ok(Tag::CompressedEvenY),
488 3 => Ok(Tag::CompressedOddY),
489 4 => Ok(Tag::Uncompressed),
490 5 => Ok(Tag::Compact),
491 _ => Err(Error::PointEncoding),
492 }
493 }
494
495 pub fn is_compact(self) -> bool {
497 matches!(self, Tag::Compact)
498 }
499
500 pub fn is_compressed(self) -> bool {
502 matches!(self, Tag::CompressedEvenY | Tag::CompressedOddY)
503 }
504
505 pub fn is_identity(self) -> bool {
507 self == Tag::Identity
508 }
509
510 pub fn message_len(self, field_element_size: usize) -> usize {
514 1 + match self {
515 Tag::Identity => 0,
516 Tag::CompressedEvenY | Tag::CompressedOddY => field_element_size,
517 Tag::Uncompressed => field_element_size * 2,
518 Tag::Compact => field_element_size,
519 }
520 }
521
522 fn compress_y(y: &[u8]) -> Self {
524 if y.as_ref().last().expect("empty y-coordinate") & 1 == 1 {
526 Tag::CompressedOddY
527 } else {
528 Tag::CompressedEvenY
529 }
530 }
531}
532
533impl TryFrom<u8> for Tag {
534 type Error = Error;
535
536 fn try_from(byte: u8) -> Result<Self> {
537 Self::from_u8(byte)
538 }
539}
540
541impl From<Tag> for u8 {
542 fn from(tag: Tag) -> u8 {
543 tag as u8
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::{Coordinates, Tag};
550 use core::str::FromStr;
551 use generic_array::{typenum::U32, GenericArray};
552 use hex_literal::hex;
553
554 #[cfg(feature = "alloc")]
555 use alloc::string::ToString;
556
557 #[cfg(feature = "subtle")]
558 use subtle::ConditionallySelectable;
559
560 type EncodedPoint = super::EncodedPoint<U32>;
561
562 const IDENTITY_BYTES: [u8; 1] = [0];
564
565 const UNCOMPRESSED_BYTES: [u8; 65] = hex!("0411111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222");
567
568 const COMPRESSED_BYTES: [u8; 33] =
570 hex!("021111111111111111111111111111111111111111111111111111111111111111");
571
572 #[test]
573 fn decode_compressed_point() {
574 let compressed_even_y_bytes =
576 hex!("020100000000000000000000000000000000000000000000000000000000000000");
577
578 let compressed_even_y = EncodedPoint::from_bytes(&compressed_even_y_bytes[..]).unwrap();
579
580 assert!(compressed_even_y.is_compressed());
581 assert_eq!(compressed_even_y.tag(), Tag::CompressedEvenY);
582 assert_eq!(compressed_even_y.len(), 33);
583 assert_eq!(compressed_even_y.as_bytes(), &compressed_even_y_bytes[..]);
584
585 assert_eq!(
586 compressed_even_y.coordinates(),
587 Coordinates::Compressed {
588 x: &hex!("0100000000000000000000000000000000000000000000000000000000000000").into(),
589 y_is_odd: false
590 }
591 );
592
593 assert_eq!(
594 compressed_even_y.x().unwrap(),
595 &hex!("0100000000000000000000000000000000000000000000000000000000000000").into()
596 );
597 assert_eq!(compressed_even_y.y(), None);
598
599 let compressed_odd_y_bytes =
601 hex!("030200000000000000000000000000000000000000000000000000000000000000");
602
603 let compressed_odd_y = EncodedPoint::from_bytes(&compressed_odd_y_bytes[..]).unwrap();
604
605 assert!(compressed_odd_y.is_compressed());
606 assert_eq!(compressed_odd_y.tag(), Tag::CompressedOddY);
607 assert_eq!(compressed_odd_y.len(), 33);
608 assert_eq!(compressed_odd_y.as_bytes(), &compressed_odd_y_bytes[..]);
609
610 assert_eq!(
611 compressed_odd_y.coordinates(),
612 Coordinates::Compressed {
613 x: &hex!("0200000000000000000000000000000000000000000000000000000000000000").into(),
614 y_is_odd: true
615 }
616 );
617
618 assert_eq!(
619 compressed_odd_y.x().unwrap(),
620 &hex!("0200000000000000000000000000000000000000000000000000000000000000").into()
621 );
622 assert_eq!(compressed_odd_y.y(), None);
623 }
624
625 #[test]
626 fn decode_uncompressed_point() {
627 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
628
629 assert!(!uncompressed_point.is_compressed());
630 assert_eq!(uncompressed_point.tag(), Tag::Uncompressed);
631 assert_eq!(uncompressed_point.len(), 65);
632 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
633
634 assert_eq!(
635 uncompressed_point.coordinates(),
636 Coordinates::Uncompressed {
637 x: &hex!("1111111111111111111111111111111111111111111111111111111111111111").into(),
638 y: &hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
639 }
640 );
641
642 assert_eq!(
643 uncompressed_point.x().unwrap(),
644 &hex!("1111111111111111111111111111111111111111111111111111111111111111").into()
645 );
646 assert_eq!(
647 uncompressed_point.y().unwrap(),
648 &hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
649 );
650 }
651
652 #[test]
653 fn decode_identity() {
654 let identity_point = EncodedPoint::from_bytes(&IDENTITY_BYTES[..]).unwrap();
655 assert!(identity_point.is_identity());
656 assert_eq!(identity_point.tag(), Tag::Identity);
657 assert_eq!(identity_point.len(), 1);
658 assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
659 assert_eq!(identity_point.coordinates(), Coordinates::Identity);
660 assert_eq!(identity_point.x(), None);
661 assert_eq!(identity_point.y(), None);
662 }
663
664 #[test]
665 fn decode_invalid_tag() {
666 let mut compressed_bytes = COMPRESSED_BYTES.clone();
667 let mut uncompressed_bytes = UNCOMPRESSED_BYTES.clone();
668
669 for bytes in &mut [&mut compressed_bytes[..], &mut uncompressed_bytes[..]] {
670 for tag in 0..=0xFF {
671 if tag == 2 || tag == 3 || tag == 4 || tag == 5 {
673 continue;
674 }
675
676 (*bytes)[0] = tag;
677 let decode_result = EncodedPoint::from_bytes(&*bytes);
678 assert!(decode_result.is_err());
679 }
680 }
681 }
682
683 #[test]
684 fn decode_truncated_point() {
685 for bytes in &[&COMPRESSED_BYTES[..], &UNCOMPRESSED_BYTES[..]] {
686 for len in 0..bytes.len() {
687 let decode_result = EncodedPoint::from_bytes(&bytes[..len]);
688 assert!(decode_result.is_err());
689 }
690 }
691 }
692
693 #[test]
694 fn from_untagged_point() {
695 let untagged_bytes = hex!("11111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222");
696 let uncompressed_point =
697 EncodedPoint::from_untagged_bytes(GenericArray::from_slice(&untagged_bytes[..]));
698 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
699 }
700
701 #[test]
702 fn from_affine_coordinates() {
703 let x = hex!("1111111111111111111111111111111111111111111111111111111111111111");
704 let y = hex!("2222222222222222222222222222222222222222222222222222222222222222");
705
706 let uncompressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false);
707 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
708
709 let compressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), true);
710 assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
711 }
712
713 #[test]
714 fn compress() {
715 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
716 let compressed_point = uncompressed_point.compress();
717 assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
718 }
719
720 #[cfg(feature = "subtle")]
721 #[test]
722 fn conditional_select() {
723 let a = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
724 let b = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
725
726 let a_selected = EncodedPoint::conditional_select(&a, &b, 0.into());
727 assert_eq!(a, a_selected);
728
729 let b_selected = EncodedPoint::conditional_select(&a, &b, 1.into());
730 assert_eq!(b, b_selected);
731 }
732
733 #[test]
734 fn identity() {
735 let identity_point = EncodedPoint::identity();
736 assert_eq!(identity_point.tag(), Tag::Identity);
737 assert_eq!(identity_point.len(), 1);
738 assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
739
740 assert_eq!(identity_point, EncodedPoint::default());
742 }
743
744 #[test]
745 fn decode_hex() {
746 let point = EncodedPoint::from_str(
747 "021111111111111111111111111111111111111111111111111111111111111111",
748 )
749 .unwrap();
750 assert_eq!(point.as_bytes(), COMPRESSED_BYTES);
751 }
752
753 #[cfg(feature = "alloc")]
754 #[test]
755 fn to_bytes() {
756 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
757 assert_eq!(&*uncompressed_point.to_bytes(), &UNCOMPRESSED_BYTES[..]);
758 }
759
760 #[cfg(feature = "alloc")]
761 #[test]
762 fn to_string() {
763 let point = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
764 assert_eq!(
765 point.to_string(),
766 "021111111111111111111111111111111111111111111111111111111111111111"
767 );
768 }
769}