use anyhow::{format_err, Error};
use bssl_sys::{
BN_CTX_free, BN_CTX_new, BN_add, BN_asc2bn, BN_bin2bn, BN_bn2bin, BN_bn2dec, BN_cmp, BN_copy,
BN_equal_consttime, BN_free, BN_is_odd, BN_is_one, BN_is_zero, BN_mod_add, BN_mod_exp,
BN_mod_inverse, BN_mod_mul, BN_mod_sqr, BN_mod_sqrt, BN_new, BN_nnmod, BN_num_bits,
BN_num_bytes, BN_one, BN_rand_range, BN_rshift1, BN_set_negative, BN_set_u64, BN_sub, BN_zero,
EC_GROUP_free, EC_GROUP_get_curve_GFp, EC_GROUP_get_order, EC_GROUP_new_by_curve_name,
EC_POINT_add, EC_POINT_free, EC_POINT_get_affine_coordinates_GFp, EC_POINT_invert,
EC_POINT_is_at_infinity, EC_POINT_mul, EC_POINT_new, EC_POINT_set_affine_coordinates_GFp,
ERR_get_error, ERR_reason_error_string, NID_X9_62_prime256v1, NID_secp384r1, NID_secp521r1,
OPENSSL_free, BIGNUM, BN_CTX, EC_GROUP, EC_POINT,
};
use num_derive::{FromPrimitive, ToPrimitive};
use std::cmp::Ordering;
use std::convert::TryInto;
use std::ffi::CString;
use std::fmt;
use std::ptr::NonNull;
fn ptr_or_error<T>(ptr: *mut T) -> Result<NonNull<T>, Error> {
match NonNull::new(ptr) {
Some(non_null) => Ok(non_null),
None => return Err(format_err!("Found null pointer from BoringSSL")),
}
}
fn one_or_error(res: std::os::raw::c_int) -> Result<(), Error> {
match res {
1 => Ok(()),
_ => unsafe {
let error_code = ERR_get_error();
let error_reason_ptr = ERR_reason_error_string(error_code);
if error_reason_ptr.is_null() {
return Err(format_err!("BoringSSL failed to perform an operation."));
}
let error_reason = std::ffi::CStr::from_ptr(error_reason_ptr).to_string_lossy();
return Err(format_err!("BoringSSL failed to perform an operation: {}", error_reason));
},
}
}
pub struct Bignum(NonNull<BIGNUM>);
impl Drop for Bignum {
fn drop(&mut self) {
unsafe { BN_free(self.0.as_mut()) }
}
}
pub struct BignumCtx(NonNull<BN_CTX>);
impl Drop for BignumCtx {
fn drop(&mut self) {
unsafe { BN_CTX_free(self.0.as_mut()) }
}
}
impl BignumCtx {
pub fn new() -> Result<Self, Error> {
ptr_or_error(unsafe { BN_CTX_new() }).map(Self)
}
}
impl Bignum {
pub fn new() -> Result<Self, Error> {
ptr_or_error(unsafe { BN_new() }).map(Self)
}
#[allow(dead_code)]
pub fn zero() -> Result<Self, Error> {
let result = Self::new()?;
unsafe {
BN_zero(result.0.as_ptr());
}
Ok(result)
}
pub fn one() -> Result<Self, Error> {
let result = Self::new()?;
one_or_error(unsafe { BN_one(result.0.as_ptr()) })?;
Ok(result)
}
pub fn rand(max: &Bignum) -> Result<Self, Error> {
let result = Self::new()?;
one_or_error(unsafe { BN_rand_range(result.0.as_ptr(), max.0.as_ptr()) })?;
Ok(result)
}
pub fn new_from_slice(bytes: &[u8]) -> Result<Self, Error> {
if bytes.is_empty() {
Self::new_from_u64(0)
} else {
let bytes_len = bytes.len().try_into().unwrap();
ptr_or_error(unsafe {
BN_bin2bn(&bytes[0] as *const u8, bytes_len, std::ptr::null_mut())
})
.map(Self)
}
}
#[allow(dead_code)]
pub fn new_from_string(ascii: &str) -> Result<Self, Error> {
let mut bignum = std::ptr::null_mut();
let ascii = CString::new(ascii)?;
one_or_error(unsafe { BN_asc2bn(&mut bignum as *mut *mut BIGNUM, ascii.as_ptr()) })?;
ptr_or_error(bignum).map(Bignum)
}
pub fn new_from_u64(value: u64) -> Result<Self, Error> {
let mut bignum = Self::new()?;
one_or_error(unsafe { BN_set_u64(bignum.0.as_mut(), value) })?;
Ok(bignum)
}
pub fn set_negative(self) -> Self {
unsafe { BN_set_negative(self.0.as_ptr(), 1) };
self
}
pub fn copy(&self) -> Result<Self, Error> {
let mut copy = Self::new()?;
ptr_or_error(unsafe { BN_copy(copy.0.as_mut(), self.0.as_ptr()) })?;
Ok(copy)
}
pub fn add(&self, mut b: Self) -> Result<Self, Error> {
one_or_error(unsafe { BN_add(b.0.as_mut(), self.0.as_ptr(), b.0.as_ptr()) })?;
Ok(b)
}
pub fn sub(&self, mut b: Self) -> Result<Self, Error> {
one_or_error(unsafe { BN_sub(b.0.as_mut(), self.0.as_ptr(), b.0.as_ptr()) })?;
Ok(b)
}
pub fn mod_nonnegative(&self, m: &Self, ctx: &BignumCtx) -> Result<Self, Error> {
let mut result = Self::new()?;
one_or_error(unsafe {
BN_nnmod(result.0.as_mut(), self.0.as_ptr(), m.0.as_ptr(), ctx.0.as_ptr())
})?;
Ok(result)
}
pub fn mod_add(&self, b: &Self, m: &Self, ctx: &BignumCtx) -> Result<Self, Error> {
let mut result = Self::new()?;
one_or_error(unsafe {
BN_mod_add(
result.0.as_mut(),
self.0.as_ptr(),
b.0.as_ptr(),
m.0.as_ptr(),
ctx.0.as_ptr(),
)
})?;
Ok(result)
}
pub fn mod_mul(&self, b: &Self, m: &Self, ctx: &BignumCtx) -> Result<Self, Error> {
let mut result = Self::new()?;
one_or_error(unsafe {
BN_mod_mul(
result.0.as_mut(),
self.0.as_ptr(),
b.0.as_ptr(),
m.0.as_ptr(),
ctx.0.as_ptr(),
)
})?;
Ok(result)
}
pub fn mod_inverse(&self, m: &Self, ctx: &BignumCtx) -> Result<Self, Error> {
let mut result = Self::new()?;
ptr_or_error(unsafe {
BN_mod_inverse(result.0.as_mut(), self.0.as_ptr(), m.0.as_ptr(), ctx.0.as_ptr())
})?;
Ok(result)
}
pub fn mod_exp(&self, p: &Self, m: &Self, ctx: &BignumCtx) -> Result<Self, Error> {
let mut result = Self::new()?;
one_or_error(unsafe {
BN_mod_exp(
result.0.as_mut(),
self.0.as_ptr(),
p.0.as_ptr(),
m.0.as_ptr(),
ctx.0.as_ptr(),
)
})?;
Ok(result)
}
pub fn mod_square(&self, m: &Self, ctx: &BignumCtx) -> Result<Self, Error> {
let mut result = Self::new()?;
one_or_error(unsafe {
BN_mod_sqr(result.0.as_mut(), self.0.as_ptr(), m.0.as_ptr(), ctx.0.as_ptr())
})?;
Ok(result)
}
pub fn mod_sqrt(&self, m: &Self, ctx: &BignumCtx) -> Result<Self, Error> {
let mut result = Self::new()?;
ptr_or_error(unsafe {
BN_mod_sqrt(result.0.as_mut(), self.0.as_ptr(), m.0.as_ptr(), ctx.0.as_ptr())
})?;
Ok(result)
}
pub fn rshift1(&self) -> Result<Self, Error> {
let mut result = Self::new()?;
one_or_error(unsafe { BN_rshift1(result.0.as_mut(), self.0.as_ptr()) })?;
Ok(result)
}
pub fn is_one(&self) -> bool {
unsafe { BN_is_one(self.0.as_ptr()) == 1 }
}
pub fn is_zero(&self) -> bool {
unsafe { BN_is_zero(self.0.as_ptr()) == 1 }
}
pub fn is_odd(&self) -> bool {
unsafe { BN_is_odd(self.0.as_ptr()) == 1 }
}
pub fn len(&self) -> usize {
unsafe { BN_num_bytes(self.0.as_ptr()) as usize }
}
pub fn bits(&self) -> usize {
unsafe { BN_num_bits(self.0.as_ptr()) as usize }
}
pub fn to_be_vec(&self, min_length: usize) -> Vec<u8> {
let len = self.len();
let padded_len = std::cmp::max(len, min_length);
let mut out = vec![0; padded_len];
if len != 0 {
unsafe {
BN_bn2bin(self.0.as_ptr(), &mut out[padded_len - len] as *mut u8);
}
}
out
}
}
impl PartialEq for Bignum {
fn eq(&self, other: &Self) -> bool {
unsafe { BN_equal_consttime(self.0.as_ptr(), other.0.as_ptr()) == 1 }
}
}
impl Eq for Bignum {}
impl std::cmp::Ord for Bignum {
fn cmp(&self, other: &Self) -> Ordering {
unsafe { BN_cmp(self.0.as_ptr(), other.0.as_ptr()) }.cmp(&0)
}
}
impl PartialOrd for Bignum {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl fmt::Display for Bignum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
unsafe {
let ptr = BN_bn2dec(self.0.as_ptr());
let res = std::ffi::CStr::from_ptr(ptr).to_string_lossy().fmt(f);
OPENSSL_free(ptr as *mut ::std::os::raw::c_void);
res
}
}
}
impl fmt::Debug for Bignum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Bignum({})", self)
}
}
#[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive)]
pub enum EcGroupId {
P256 = 19,
P384 = 20,
P521 = 21,
}
impl EcGroupId {
fn nid(&self) -> u32 {
match self {
EcGroupId::P256 => NID_X9_62_prime256v1,
EcGroupId::P384 => NID_secp384r1,
EcGroupId::P521 => NID_secp521r1,
}
}
}
pub struct EcGroup(NonNull<EC_GROUP>);
pub struct EcGroupParams {
pub p: Bignum,
pub a: Bignum,
pub b: Bignum,
}
impl Drop for EcGroup {
fn drop(&mut self) {
unsafe { EC_GROUP_free(self.0.as_mut()) }
}
}
impl EcGroup {
pub fn new(id: EcGroupId) -> Result<Self, Error> {
ptr_or_error(unsafe { EC_GROUP_new_by_curve_name(id.nid() as i32) }).map(Self)
}
pub fn get_params(&self, ctx: &BignumCtx) -> Result<EcGroupParams, Error> {
let p = Bignum::new()?;
let a = Bignum::new()?;
let b = Bignum::new()?;
one_or_error(unsafe {
EC_GROUP_get_curve_GFp(
self.0.as_ptr(),
p.0.as_ptr(),
a.0.as_ptr(),
b.0.as_ptr(),
ctx.0.as_ptr(),
)
})?;
Ok(EcGroupParams { p, a, b })
}
pub fn get_order(&self, ctx: &BignumCtx) -> Result<Bignum, Error> {
let order = Bignum::new()?;
one_or_error(unsafe {
EC_GROUP_get_order(self.0.as_ptr(), order.0.as_ptr(), ctx.0.as_ptr())
})?;
Ok(order)
}
}
pub struct EcPoint(NonNull<EC_POINT>);
impl Drop for EcPoint {
fn drop(&mut self) {
unsafe { EC_POINT_free(self.0.as_mut()) }
}
}
impl EcPoint {
pub fn new(group: &EcGroup) -> Result<Self, Error> {
ptr_or_error(unsafe { EC_POINT_new(group.0.as_ptr()) }).map(Self)
}
pub fn new_from_affine_coords(
x: Bignum,
y: Bignum,
group: &EcGroup,
ctx: &BignumCtx,
) -> Result<Self, Error> {
let point = Self::new(group)?;
one_or_error(unsafe {
EC_POINT_set_affine_coordinates_GFp(
group.0.as_ptr(),
point.0.as_ptr(),
x.0.as_ptr(),
y.0.as_ptr(),
ctx.0.as_ptr(),
)
})?;
Ok(point)
}
pub fn to_affine_coords(
&self,
group: &EcGroup,
ctx: &BignumCtx,
) -> Result<(Bignum, Bignum), Error> {
let x = Bignum::new()?;
let y = Bignum::new()?;
one_or_error(unsafe {
EC_POINT_get_affine_coordinates_GFp(
group.0.as_ptr(),
self.0.as_ptr(),
x.0.as_ptr(),
y.0.as_ptr(),
ctx.0.as_ptr(),
)
})?;
Ok((x, y))
}
pub fn mul(&self, group: &EcGroup, m: &Bignum, ctx: &BignumCtx) -> Result<EcPoint, Error> {
let result = Self::new(group)?;
one_or_error(unsafe {
EC_POINT_mul(
group.0.as_ptr(),
result.0.as_ptr(),
std::ptr::null_mut(),
self.0.as_ptr(),
m.0.as_ptr(),
ctx.0.as_ptr(),
)
})?;
Ok(result)
}
pub fn add(&self, group: &EcGroup, b: &EcPoint, ctx: &BignumCtx) -> Result<EcPoint, Error> {
let result = Self::new(group)?;
one_or_error(unsafe {
EC_POINT_add(
group.0.as_ptr(),
result.0.as_ptr(),
self.0.as_ptr(),
b.0.as_ptr(),
ctx.0.as_ptr(),
)
})?;
Ok(result)
}
pub fn invert(self, group: &EcGroup, ctx: &BignumCtx) -> Result<EcPoint, Error> {
one_or_error(unsafe {
EC_POINT_invert(group.0.as_ptr(), self.0.as_ptr(), ctx.0.as_ptr())
})?;
Ok(self)
}
pub fn is_point_at_infinity(&self, group: &EcGroup) -> bool {
unsafe { EC_POINT_is_at_infinity(group.0.as_ptr(), self.0.as_ptr()) == 1 }
}
}
#[cfg(test)]
mod tests {
use super::*;
fn bn(value: &str) -> Bignum {
Bignum::new_from_string(value).unwrap()
}
#[test]
fn bignum_lifetime() {
for _ in 0..10 {
let bignum = Bignum::new().unwrap();
std::mem::drop(bignum);
}
}
#[test]
fn bignum_new() {
assert_eq!(Bignum::new_from_string("100").unwrap(), bn("100"));
assert_eq!(Bignum::new_from_u64(100).unwrap(), bn("100"));
assert_eq!(Bignum::new_from_slice(&[0xff, 0xff][..]).unwrap(), bn("65535"));
}
#[test]
fn bignum_set_negative() {
assert_eq!(bn("100").set_negative(), bn("-100"));
assert_eq!(bn("-100").set_negative(), bn("-100"));
assert_eq!(bn("0").set_negative(), bn("0"));
}
#[test]
fn bignum_format() {
assert_eq!(format!("{}", bn("100")), "100");
assert_eq!(format!("{}", bn("0x100")), "256");
}
#[test]
fn bignum_add() {
let bn1 = bn("1000000000000000000000");
let bn2 = bn("1000000000001234567890");
let sum = bn1.add(bn2).unwrap();
assert_eq!(sum, bn("2000000000001234567890"));
let bn1 = bn("-1000000000000000000000");
let bn2 = bn("1000000000001234567890");
let sum = bn1.add(bn2).unwrap();
assert_eq!(sum, bn("1234567890"));
}
#[test]
fn bignum_sub() {
let bn1 = bn("3000000000000987654321");
let bn2 = bn("2000000000000000000000");
let diff = bn1.sub(bn2).unwrap();
assert_eq!(diff, bn("1000000000000987654321"));
let bn1 = bn("2000000000000012345678");
let bn2 = bn("-3000000000000987654321");
let diff = bn1.sub(bn2).unwrap();
assert_eq!(diff, bn("5000000000000999999999"));
}
#[test]
fn bignum_mod_nonnegative() {
let ctx = BignumCtx::new().unwrap();
let bn1 = bn("12");
let bn2 = bn("5");
let mod_nonnegative = bn1.mod_nonnegative(&bn2, &ctx).unwrap();
assert_eq!(mod_nonnegative, bn("2"));
let bn1 = bn("-12");
let bn2 = bn("5");
let mod_nonnegative = bn1.mod_nonnegative(&bn2, &ctx).unwrap();
assert_eq!(mod_nonnegative, bn("3"));
}
#[test]
fn bignum_mod_add() {
let ctx = BignumCtx::new().unwrap();
let bn1 = bn("1000000000000000000000");
let bn2 = bn("1000000000001234567890");
let m = bn("2000000000000000000000");
let value = bn1.mod_add(&bn2, &m, &ctx).unwrap();
assert_eq!(value, bn("1234567890"));
}
#[test]
fn bignum_mod_mul() {
let ctx = BignumCtx::new().unwrap();
let value = bn("4").mod_mul(&bn("5"), &bn("12"), &ctx).unwrap();
assert_eq!(value, bn("8"));
}
#[test]
fn bignum_mod_inverse() {
let ctx = BignumCtx::new().unwrap();
assert_eq!(bn("3").mod_inverse(&bn("7"), &ctx).unwrap(), bn("5"));
}
#[test]
fn bignum_mod_exp() {
let ctx = BignumCtx::new().unwrap();
let value = bn("4").mod_exp(&bn("2"), &bn("10"), &ctx).unwrap();
assert_eq!(value, bn("6"));
}
#[test]
fn bigum_mod_square() {
let ctx = BignumCtx::new().unwrap();
assert_eq!(bn("11").mod_square(&bn("17"), &ctx).unwrap(), bn("2"));
}
#[test]
fn bignum_mod_sqrt() {
let ctx = BignumCtx::new().unwrap();
let m = bn("13"); let quadratic_residues = [1, 3, 4, 9, 10, 12];
for i in 1..12 {
let i_bn = Bignum::new_from_u64(i).unwrap();
let sqrt = i_bn.mod_sqrt(&m, &ctx);
if quadratic_residues.contains(&i) {
assert!(sqrt.is_ok());
assert_eq!(sqrt.unwrap().mod_exp(&bn("2"), &m, &ctx).unwrap(), i_bn);
} else {
assert!(sqrt.is_err());
}
}
}
#[test]
fn bignum_mod_sqrt_non_prime() {
let ctx = BignumCtx::new().unwrap();
let m = bn("100");
assert!(bn("16").mod_sqrt(&m, &ctx).is_err())
}
#[test]
fn bignum_rshift1() {
assert_eq!(bn("100").rshift1().unwrap(), bn("50"));
assert_eq!(bn("101").rshift1().unwrap(), bn("50"));
}
#[test]
fn bignum_simple_fns() {
assert!(bn("1").is_one());
assert!(!bn("100000").is_one());
assert!(Bignum::one().unwrap().is_one());
assert!(bn("0").is_zero());
assert!(!bn("1").is_zero());
assert!(Bignum::zero().unwrap().is_zero());
assert!(bn("1000001").is_odd());
assert!(!bn("1000002").is_odd());
}
#[test]
fn bignum_ord() {
let neg = bn("-100");
let zero = bn("0");
let pos = bn("100");
assert!(neg < zero);
assert!(pos > neg);
assert_eq!(neg, neg);
assert_eq!(zero, zero);
assert_eq!(pos, pos);
}
#[test]
fn bignum_to_be_vec() {
assert_eq!(bn("0xff00").to_be_vec(1), vec![0xff, 0x00]);
assert_eq!(bn("0xff00").to_be_vec(4), vec![0x00, 0x00, 0xff, 0x00]);
assert_eq!(bn("0").to_be_vec(4), vec![0x00, 0x00, 0x00, 0x00]);
}
const P: &'static str = "0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF";
const B: &'static str = "0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B";
const ORDER: &'static str =
"0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551";
const GX: &'static str = "0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296";
const GY: &'static str = "0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5";
const I: &'static str = "0xC88F01F510D9AC3F70A292DAA2316DE544E9AAB8AFE84049C62A9C57862D1433";
const GIX: &'static str = "0xDAD0B65394221CF9B051E1FECA5787D098DFE637FC90B9EF945D0C3772581180";
const GIY: &'static str = "0x5271A0461CDB8252D61F1C456FA3E59AB1F45B33ACCF5F58389E0577B8990BB3";
#[test]
fn ec_group_params() {
let group = EcGroup::new(EcGroupId::P256).unwrap();
let ctx = BignumCtx::new().unwrap();
let params = group.get_params(&ctx).unwrap();
let order = group.get_order(&ctx).unwrap();
assert_eq!(params.p, bn(P));
assert_eq!(params.a, params.p.sub(bn("3")).unwrap());
assert_eq!(params.b, bn(B));
assert_eq!(order, bn(ORDER));
}
#[test]
fn ec_point_to_coords() {
let group = EcGroup::new(EcGroupId::P256).unwrap();
let ctx = BignumCtx::new().unwrap();
let point = EcPoint::new_from_affine_coords(bn(GIX), bn(GIY), &group, &ctx).unwrap();
let (x, y) = point.to_affine_coords(&group, &ctx).unwrap();
assert_eq!(x, bn(GIX));
assert_eq!(y, bn(GIY));
}
#[test]
fn ec_wrong_coords_to_point_err() {
let group = EcGroup::new(EcGroupId::P256).unwrap();
let ctx = BignumCtx::new().unwrap();
let result =
EcPoint::new_from_affine_coords(bn(GIX).add(bn("1")).unwrap(), bn(GIY), &group, &ctx);
let Err(err) = result else { panic!("Expected error") };
assert!(format!("{:?}", err).contains("POINT_IS_NOT_ON_CURVE"));
}
#[test]
fn ec_point_mul() {
let group = EcGroup::new(EcGroupId::P256).unwrap();
let ctx = BignumCtx::new().unwrap();
let g = EcPoint::new_from_affine_coords(bn(GX), bn(GY), &group, &ctx).unwrap();
let gi = g.mul(&group, &bn(I), &ctx).unwrap();
let (gix, giy) = gi.to_affine_coords(&group, &ctx).unwrap();
assert_eq!(gix, bn(GIX));
assert_eq!(giy, bn(GIY));
}
}