Skip to content

Commit 6f116ab

Browse files
committed
Add allocation-free MultiscalarMul method
1 parent c3f91f7 commit 6f116ab

File tree

9 files changed

+246
-111
lines changed

9 files changed

+246
-111
lines changed

curve25519-dalek/benches/dalek_benchmarks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ mod multiscalar_benches {
139139
// rerandomize the scalars for every call just in case.
140140
b.iter_batched(
141141
|| construct_scalars(size),
142-
|scalars| EdwardsPoint::multiscalar_mul(&scalars, &points),
142+
|scalars| EdwardsPoint::multiscalar_alloc_mul(&scalars, &points),
143143
BatchSize::SmallInput,
144144
);
145145
},

curve25519-dalek/src/backend/mod.rs

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
3737
use crate::EdwardsPoint;
3838
use crate::Scalar;
39+
use crate::traits::MultiscalarMul;
3940

4041
pub mod serial;
4142

@@ -191,30 +192,51 @@ impl VartimePrecomputedStraus {
191192
}
192193
}
193194

195+
#[allow(missing_docs)]
196+
pub fn straus_multiscalar_mul<const N: usize>(
197+
scalars: &[Scalar; N],
198+
points: &[EdwardsPoint; N],
199+
) -> EdwardsPoint {
200+
match get_selected_backend() {
201+
#[cfg(curve25519_dalek_backend = "simd")]
202+
BackendKind::Avx2 => {
203+
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul(scalars, points)
204+
}
205+
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
206+
BackendKind::Avx512 => {
207+
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul(
208+
scalars, points,
209+
)
210+
}
211+
BackendKind::Serial => serial::scalar_mul::straus::Straus::multiscalar_mul(scalars, points),
212+
}
213+
}
214+
194215
#[allow(missing_docs)]
195216
#[cfg(feature = "alloc")]
196-
pub fn straus_multiscalar_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
217+
pub fn straus_multiscalar_alloc_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
197218
where
198219
I: IntoIterator,
199220
I::Item: core::borrow::Borrow<Scalar>,
200221
J: IntoIterator,
201222
J::Item: core::borrow::Borrow<EdwardsPoint>,
202223
{
203-
use crate::traits::MultiscalarMul;
204-
205224
match get_selected_backend() {
206225
#[cfg(curve25519_dalek_backend = "simd")]
207226
BackendKind::Avx2 => {
208-
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::<I, J>(scalars, points)
227+
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_alloc_mul::<I, J>(
228+
scalars, points,
229+
)
209230
}
210231
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
211232
BackendKind::Avx512 => {
212-
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::<I, J>(
213-
scalars, points,
214-
)
233+
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_alloc_mul::<
234+
I,
235+
J,
236+
>(scalars, points)
215237
}
216238
BackendKind::Serial => {
217-
serial::scalar_mul::straus::Straus::multiscalar_mul::<I, J>(scalars, points)
239+
serial::scalar_mul::straus::Straus::multiscalar_alloc_mul::<I, J>(scalars, points)
218240
}
219241
}
220242
}

curve25519-dalek/src/backend/serial/scalar_mul/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ pub mod variable_base;
2323
#[allow(missing_docs)]
2424
pub mod vartime_double_base;
2525

26-
#[cfg(feature = "alloc")]
2726
pub mod straus;
2827

2928
#[cfg(feature = "alloc")]

curve25519-dalek/src/backend/serial/scalar_mul/straus.rs

Lines changed: 94 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,20 @@
1313
1414
#![allow(non_snake_case)]
1515

16+
#[cfg(feature = "alloc")]
1617
use alloc::vec::Vec;
1718

19+
#[cfg(feature = "alloc")]
1820
use core::borrow::Borrow;
19-
use core::cmp::Ordering;
2021

22+
use crate::backend::serial::curve_models::ProjectiveNielsPoint;
2123
use crate::edwards::EdwardsPoint;
2224
use crate::scalar::Scalar;
25+
use crate::traits::Identity;
2326
use crate::traits::MultiscalarMul;
27+
#[cfg(feature = "alloc")]
2428
use crate::traits::VartimeMultiscalarMul;
29+
use crate::window::LookupTable;
2530

