use ark_ec::{short_weierstrass::Affine, AffineRepr, CurveGroup};
use ark_ff::{BigInteger, PrimeField, Zero};
use bs58;
use core::fmt;
use sha2::{Digest, Sha256};
use std::ops::{Mul, Neg};
use thiserror::Error;
use crate::{BaseField, CurvePoint, ScalarField, SecKey};
use o1_utils::FieldHelpers;
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum PubKeyError {
#[error("invalid address length")]
AddressLength,
#[error("invalid address base58")]
AddressBase58,
#[error("invalid raw address bytes length")]
AddressRawByteLength,
#[error("invalid address checksum")]
AddressChecksum,
#[error("invalid address version")]
AddressVersion,
#[error("invalid x-coordinate bytes")]
XCoordinateBytes,
#[error("invalid x-coordinate")]
XCoordinate,
#[error("point not on curve")]
YCoordinateBytes,
#[error("invalid y-coordinate bytes")]
YCoordinateParityBytes,
#[error("invalid y-coordinate parity bytes")]
YCoordinateParity,
#[error("invalid y-coordinate parity")]
NonCurvePoint,
#[error("invalid public key hex")]
Hex,
#[error("invalid secret key")]
SecKey,
}
pub type Result<T> = std::result::Result<T, PubKeyError>;
pub const MINA_ADDRESS_LEN: usize = 55;
const MINA_ADDRESS_RAW_LEN: usize = 40;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PubKey(CurvePoint);
impl PubKey {
pub fn from_point_unsafe(point: CurvePoint) -> Self {
Self(point)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != BaseField::size_in_bytes() * 2 {
return Err(PubKeyError::YCoordinateBytes);
}
let x = BaseField::from_bytes(&bytes[0..BaseField::size_in_bytes()])
.map_err(|_| PubKeyError::XCoordinateBytes)?;
let y = BaseField::from_bytes(&bytes[BaseField::size_in_bytes()..])
.map_err(|_| PubKeyError::YCoordinateBytes)?;
let pt = CurvePoint::get_point_from_x_unchecked(x, y.0.is_odd())
.ok_or(PubKeyError::XCoordinate)?;
if pt.y != y {
return Err(PubKeyError::NonCurvePoint);
}
let public = Affine {
x,
y,
infinity: pt.infinity,
};
if !public.is_on_curve() {
return Err(PubKeyError::NonCurvePoint);
}
Ok(PubKey::from_point_unsafe(public))
}
pub fn from_hex(public_hex: &str) -> Result<Self> {
let bytes: Vec<u8> = hex::decode(public_hex).map_err(|_| PubKeyError::Hex)?;
PubKey::from_bytes(&bytes)
}
pub fn from_secret_key(secret_key: SecKey) -> Result<Self> {
if secret_key.clone().into_scalar() == ScalarField::zero() {
return Err(PubKeyError::SecKey);
}
let pt = CurvePoint::generator()
.mul(secret_key.into_scalar())
.into_affine();
if !pt.is_on_curve() {
return Err(PubKeyError::NonCurvePoint);
}
Ok(PubKey::from_point_unsafe(pt))
}
pub fn from_address(address: &str) -> Result<Self> {
if address.len() != MINA_ADDRESS_LEN {
return Err(PubKeyError::AddressLength);
}
let bytes = bs58::decode(address)
.into_vec()
.map_err(|_| PubKeyError::AddressBase58)?;
if bytes.len() != MINA_ADDRESS_RAW_LEN {
return Err(PubKeyError::AddressRawByteLength);
}
let (raw, checksum) = (&bytes[..bytes.len() - 4], &bytes[bytes.len() - 4..]);
let hash = Sha256::digest(&Sha256::digest(raw)[..]);
if checksum != &hash[..4] {
return Err(PubKeyError::AddressChecksum);
}
let (version, x_bytes, y_parity) = (
&raw[..3],
&raw[3..bytes.len() - 5],
raw[bytes.len() - 5] == 0x01,
);
if version != [0xcb, 0x01, 0x01] {
return Err(PubKeyError::AddressVersion);
}
let x = BaseField::from_bytes(x_bytes).map_err(|_| PubKeyError::XCoordinateBytes)?;
let mut pt =
CurvePoint::get_point_from_x_unchecked(x, y_parity).ok_or(PubKeyError::XCoordinate)?;
if pt.y.into_bigint().is_even() == y_parity {
pt.y = pt.y.neg();
}
if !pt.is_on_curve() {
return Err(PubKeyError::NonCurvePoint);
}
Ok(PubKey::from_point_unsafe(pt))
}
pub fn point(&self) -> &CurvePoint {
&self.0
}
pub fn into_point(self) -> CurvePoint {
self.0
}
pub fn into_compressed(&self) -> CompressedPubKey {
let point = self.0;
CompressedPubKey {
x: point.x,
is_odd: point.y.into_bigint().is_odd(),
}
}
pub fn into_address(&self) -> String {
let point = self.point();
into_address(&point.x, point.y.into_bigint().is_odd())
}
pub fn to_bytes(&self) -> Vec<u8> {
let point = self.point();
[point.x.to_bytes(), point.y.to_bytes()].concat()
}
pub fn to_hex(&self) -> String {
let point = self.point();
point.x.to_hex() + point.y.to_hex().as_str()
}
}
impl fmt::Display for PubKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_hex())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CompressedPubKey {
pub x: BaseField,
pub is_odd: bool,
}
fn into_address(x: &BaseField, is_odd: bool) -> String {
let mut raw: Vec<u8> = vec![
0xcb, 0x01, 0x01, ];
raw.extend(x.to_bytes());
raw.push(u8::from(is_odd));
let hash = Sha256::digest(&Sha256::digest(&raw[..])[..]);
raw.extend(&hash[..4]);
bs58::encode(raw).into_string()
}
impl CompressedPubKey {
pub fn into_address(&self) -> String {
into_address(&self.x, self.is_odd)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
let x = BaseField::from_bytes(&bytes[0..BaseField::size_in_bytes()])
.map_err(|_| PubKeyError::XCoordinateBytes)?;
let parity_bytes = &bytes[BaseField::size_in_bytes()..];
if parity_bytes.len() != 1 {
return Err(PubKeyError::YCoordinateParityBytes);
}
let is_odd = if parity_bytes[0] == 0x01 {
true } else if parity_bytes[0] == 0x00 {
false } else {
return Err(PubKeyError::YCoordinateParity);
};
let public =
CurvePoint::get_point_from_x_unchecked(x, is_odd).ok_or(PubKeyError::XCoordinate)?;
if !public.is_on_curve() {
return Err(PubKeyError::NonCurvePoint);
}
Ok(Self { x, is_odd })
}
pub fn from_hex(public_hex: &str) -> Result<Self> {
let bytes: Vec<u8> = hex::decode(public_hex).map_err(|_| PubKeyError::Hex)?;
Self::from_bytes(&bytes)
}
pub fn from_secret_key(sec_key: SecKey) -> Self {
let public = PubKey::from_point_unsafe(
CurvePoint::generator()
.mul(sec_key.into_scalar())
.into_affine(),
);
public.into_compressed()
}
pub fn from_address(address: &str) -> Result<Self> {
Ok(PubKey::from_address(address)?.into_compressed())
}
pub fn empty() -> Self {
Self {
x: BaseField::zero(),
is_odd: false,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let x_bytes = self.x.to_bytes();
let is_odd_bytes = vec![if self.is_odd { 0x01u8 } else { 0x00u8 }];
[x_bytes, is_odd_bytes].concat()
}
pub fn to_hex(&self) -> String {
hex::encode(self.to_bytes())
}
}