use crate::{Error, Result};
use core::{
cmp::Ordering,
fmt::{self, Debug},
ops::Add,
str,
};
use generic_array::{
typenum::{U1, U28, U32, U48, U66},
ArrayLength, GenericArray,
};
#[cfg(feature = "alloc")]
use alloc::boxed::Box;
#[cfg(feature = "serde")]
use serde::{de, ser, Deserialize, Serialize};
#[cfg(feature = "subtle")]
use subtle::{Choice, ConditionallySelectable};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
pub trait ModulusSize: 'static + ArrayLength<u8> + Copy + Debug {
type CompressedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
type UncompressedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
type UntaggedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
}
macro_rules! impl_modulus_size {
($($size:ty),+) => {
$(impl ModulusSize for $size {
type CompressedPointSize = <$size as Add<U1>>::Output;
type UncompressedPointSize = <Self::UntaggedPointSize as Add<U1>>::Output;
type UntaggedPointSize = <$size as Add>::Output;
})+
}
}
impl_modulus_size!(U28, U32, U48, U66);
#[derive(Clone, Default)]
pub struct EncodedPoint<Size>
where
Size: ModulusSize,
{
bytes: GenericArray<u8, Size::UncompressedPointSize>,
}
#[allow(clippy::len_without_is_empty)]
impl<Size> EncodedPoint<Size>
where
Size: ModulusSize,
{
pub fn from_bytes(input: impl AsRef<[u8]>) -> Result<Self> {
let input = input.as_ref();
let tag = input
.first()
.cloned()
.ok_or(Error::PointEncoding)
.and_then(Tag::from_u8)?;
let expected_len = tag.message_len(Size::to_usize());
if input.len() != expected_len {
return Err(Error::PointEncoding);
}
let mut bytes = GenericArray::default();
bytes[..expected_len].copy_from_slice(input);
Ok(Self { bytes })
}
pub fn from_untagged_bytes(bytes: &GenericArray<u8, Size::UntaggedPointSize>) -> Self {
let (x, y) = bytes.split_at(Size::to_usize());
Self::from_affine_coordinates(x.into(), y.into(), false)
}
pub fn from_affine_coordinates(
x: &GenericArray<u8, Size>,
y: &GenericArray<u8, Size>,
compress: bool,
) -> Self {
let tag = if compress {
Tag::compress_y(y.as_slice())
} else {
Tag::Uncompressed
};
let mut bytes = GenericArray::default();
bytes[0] = tag.into();
bytes[1..(Size::to_usize() + 1)].copy_from_slice(x);
if !compress {
bytes[(Size::to_usize() + 1)..].copy_from_slice(y);
}
Self { bytes }
}
pub fn identity() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.tag().message_len(Size::to_usize())
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len()]
}
#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub fn to_bytes(&self) -> Box<[u8]> {
self.as_bytes().to_vec().into_boxed_slice()
}
pub fn is_compact(&self) -> bool {
self.tag().is_compact()
}
pub fn is_compressed(&self) -> bool {
self.tag().is_compressed()
}
pub fn is_identity(&self) -> bool {
self.tag().is_identity()
}
pub fn compress(&self) -> Self {
match self.coordinates() {
Coordinates::Compressed { .. }
| Coordinates::Compact { .. }
| Coordinates::Identity => self.clone(),
Coordinates::Uncompressed { x, y } => Self::from_affine_coordinates(x, y, true),
}
}
pub fn tag(&self) -> Tag {
Tag::from_u8(self.bytes[0]).expect("invalid tag")
}
#[inline]
pub fn coordinates(&self) -> Coordinates<'_, Size> {
if self.is_identity() {
return Coordinates::Identity;
}
let (x, y) = self.bytes[1..].split_at(Size::to_usize());
if self.is_compressed() {
Coordinates::Compressed {
x: x.into(),
y_is_odd: self.tag() as u8 & 1 == 1,
}
} else if self.is_compact() {
Coordinates::Compact { x: x.into() }
} else {
Coordinates::Uncompressed {
x: x.into(),
y: y.into(),
}
}
}
pub fn x(&self) -> Option<&GenericArray<u8, Size>> {
match self.coordinates() {
Coordinates::Identity => None,
Coordinates::Compressed { x, .. } => Some(x),
Coordinates::Uncompressed { x, .. } => Some(x),
Coordinates::Compact { x } => Some(x),
}
}
pub fn y(&self) -> Option<&GenericArray<u8, Size>> {
match self.coordinates() {
Coordinates::Compressed { .. } | Coordinates::Identity => None,
Coordinates::Uncompressed { y, .. } => Some(y),
Coordinates::Compact { .. } => None,
}
}
}
impl<Size> AsRef<[u8]> for EncodedPoint<Size>
where
Size: ModulusSize,
{
#[inline]
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
#[cfg(feature = "subtle")]
impl<Size> ConditionallySelectable for EncodedPoint<Size>
where
Size: ModulusSize,
<Size::UncompressedPointSize as ArrayLength<u8>>::ArrayType: Copy,
{
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let mut bytes = GenericArray::default();
for (i, byte) in bytes.iter_mut().enumerate() {
*byte = u8::conditional_select(&a.bytes[i], &b.bytes[i], choice);
}
Self { bytes }
}
}
impl<Size> Copy for EncodedPoint<Size>
where
Size: ModulusSize,
<Size::UncompressedPointSize as ArrayLength<u8>>::ArrayType: Copy,
{
}
impl<Size> Debug for EncodedPoint<Size>
where
Size: ModulusSize,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "EncodedPoint({:?})", self.coordinates())
}
}
impl<Size: ModulusSize> Eq for EncodedPoint<Size> {}
impl<Size> PartialEq for EncodedPoint<Size>
where
Size: ModulusSize,
{
fn eq(&self, other: &Self) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl<Size: ModulusSize> PartialOrd for EncodedPoint<Size>
where
Size: ModulusSize,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<Size: ModulusSize> Ord for EncodedPoint<Size>
where
Size: ModulusSize,
{
fn cmp(&self, other: &Self) -> Ordering {
self.as_bytes().cmp(other.as_bytes())
}
}
impl<Size: ModulusSize> TryFrom<&[u8]> for EncodedPoint<Size>
where
Size: ModulusSize,
{
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self> {
Self::from_bytes(bytes)
}
}
#[cfg(feature = "zeroize")]
impl<Size> Zeroize for EncodedPoint<Size>
where
Size: ModulusSize,
{
fn zeroize(&mut self) {
self.bytes.zeroize();
*self = Self::identity();
}
}
impl<Size> fmt::Display for EncodedPoint<Size>
where
Size: ModulusSize,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:X}", self)
}
}
impl<Size> fmt::LowerHex for EncodedPoint<Size>
where
Size: ModulusSize,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.as_bytes() {
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}
impl<Size> fmt::UpperHex for EncodedPoint<Size>
where
Size: ModulusSize,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.as_bytes() {
write!(f, "{:02X}", byte)?;
}
Ok(())
}
}
impl<Size> str::FromStr for EncodedPoint<Size>
where
Size: ModulusSize,
{
type Err = Error;
fn from_str(hex: &str) -> Result<Self> {
let mut buffer = GenericArray::<u8, Size::UncompressedPointSize>::default();
let decoded_len = hex.as_bytes().len() / 2;
if hex.as_bytes().len() % 2 != 0 || decoded_len > buffer.len() {
return Err(Error::PointEncoding);
}
let mut upper_case = None;
for &byte in hex.as_bytes() {
match byte {
b'0'..=b'9' => (),
b'a'..=b'z' => match upper_case {
Some(true) => return Err(Error::PointEncoding),
Some(false) => (),
None => upper_case = Some(false),
},
b'A'..=b'Z' => match upper_case {
Some(true) => (),
Some(false) => return Err(Error::PointEncoding),
None => upper_case = Some(true),
},
_ => return Err(Error::PointEncoding),
}
}
for (digit, byte) in hex.as_bytes().chunks_exact(2).zip(buffer.iter_mut()) {
*byte = str::from_utf8(digit)
.ok()
.and_then(|s| u8::from_str_radix(s, 16).ok())
.ok_or(Error::PointEncoding)?;
}
Self::from_bytes(&buffer[..decoded_len])
}
}
#[cfg(feature = "serde")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
impl<Size> Serialize for EncodedPoint<Size>
where
Size: ModulusSize,
{
#[cfg(not(feature = "alloc"))]
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: ser::Serializer,
{
self.as_bytes().serialize(serializer)
}
#[cfg(feature = "alloc")]
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: ser::Serializer,
{
use alloc::string::ToString;
if serializer.is_human_readable() {
self.to_string().serialize(serializer)
} else {
self.as_bytes().serialize(serializer)
}
}
}
#[cfg(feature = "serde")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
impl<'de, Size> Deserialize<'de> for EncodedPoint<Size>
where
Size: ModulusSize,
{
#[cfg(not(feature = "alloc"))]
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
use de::Error;
<&[u8]>::deserialize(deserializer)
.and_then(|slice| Self::from_bytes(slice).map_err(D::Error::custom))
}
#[cfg(feature = "alloc")]
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
use de::Error;
if deserializer.is_human_readable() {
<&str>::deserialize(deserializer)?
.parse()
.map_err(D::Error::custom)
} else {
<&[u8]>::deserialize(deserializer)
.and_then(|bytes| Self::from_bytes(bytes).map_err(D::Error::custom))
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Coordinates<'a, Size: ModulusSize> {
Identity,
Compact {
x: &'a GenericArray<u8, Size>,
},
Compressed {
x: &'a GenericArray<u8, Size>,
y_is_odd: bool,
},
Uncompressed {
x: &'a GenericArray<u8, Size>,
y: &'a GenericArray<u8, Size>,
},
}
impl<'a, Size: ModulusSize> Coordinates<'a, Size> {
pub fn tag(&self) -> Tag {
match self {
Coordinates::Compact { .. } => Tag::Compact,
Coordinates::Compressed { y_is_odd, .. } => {
if *y_is_odd {
Tag::CompressedOddY
} else {
Tag::CompressedEvenY
}
}
Coordinates::Identity => Tag::Identity,
Coordinates::Uncompressed { .. } => Tag::Uncompressed,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(u8)]
pub enum Tag {
Identity = 0,
CompressedEvenY = 2,
CompressedOddY = 3,
Uncompressed = 4,
Compact = 5,
}
impl Tag {
pub fn from_u8(byte: u8) -> Result<Self> {
match byte {
0 => Ok(Tag::Identity),
2 => Ok(Tag::CompressedEvenY),
3 => Ok(Tag::CompressedOddY),
4 => Ok(Tag::Uncompressed),
5 => Ok(Tag::Compact),
_ => Err(Error::PointEncoding),
}
}
pub fn is_compact(self) -> bool {
matches!(self, Tag::Compact)
}
pub fn is_compressed(self) -> bool {
matches!(self, Tag::CompressedEvenY | Tag::CompressedOddY)
}
pub fn is_identity(self) -> bool {
self == Tag::Identity
}
pub fn message_len(self, field_element_size: usize) -> usize {
1 + match self {
Tag::Identity => 0,
Tag::CompressedEvenY | Tag::CompressedOddY => field_element_size,
Tag::Uncompressed => field_element_size * 2,
Tag::Compact => field_element_size,
}
}
fn compress_y(y: &[u8]) -> Self {
if y.as_ref().last().expect("empty y-coordinate") & 1 == 1 {
Tag::CompressedOddY
} else {
Tag::CompressedEvenY
}
}
}
impl From<Tag> for u8 {
fn from(tag: Tag) -> u8 {
tag as u8
}
}
#[cfg(test)]
mod tests {
use super::{Coordinates, Tag};
use core::str::FromStr;
use generic_array::{typenum::U32, GenericArray};
use hex_literal::hex;
#[cfg(feature = "alloc")]
use alloc::string::ToString;
#[cfg(feature = "subtle")]
use subtle::ConditionallySelectable;
type EncodedPoint = super::EncodedPoint<U32>;
const IDENTITY_BYTES: [u8; 1] = [0];
const UNCOMPRESSED_BYTES: [u8; 65] = hex!("0411111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222");
const COMPRESSED_BYTES: [u8; 33] =
hex!("021111111111111111111111111111111111111111111111111111111111111111");
#[test]
fn decode_compressed_point() {
let compressed_even_y_bytes =
hex!("020100000000000000000000000000000000000000000000000000000000000000");
let compressed_even_y = EncodedPoint::from_bytes(&compressed_even_y_bytes[..]).unwrap();
assert!(compressed_even_y.is_compressed());
assert_eq!(compressed_even_y.tag(), Tag::CompressedEvenY);
assert_eq!(compressed_even_y.len(), 33);
assert_eq!(compressed_even_y.as_bytes(), &compressed_even_y_bytes[..]);
assert_eq!(
compressed_even_y.coordinates(),
Coordinates::Compressed {
x: &hex!("0100000000000000000000000000000000000000000000000000000000000000").into(),
y_is_odd: false
}
);
assert_eq!(
compressed_even_y.x().unwrap(),
&hex!("0100000000000000000000000000000000000000000000000000000000000000").into()
);
assert_eq!(compressed_even_y.y(), None);
let compressed_odd_y_bytes =
hex!("030200000000000000000000000000000000000000000000000000000000000000");
let compressed_odd_y = EncodedPoint::from_bytes(&compressed_odd_y_bytes[..]).unwrap();
assert!(compressed_odd_y.is_compressed());
assert_eq!(compressed_odd_y.tag(), Tag::CompressedOddY);
assert_eq!(compressed_odd_y.len(), 33);
assert_eq!(compressed_odd_y.as_bytes(), &compressed_odd_y_bytes[..]);
assert_eq!(
compressed_odd_y.coordinates(),
Coordinates::Compressed {
x: &hex!("0200000000000000000000000000000000000000000000000000000000000000").into(),
y_is_odd: true
}
);
assert_eq!(
compressed_odd_y.x().unwrap(),
&hex!("0200000000000000000000000000000000000000000000000000000000000000").into()
);
assert_eq!(compressed_odd_y.y(), None);
}
#[test]
fn decode_uncompressed_point() {
let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
assert!(!uncompressed_point.is_compressed());
assert_eq!(uncompressed_point.tag(), Tag::Uncompressed);
assert_eq!(uncompressed_point.len(), 65);
assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
assert_eq!(
uncompressed_point.coordinates(),
Coordinates::Uncompressed {
x: &hex!("1111111111111111111111111111111111111111111111111111111111111111").into(),
y: &hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
}
);
assert_eq!(
uncompressed_point.x().unwrap(),
&hex!("1111111111111111111111111111111111111111111111111111111111111111").into()
);
assert_eq!(
uncompressed_point.y().unwrap(),
&hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
);
}
#[test]
fn decode_identity() {
let identity_point = EncodedPoint::from_bytes(&IDENTITY_BYTES[..]).unwrap();
assert!(identity_point.is_identity());
assert_eq!(identity_point.tag(), Tag::Identity);
assert_eq!(identity_point.len(), 1);
assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
assert_eq!(identity_point.coordinates(), Coordinates::Identity);
assert_eq!(identity_point.x(), None);
assert_eq!(identity_point.y(), None);
}
#[test]
fn decode_invalid_tag() {
let mut compressed_bytes = COMPRESSED_BYTES.clone();
let mut uncompressed_bytes = UNCOMPRESSED_BYTES.clone();
for bytes in &mut [&mut compressed_bytes[..], &mut uncompressed_bytes[..]] {
for tag in 0..=0xFF {
if tag == 2 || tag == 3 || tag == 4 || tag == 5 {
continue;
}
(*bytes)[0] = tag;
let decode_result = EncodedPoint::from_bytes(&*bytes);
assert!(decode_result.is_err());
}
}
}
#[test]
fn decode_truncated_point() {
for bytes in &[&COMPRESSED_BYTES[..], &UNCOMPRESSED_BYTES[..]] {
for len in 0..bytes.len() {
let decode_result = EncodedPoint::from_bytes(&bytes[..len]);
assert!(decode_result.is_err());
}
}
}
#[test]
fn from_untagged_point() {
let untagged_bytes = hex!("11111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222");
let uncompressed_point =
EncodedPoint::from_untagged_bytes(GenericArray::from_slice(&untagged_bytes[..]));
assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
}
#[test]
fn from_affine_coordinates() {
let x = hex!("1111111111111111111111111111111111111111111111111111111111111111");
let y = hex!("2222222222222222222222222222222222222222222222222222222222222222");
let uncompressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false);
assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
let compressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), true);
assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
}
#[test]
fn compress() {
let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
let compressed_point = uncompressed_point.compress();
assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
}
#[cfg(feature = "subtle")]
#[test]
fn conditional_select() {
let a = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
let b = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
let a_selected = EncodedPoint::conditional_select(&a, &b, 0.into());
assert_eq!(a, a_selected);
let b_selected = EncodedPoint::conditional_select(&a, &b, 1.into());
assert_eq!(b, b_selected);
}
#[test]
fn identity() {
let identity_point = EncodedPoint::identity();
assert_eq!(identity_point.tag(), Tag::Identity);
assert_eq!(identity_point.len(), 1);
assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
assert_eq!(identity_point, EncodedPoint::default());
}
#[test]
fn decode_hex() {
let point = EncodedPoint::from_str(
"021111111111111111111111111111111111111111111111111111111111111111",
)
.unwrap();
assert_eq!(point.as_bytes(), COMPRESSED_BYTES);
}
#[cfg(feature = "alloc")]
#[test]
fn to_bytes() {
let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
assert_eq!(&*uncompressed_point.to_bytes(), &UNCOMPRESSED_BYTES[..]);
}
#[cfg(feature = "alloc")]
#[test]
fn to_string() {
let point = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
assert_eq!(
point.to_string(),
"021111111111111111111111111111111111111111111111111111111111111111"
);
}
}