Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5db6c1f
zkvm config to allow tests to run in zkvm
austinabell Dec 11, 2024
195c183
initial template of code for acc
austinabell Dec 12, 2024
98e8ea5
impl
austinabell Dec 12, 2024
84ef330
uncomment dep config
austinabell Dec 16, 2024
6d87f20
accelerate point inv for proj -> affine
austinabell Dec 17, 2024
73a06f5
accelerate scalar inverse
austinabell Dec 17, 2024
d095d83
switch to unchecked (already checked)
austinabell Dec 18, 2024
64c9011
impl acceleration for decompress before sqrt
austinabell Dec 18, 2024
9856151
accelerate sqrt
austinabell Dec 18, 2024
fb939a8
move ec add to accelerated
austinabell Dec 18, 2024
d7454af
accelerate the to_mont multiply
austinabell Dec 19, 2024
d30ec8a
accelerate to_affine conversion
austinabell Dec 19, 2024
ef27d4b
add zero checks to inverse operations through zkvm
austinabell Dec 19, 2024
a5cce8f
impl remaining ec impls
austinabell Dec 19, 2024
480bedf
update impl of scalar to words to be more careful about alignment
austinabell Dec 19, 2024
3fc6064
update patch config to minimize git diff
austinabell Dec 19, 2024
2bbfc5d
rename risc0 module to minimize discoverability
austinabell Dec 19, 2024
d220d53
reduce duplicate logic
austinabell Dec 19, 2024
17cf000
switch ops to do prime check even for intermediate ops to be safe
austinabell Jan 7, 2025
db7cd24
handle identity in affine conversion, panic on invalid mont conversio…
austinabell Jan 10, 2025
09f54cb
simplify invert and FE conversions with checked mul
austinabell Jan 10, 2025
70c5519
Test and prep for 1.2.1
tzerrell Jan 16, 2025
0a739b1
Cut over to 1.2.1 release
tzerrell Jan 16, 2025
cdbf06d
Update lock file with version update
tzerrell Jan 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,478 changes: 2,318 additions & 160 deletions Cargo.lock

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions p256/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,24 @@ primeorder = { version = "0.13", optional = true, path = "../primeorder" }
serdect = { version = "0.2", optional = true, default-features = false }
sha2 = { version = "0.10", optional = true, default-features = false }

[target.'cfg(all(target_os = "zkvm", target_arch = "riscv32"))'.dependencies]
bytemuck = "1"
risc0-bigint2 = { version = "1.2.1", features = ["unstable"] }

[dev-dependencies]
blobby = "0.3"
criterion = "0.4"
ecdsa-core = { version = "0.16", package = "ecdsa", default-features = false, features = ["dev"] }
hex-literal = "0.4"
primeorder = { version = "0.13", features = ["dev"], path = "../primeorder" }
proptest = "1"
rand_core = { version = "0.6", features = ["getrandom"] }

[target.'cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))'.dev-dependencies]
proptest = "1"
criterion = "0.4"

[target.'cfg(all(target_os = "zkvm", target_arch = "riscv32"))'.dev-dependencies]
risc0-zkvm = { version = "1.2.1", features = ["getrandom"] }

[features]
default = ["arithmetic", "ecdsa", "pem", "std"]
alloc = ["ecdsa-core?/alloc", "elliptic-curve/alloc"]
Expand Down
1 change: 1 addition & 0 deletions p256/benches/field.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! secp256r1 field element benchmarks
#![cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]