2631
/// Perform multiscalar multiplication by the interleaved window
2732
/// method, also known as Straus' method (since it was apparently
@@ -49,68 +54,26 @@ pub struct Straus {}
4954
impl MultiscalarMul for Straus {
5055
type Point = EdwardsPoint;
5156

52-
/// Constant-time Straus using a fixed window of size \\(4\\).
53-
///
54-
/// Our goal is to compute
55-
/// \\[
56-
/// Q = s_1 P_1 + \cdots + s_n P_n.
57-
/// \\]
58-
///
59-
/// For each point \\( P_i \\), precompute a lookup table of
60-
/// \\[
61-
/// P_i, 2P_i, 3P_i, 4P_i, 5P_i, 6P_i, 7P_i, 8P_i.
62-
/// \\]
63-
///
64-
/// For each scalar \\( s_i \\), compute its radix-\\(2^4\\)
65-
/// signed digits \\( s_{i,j} \\), i.e.,
66-
/// \\[
67-
/// s_i = s_{i,0} + s_{i,1} 16^1 + ... + s_{i,63} 16^{63},
68-
/// \\]
69-
/// with \\( -8 \leq s_{i,j} < 8 \\). Since \\( 0 \leq |s_{i,j}|
70-
/// \leq 8 \\), we can retrieve \\( s_{i,j} P_i \\) from the
71-
/// lookup table with a conditional negation: using signed
72-
/// digits halves the required table size.
73-
///
74-
/// Then as in the single-base fixed window case, we have
75-
/// \\[
76-
/// \begin{aligned}
77-
/// s_i P_i &= P_i (s_{i,0} + s_{i,1} 16^1 + \cdots + s_{i,63} 16^{63}) \\\\
78-
/// s_i P_i &= P_i s_{i,0} + P_i s_{i,1} 16^1 + \cdots + P_i s_{i,63} 16^{63} \\\\
79-
/// s_i P_i &= P_i s_{i,0} + 16(P_i s_{i,1} + 16( \cdots +16P_i s_{i,63})\cdots )
80-
/// \end{aligned}
81-
/// \\]
82-
/// so each \\( s_i P_i \\) can be computed by alternately adding
83-
/// a precomputed multiple \\( P_i s_{i,j} \\) of \\( P_i \\) and
84-
/// repeatedly doubling.
85-
///
86-
/// Now consider the two-dimensional sum
87-
/// \\[
88-
/// \begin{aligned}
89-
/// s\_1 P\_1 &=& P\_1 s\_{1,0} &+& 16 (P\_1 s\_{1,1} &+& 16 ( \cdots &+& 16 P\_1 s\_{1,63}&) \cdots ) \\\\
90-
/// + & & + & & + & & & & + & \\\\
91-
/// s\_2 P\_2 &=& P\_2 s\_{2,0} &+& 16 (P\_2 s\_{2,1} &+& 16 ( \cdots &+& 16 P\_2 s\_{2,63}&) \cdots ) \\\\
92-
/// + & & + & & + & & & & + & \\\\
93-
/// \vdots & & \vdots & & \vdots & & & & \vdots & \\\\
94-
/// + & & + & & + & & & & + & \\\\
95-
/// s\_n P\_n &=& P\_n s\_{n,0} &+& 16 (P\_n s\_{n,1} &+& 16 ( \cdots &+& 16 P\_n s\_{n,63}&) \cdots )
96-
/// \end{aligned}
97-
/// \\]
98-
/// The sum of the left-hand column is the result \\( Q \\); by
99-
/// computing the two-dimensional sum on the right column-wise,
100-
/// top-to-bottom, then right-to-left, we need to multiply by \\(
101-
/// 16\\) only once per column, sharing the doublings across all
102-
/// of the input points.
103-
fn multiscalar_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
57+
fn multiscalar_mul<const N: usize>(
58+
scalars: &[Scalar; N],
59+
points: &[EdwardsPoint; N],
60+
) -> EdwardsPoint {
61+
let lookup_tables: [_; N] =
62+
core::array::from_fn(|index| LookupTable::<ProjectiveNielsPoint>::from(&points[index]));
63+
64+
let scalar_digits: [_; N] = core::array::from_fn(|index| scalars[index].as_radix_16());
65+
66+
multiscalar_mul(&scalar_digits, &lookup_tables)
67+
}
68+
69+
#[cfg(feature = "alloc")]
70+
fn multiscalar_alloc_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
10471
where
10572
I: IntoIterator,
10673
I::Item: Borrow<Scalar>,
10774
J: IntoIterator,
10875
J::Item: Borrow<EdwardsPoint>,
10976
{
110-
use crate::backend::serial::curve_models::ProjectiveNielsPoint;
111-
use crate::traits::Identity;
112-
use crate::window::LookupTable;
113-
11477
let lookup_tables: Vec<_> = points
11578
.into_iter()
11679
.map(|point| LookupTable::<ProjectiveNielsPoint>::from(point.borrow()))
@@ -125,25 +88,86 @@ impl MultiscalarMul for Straus {
12588
.map(|s| s.borrow().as_radix_16())
12689
.collect();
12790

128-
let mut Q = EdwardsPoint::identity();
129-
for j in (0..64).rev() {
130-
Q = Q.mul_by_pow_2(4);
131-
let it = scalar_digits.iter().zip(lookup_tables.iter());
132-
for (s_i, lookup_table_i) in it {
133-
// R_i = s_{i,j} * P_i
134-
let R_i = lookup_table_i.select(s_i[j]);
135-
// Q = Q + R_i
136-
Q = (&Q + &R_i).as_extended();
137-
}
138-
}
91+
let Q = multiscalar_mul(&scalar_digits, &lookup_tables);
13992

14093
#[cfg(feature = "zeroize")]
141-
zeroize::Zeroize::zeroize(&mut scalar_digits);
94+
zeroize::Zeroize::zeroize(&mut scalar_digits.iter_mut());
14295

14396
Q
14497
}
14598
}
14699

100+
/// Constant-time Straus using a fixed window of size \\(4\\).
101+
///
102+
/// Our goal is to compute
103+
/// \\[
104+
/// Q = s_1 P_1 + \cdots + s_n P_n.
105+
/// \\]
106+
///
107+
/// For each point \\( P_i \\), precompute a lookup table of
108+
/// \\[
109+
/// P_i, 2P_i, 3P_i, 4P_i, 5P_i, 6P_i, 7P_i, 8P_i.
110+
/// \\]
111+
///
112+
/// For each scalar \\( s_i \\), compute its radix-\\(2^4\\)
113+
/// signed digits \\( s_{i,j} \\), i.e.,
114+
/// \\[
115+
/// s_i = s_{i,0} + s_{i,1} 16^1 + ... + s_{i,63} 16^{63},
116+
/// \\]
117+
/// with \\( -8 \leq s_{i,j} < 8 \\). Since \\( 0 \leq |s_{i,j}|
118+
/// \leq 8 \\), we can retrieve \\( s_{i,j} P_i \\) from the
119+
/// lookup table with a conditional negation: using signed
120+
/// digits halves the required table size.
121+
///
122+
/// Then as in the single-base fixed window case, we have
123+
/// \\[
124+
/// \begin{aligned}
125+
/// s_i P_i &= P_i (s_{i,0} + s_{i,1} 16^1 + \cdots + s_{i,63} 16^{63}) \\\\
126+
/// s_i P_i &= P_i s_{i,0} + P_i s_{i,1} 16^1 + \cdots + P_i s_{i,63} 16^{63} \\\\
127+
/// s_i P_i &= P_i s_{i,0} + 16(P_i s_{i,1} + 16( \cdots +16P_i s_{i,63})\cdots )
128+
/// \end{aligned}
129+
/// \\]
130+
/// so each \\( s_i P_i \\) can be computed by alternately adding
131+
/// a precomputed multiple \\( P_i s_{i,j} \\) of \\( P_i \\) and
132+
/// repeatedly doubling.
133+
///
134+
/// Now consider the two-dimensional sum
135+
/// \\[
136+
/// \begin{aligned}
137+
/// s\_1 P\_1 &=& P\_1 s\_{1,0} &+& 16 (P\_1 s\_{1,1} &+& 16 ( \cdots &+& 16 P\_1 s\_{1,63}&) \cdots ) \\\\
138+
/// + & & + & & + & & & & + & \\\\
139+
/// s\_2 P\_2 &=& P\_2 s\_{2,0} &+& 16 (P\_2 s\_{2,1} &+& 16 ( \cdots &+& 16 P\_2 s\_{2,63}&) \cdots ) \\\\
140+
/// + & & + & & + & & & & + & \\\\
141+
/// \vdots & & \vdots & & \vdots & & & & \vdots & \\\\
142+
/// + & & + & & + & & & & + & \\\\
143+
/// s\_n P\_n &=& P\_n s\_{n,0} &+& 16 (P\_n s\_{n,1} &+& 16 ( \cdots &+& 16 P\_n s\_{n,63}&) \cdots )
144+
/// \end{aligned}
145+
/// \\]
146+
/// The sum of the left-hand column is the result \\( Q \\); by
147+
/// computing the two-dimensional sum on the right column-wise,
148+
/// top-to-bottom, then right-to-left, we need to multiply by \\(
149+
/// 16\\) only once per column, sharing the doublings across all
150+
/// of the input points.
151+
fn multiscalar_mul(
152+
scalar_digits: &[[i8; 64]],
153+
lookup_tables: &[LookupTable<ProjectiveNielsPoint>],
154+
) -> EdwardsPoint {
155+
let mut Q = EdwardsPoint::identity();
156+
for j in (0..64).rev() {
157+
Q = Q.mul_by_pow_2(4);
158+
let it = scalar_digits.iter().zip(lookup_tables.iter());
159+
for (s_i, lookup_table_i) in it {
160+
// R_i = s_{i,j} * P_i
161+
let R_i = lookup_table_i.select(s_i[j]);
162+
// Q = Q + R_i
163+
Q = (&Q + &R_i).as_extended();
164+
}
165+
}
166+
167+
Q
168+
}
169+
170+
#[cfg(feature = "alloc")]
147171
impl VartimeMultiscalarMul for Straus {
148172
type Point = EdwardsPoint;
149173

@@ -167,6 +191,7 @@ impl VartimeMultiscalarMul for Straus {
167191
};
168192
use crate::traits::Identity;
169193
use crate::window::NafLookupTable5;
194+
use core::cmp::Ordering;
170195

171196
let nafs: Vec<_> = scalars
172197
.into_iter()

curve25519-dalek/src/backend/vector/scalar_mul/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ pub mod variable_base;
1818
pub mod vartime_double_base;
1919

2020
#[allow(missing_docs)]
21-
#[cfg(feature = "alloc")]
2221
pub mod straus;
2322

2423
#[allow(missing_docs)]

0 commit comments

Comments
 (0)