use criterion::{
criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion,
Expand Down
2 changes: 2 additions & 0 deletions p256/benches/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! secp256r1 scalar arithmetic benchmarks

#![cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]

use criterion::{
criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion,
};
Expand Down
22 changes: 22 additions & 0 deletions p256/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ impl PrimeCurveArithmetic for NistP256 {
type CurveGroup = ProjectivePoint;
}

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
use primeorder::__risc0::FieldElement256;

/// Adapted from [NIST SP 800-186] § G.1.2: Curve P-256.
///
/// [NIST SP 800-186]: https://csrc.nist.gov/publications/detail/sp/800-186/final
Expand All @@ -56,4 +59,23 @@ impl PrimeCurveParams for NistP256 {
FieldElement::from_hex("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296"),
FieldElement::from_hex("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5"),
);

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
const PRIME_LE_WORDS: [u32; 8] = crate::__risc0::SECP256R1_PRIME;

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
const ORDER_LE_WORDS: [u32; 8] = crate::__risc0::SECP256R1_ORDER;

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
const EQUATION_A_LE: FieldElement256<NistP256> =
FieldElement256::new_unchecked(crate::__risc0::SECP256R1_EQUATION_A_LE);

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
const EQUATION_B_LE: FieldElement256<NistP256> =
FieldElement256::new_unchecked(crate::__risc0::SECP256R1_EQUATION_B_LE);

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
fn from_u32_words_le(words: [u32; 8]) -> FieldElement {
FieldElement::from_words_le(words)
}
}
61 changes: 61 additions & 0 deletions p256/src/arithmetic/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ pub const MODULUS: U256 = U256::from_be_hex(MODULUS_HEX);
const R_2: U256 =
U256::from_be_hex("00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003");

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
use primeorder::__risc0::FieldElement256;

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
const R_2_LE: FieldElement256<NistP256> = FieldElement256::new_unchecked([
0x00000001, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFE, 0x00000000,
]);

/// An element in the finite field modulo p = 2^{224}(2^{32} − 1) + 2^{192} + 2^{96} − 1.
///
/// The internal representation is in little-endian order. Elements are always in
Expand All @@ -54,15 +62,66 @@ primeorder::impl_mont_field_element!(
);

impl FieldElement {
#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
pub(crate) fn from_words_le(fe: [u32; 8]) -> Self {
let fe = FieldElement256::new_unchecked(fe);

// Convert to montgomery form with aR mod p
let mut mont = FieldElement256::default();

// This mul will check if the result is within the modulus.
fe.mul(&R_2_LE, &mut mont);

let uint = U256::from_le_slice(bytemuck::cast_slice::<u32, u8>(&mont.data));

Self(uint)
}

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
pub(crate) fn to_words_le(&self) -> [u32; 8] {
use crate::elliptic_curve::bigint::Encoding;
// NOTE: this from mont conversion could be accelerated, but it's very little cycles.
let canonical = self.to_canonical();
let input = canonical.to_le_bytes();
let array = bytemuck::cast::<_, [u32; 8]>(input);

array
}

/// Returns the multiplicative inverse of self, if self is non-zero.
pub fn invert(&self) -> CtOption<Self> {
#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
{
use crate::elliptic_curve::bigint::Encoding;

// NOTE: This is not a constant time operation, as inverting zero in the zkvm is not
// possible as it will panic in the host.
if self.is_zero().into() {
return CtOption::new(FieldElement::ZERO, Choice::from(0));
} else {
let input_words = self.to_words_le();
let mut output = [0u32; 8];
risc0_bigint2::field::modinv_256(
&input_words,
&crate::__risc0::SECP256R1_PRIME,
&mut output,
);
let element = FieldElement::from_words_le(output);
return CtOption::new(element, Choice::from(1));
}
}

#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]
CtOption::new(self.invert_unchecked(), !self.is_zero())
}

/// Returns the multiplicative inverse of self.
///
/// Does not check that self is non-zero.
const fn invert_unchecked(&self) -> Self {
// NOTE: It is fine that the internal invert function is not overriden for the zkvm, as this
// is only called in compile time constants given `invert` is overridden.

// We need to find b such that b * a ≡ 1 mod p. As we are in a prime
// field, we can apply Fermat's Little Theorem:
//
Expand Down Expand Up @@ -169,6 +228,7 @@ mod tests {
impl_field_identity_tests, impl_field_invert_tests, impl_field_sqrt_tests,
impl_primefield_tests,
};
#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]
use proptest::{num, prelude::*};

/// t = (modulus - 1) >> S
Expand Down Expand Up @@ -263,6 +323,7 @@ mod tests {
assert_eq!(two.pow_vartime(&[2, 0, 0, 0]), four);
}

#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]
proptest! {
/// This checks behaviour well within the field ranges, because it doesn't set the
/// highest limb.
Expand Down
1 change: 0 additions & 1 deletion p256/src/arithmetic/field/field32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::arithmetic::util::{adc, mac, sbb};
pub type Fe = [u32; 8];

/// Translate a field element out of the Montgomery domain.
#[inline]
pub const fn fe_from_montgomery(w: &Fe) -> Fe {
let w = fe32_to_fe64(w);
montgomery_reduce(&[w[0], w[1], w[2], w[3], 0, 0, 0, 0])
Expand Down
2 changes: 2 additions & 0 deletions p256/src/arithmetic/hash2curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ mod tests {
Curve, Field,
};
use hex_literal::hex;
#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]
use proptest::{num::u64::ANY, prelude::ProptestConfig, proptest};
use sha2::Sha256;

Expand Down Expand Up @@ -289,6 +290,7 @@ mod tests {
}

#[test]
#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]
fn from_okm_fuzz() {
let mut wide_order = GenericArray::default();
wide_order[16..].copy_from_slice(&NistP256::ORDER.to_be_byte_array());
Expand Down
29 changes: 29 additions & 0 deletions p256/src/arithmetic/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,30 @@ impl Scalar {

/// Returns the multiplicative inverse of self, if self is non-zero
pub fn invert(&self) -> CtOption<Self> {
#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
{
use crate::elliptic_curve::bigint::Encoding;

// NOTE: This is not a constant time operation, as inverting zero in the zkvm is not
// possible as it will panic in the host.
if self.is_zero().into() {
return CtOption::new(Scalar::ZERO, Choice::from(0));
} else {
let input = self.0.to_le_bytes();
let input_words = bytemuck::cast::<_, [u32; 8]>(input);
let mut output = [0u32; 8];
risc0_bigint2::field::modinv_256(
&input_words,
&crate::__risc0::SECP256R1_ORDER,
&mut output,
);
let bytes = bytemuck::cast_slice::<u32, u8>(&output);
let res = Scalar(U256::from_le_slice(bytes));
CtOption::new(res, Choice::from(1))
}
}

#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]
CtOption::new(self.invert_unchecked(), !self.is_zero())
}

Expand Down Expand Up @@ -363,6 +387,11 @@ impl Invert for Scalar {
/// sidechannels.
#[allow(non_snake_case)]
fn invert_vartime(&self) -> CtOption<Self> {
#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
{
return self.invert();
}

let mut u = *self;
let mut v = Self(MODULUS);
let mut A = Self::ONE;
Expand Down
4 changes: 4 additions & 0 deletions p256/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
#[cfg(feature = "arithmetic")]
mod arithmetic;

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
#[path = "risc0.rs"]
mod __risc0;

#[cfg(feature = "ecdh")]
pub mod ecdh;

Expand Down
31 changes: 31 additions & 0 deletions p256/src/risc0.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use risc0_bigint2::ec::{Curve, WeierstrassCurve, EC_256_WIDTH_WORDS};

/// The secp256r1 curve's prime field characteristic
pub(crate) const SECP256R1_PRIME: [u32; EC_256_WIDTH_WORDS] = [
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000001, 0xFFFFFFFF,
];

/// The secp256r1 curve's order
pub(crate) const SECP256R1_ORDER: [u32; EC_256_WIDTH_WORDS] = [
0xFC632551, 0xF3B9CAC2, 0xA7179E84, 0xBCE6FAAD, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF,
];

pub(crate) const SECP256R1_EQUATION_A_LE: [u32; EC_256_WIDTH_WORDS] = [
0xFFFFFFFC, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000001, 0xFFFFFFFF,
];

pub(crate) const SECP256R1_EQUATION_B_LE: [u32; EC_256_WIDTH_WORDS] = [
0x27D2604B, 0x3BCE3C3E, 0xCC53B0F6, 0x651D06B0, 0x769886BC, 0xB3EBBD55, 0xAA3A93E7, 0x5AC635D8,
];

const SECP256R1_CURVE: &WeierstrassCurve<EC_256_WIDTH_WORDS> =
&WeierstrassCurve::<EC_256_WIDTH_WORDS>::new(
SECP256R1_PRIME,
// Curve parameter a = -3 (represented mod p)
SECP256R1_EQUATION_A_LE,
SECP256R1_EQUATION_B_LE,
);

impl Curve<EC_256_WIDTH_WORDS> for crate::NistP256 {
const CURVE: &'static WeierstrassCurve<EC_256_WIDTH_WORDS> = SECP256R1_CURVE;
}
3 changes: 1 addition & 2 deletions p256/tests/ecdsa.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! ECDSA tests.

#![cfg(feature = "arithmetic")]

#![cfg(all(feature = "arithmetic", not(all(target_os = "zkvm", target_arch = "riscv32"))))]
use elliptic_curve::ops::Reduce;
use p256::{
ecdsa::{SigningKey, VerifyingKey},
Expand Down
2 changes: 1 addition & 1 deletion p256/tests/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Scalar arithmetic tests.

#![cfg(feature = "arithmetic")]
#![cfg(all(feature = "arithmetic", not(all(target_os = "zkvm", target_arch = "riscv32"))))]

use elliptic_curve::ops::{Invert, Reduce};
use p256::{Scalar, U256};
Expand Down
4 changes: 4 additions & 0 deletions primeorder/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ elliptic-curve = { version = "0.13", default-features = false, features = ["arit
# optional dependencies
serdect = { version = "0.2", optional = true, default-features = false }

[target.'cfg(all(target_os = "zkvm", target_arch = "riscv32"))'.dependencies]
risc0-bigint2 = { version = "1.2.1", features = ["unstable"] }
bytemuck = "1"

[features]
std = ["elliptic-curve/std"]

Expand Down
51 changes: 51 additions & 0 deletions primeorder/src/affine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,56 @@ where
FieldBytes<C>: Copy,
{
fn decompress(x_bytes: &FieldBytes<C>, y_is_odd: Choice) -> CtOption<Self> {
#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
{
use crate::__risc0::FieldElement256;

// Note: buffers are kept separate for each OP as the result pointer cannot equal one
// of the input pointers.
let mut scratch = FieldElement256::<C>::default();
let mut acc = FieldElement256::<C>::default();
let mut scratch_1 = FieldElement256::<C>::from(x_bytes);

// x checked to be in the field.
C::FieldElement::from_repr(*x_bytes).and_then(|x| {
// x * &x * &x
scratch_1.mul_unchecked(&scratch_1, &mut scratch);
scratch.mul_unchecked(&scratch_1, &mut acc);

// + &(C::EQUATION_A * &x)
scratch_1.mul_unchecked(&C::EQUATION_A_LE, &mut scratch);
// Can re-use x as a buffer, no longer needed.
scratch.add_unchecked(&acc, &mut scratch_1);

// + &C::EQUATION_B
scratch_1.add_unchecked(&C::EQUATION_B_LE, &mut scratch);

// Sqrt implementation. Not separated into another function to allow
// re-using buffers.
scratch.sqrt_unchecked(&mut scratch_1, &mut acc);

// Check that the square root is correct.
acc.square(&mut scratch_1);

let sqrt = CtOption::new(acc, Choice::from(scratch_1.eq(&scratch) as u8));

// Checked that the result is within the field.
sqrt.and_then(|sqrt| {
C::FieldElement::from_repr(sqrt.into()).map(|beta| {
let y = C::FieldElement::conditional_select(
// TODO this neg can be accelerated too
&-beta,
&beta,
beta.is_odd().ct_eq(&y_is_odd),
);

Self { x, y, infinity: 0 }
})
})
})
}

#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]
C::FieldElement::from_repr(*x_bytes).and_then(|x| {
let alpha = x * &x * &x + &(C::EQUATION_A * &x) + &C::EQUATION_B;
let beta = alpha.sqrt();
Expand Down Expand Up @@ -416,6 +466,7 @@ where
type Output = ProjectivePoint<C>;

fn mul(self, scalar: S) -> ProjectivePoint<C> {
// TODO avoid proj conversion
ProjectivePoint::<C>::from(self) * scalar
}
}
Expand Down
3 changes: 0 additions & 3 deletions primeorder/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ macro_rules! impl_field_op {
impl ::core::ops::$op for $fe {
type Output = $fe;

#[inline]
fn $op_fn(self, rhs: $fe) -> $fe {
$fe($func(self.as_ref(), rhs.as_ref()).into())
}
Expand All @@ -465,7 +464,6 @@ macro_rules! impl_field_op {
impl ::core::ops::$op<&$fe> for $fe {
type Output = $fe;

#[inline]
fn $op_fn(self, rhs: &$fe) -> $fe {
$fe($func(self.as_ref(), rhs.as_ref()).into())
}
Expand All @@ -474,7 +472,6 @@ macro_rules! impl_field_op {
impl ::core::ops::$op<&$fe> for &$fe {
type Output = $fe;

#[inline]
fn $op_fn(self, rhs: &$fe) -> $fe {
$fe($func(self.as_ref(), rhs.as_ref()).into())
}
Expand Down
Loading