Skip to content

Commit 1c898bf

Browse files
committed
Use tuple for MultiscalarMul
1 parent 6f116ab commit 1c898bf

File tree

6 files changed

+85
-109
lines changed

6 files changed

+85
-109
lines changed

curve25519-dalek/src/backend/mod.rs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -194,49 +194,46 @@ impl VartimePrecomputedStraus {
194194

195195
#[allow(missing_docs)]
196196
pub fn straus_multiscalar_mul<const N: usize>(
197-
scalars: &[Scalar; N],
198-
points: &[EdwardsPoint; N],
197+
points_and_scalars: &[(EdwardsPoint, Scalar); N],
199198
) -> EdwardsPoint {
200199
match get_selected_backend() {
201200
#[cfg(curve25519_dalek_backend = "simd")]
202201
BackendKind::Avx2 => {
203-
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul(scalars, points)
202+
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul(points_and_scalars)
204203
}
205204
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
206205
BackendKind::Avx512 => {
207206
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul(
208207
scalars, points,
209208
)
210209
}
211-
BackendKind::Serial => serial::scalar_mul::straus::Straus::multiscalar_mul(scalars, points),
210+
BackendKind::Serial => {
211+
serial::scalar_mul::straus::Straus::multiscalar_mul(points_and_scalars)
212+
}
212213
}
213214
}
214215

215216
#[allow(missing_docs)]
216217
#[cfg(feature = "alloc")]
217-
pub fn straus_multiscalar_alloc_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
218+
pub fn straus_multiscalar_alloc_mul<I, P, S>(points_and_scalars: I) -> EdwardsPoint
218219
where
219-
I: IntoIterator,
220-
I::Item: core::borrow::Borrow<Scalar>,
221-
J: IntoIterator,
222-
J::Item: core::borrow::Borrow<EdwardsPoint>,
220+
I: IntoIterator<Item = (P, S)>,
221+
P: core::borrow::Borrow<EdwardsPoint>,
222+
S: core::borrow::Borrow<Scalar>,
223223
{
224224
match get_selected_backend() {
225225
#[cfg(curve25519_dalek_backend = "simd")]
226226
BackendKind::Avx2 => {
227-
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_alloc_mul::<I, J>(
228-
scalars, points,
229-
)
227+
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_alloc_mul(points_and_scalars)
230228
}
231229
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
232230
BackendKind::Avx512 => {
233-
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_alloc_mul::<
234-
I,
235-
J,
236-
>(scalars, points)
231+
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_alloc_mul(
232+
points_and_scalars,
233+
)
237234
}
238235
BackendKind::Serial => {
239-
serial::scalar_mul::straus::Straus::multiscalar_alloc_mul::<I, J>(scalars, points)
236+
serial::scalar_mul::straus::Straus::multiscalar_alloc_mul(points_and_scalars)
240237
}
241238
}
242239
}

curve25519-dalek/src/backend/serial/scalar_mul/straus.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,38 +55,38 @@ impl MultiscalarMul for Straus {
5555
type Point = EdwardsPoint;
5656

5757
fn multiscalar_mul<const N: usize>(
58-
scalars: &[Scalar; N],
59-
points: &[EdwardsPoint; N],
58+
points_and_scalars: &[(EdwardsPoint, Scalar); N],
6059
) -> EdwardsPoint {
61-
let lookup_tables: [_; N] =
62-
core::array::from_fn(|index| LookupTable::<ProjectiveNielsPoint>::from(&points[index]));
60+
let lookup_tables: [_; N] = core::array::from_fn(|index| {
61+
LookupTable::<ProjectiveNielsPoint>::from(&points_and_scalars[index].0)
62+
});
6363

64-
let scalar_digits: [_; N] = core::array::from_fn(|index| scalars[index].as_radix_16());
64+
let scalar_digits: [_; N] =
65+
core::array::from_fn(|index| points_and_scalars[index].1.as_radix_16());
6566

6667
multiscalar_mul(&scalar_digits, &lookup_tables)
6768
}
6869

6970
#[cfg(feature = "alloc")]
70-
fn multiscalar_alloc_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
71+
fn multiscalar_alloc_mul<I, P, S>(points_and_scalars: I) -> EdwardsPoint
7172
where
72-
I: IntoIterator,
73-
I::Item: Borrow<Scalar>,
74-
J: IntoIterator,
75-
J::Item: Borrow<EdwardsPoint>,
73+
I: IntoIterator<Item = (P, S)>,
74+
P: Borrow<EdwardsPoint>,
75+
S: Borrow<Scalar>,
7676
{
77-
let lookup_tables: Vec<_> = points
78-
.into_iter()
79-
.map(|point| LookupTable::<ProjectiveNielsPoint>::from(point.borrow()))
80-
.collect();
81-
8277
// This puts the scalar digits into a heap-allocated Vec.
8378
// To ensure that these are erased, pass ownership of the Vec into a
8479
// Zeroizing wrapper.
8580
#[cfg_attr(not(feature = "zeroize"), allow(unused_mut))]
86-
let mut scalar_digits: Vec<_> = scalars
81+
let (lookup_tables, mut scalar_digits): (Vec<_>, Vec<_>) = points_and_scalars
8782
.into_iter()
88-
.map(|s| s.borrow().as_radix_16())
89-
.collect();
83+
.map(|(p, s)| {
84+
(
85+
LookupTable::<ProjectiveNielsPoint>::from(p.borrow()),
86+
s.borrow().as_radix_16(),
87+
)
88+
})
89+
.unzip();
9090

9191
let Q = multiscalar_mul(&scalar_digits, &lookup_tables);
9292

curve25519-dalek/src/backend/vector/scalar_mul/straus.rs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,39 @@ pub mod spec {
5353
type Point = EdwardsPoint;
5454

5555
fn multiscalar_mul<const N: usize>(
56-
scalars: &[Scalar; N],
57-
points: &[EdwardsPoint; N],
56+
points_and_scalars: &[(EdwardsPoint, Scalar); N],
5857
) -> EdwardsPoint {
5958
// Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P]
6059
// for each input point P
61-
let lookup_tables: [_; N] =
62-
core::array::from_fn(|index| LookupTable::<CachedPoint>::from(&points[index]));
60+
let lookup_tables: [_; N] = core::array::from_fn(|index| {
61+
LookupTable::<CachedPoint>::from(&points_and_scalars[index].0)
62+
});
6363

64-
let scalar_digits: [_; N] = core::array::from_fn(|index| scalars[index].as_radix_16());
64+
let scalar_digits: [_; N] =
65+
core::array::from_fn(|index| points_and_scalars[index].1.as_radix_16());
6566

6667
multiscalar_mul(&scalar_digits, &lookup_tables)
6768
}
6869

6970
#[cfg(feature = "alloc")]
70-
fn multiscalar_alloc_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
71+
fn multiscalar_alloc_mul<I, P, S>(points_and_scalars: I) -> EdwardsPoint
7172
where
72-
I: IntoIterator,
73-
I::Item: Borrow<Scalar>,
74-
J: IntoIterator,
75-
J::Item: Borrow<EdwardsPoint>,
73+
I: IntoIterator<Item = (P, S)>,
74+
P: Borrow<EdwardsPoint>,
75+
S: Borrow<Scalar>,
7676
{
7777
// Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P]
7878
// for each input point P
79-
let lookup_tables: Vec<_> = points
79+
let (lookup_tables, scalar_digits_vec): (Vec<_>, Vec<_>) = points_and_scalars
8080
.into_iter()
81-
.map(|point| LookupTable::<CachedPoint>::from(point.borrow()))
82-
.collect();
81+
.map(|(p, s)| {
82+
(
83+
LookupTable::<CachedPoint>::from(p.borrow()),
84+
s.borrow().as_radix_16(),
85+
)
86+
})
87+
.unzip();
8388

84-
let scalar_digits_vec: Vec<_> = scalars
85-
.into_iter()
86-
.map(|s| s.borrow().as_radix_16())
87-
.collect();
8889
// Pass ownership to a `Zeroizing` wrapper
8990
#[cfg(feature = "zeroize")]
9091
let scalar_digits_vec = zeroize::Zeroizing::new(scalar_digits_vec);

curve25519-dalek/src/edwards.rs

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -898,38 +898,19 @@ impl MultiscalarMul for EdwardsPoint {
898898
type Point = EdwardsPoint;
899899

900900
fn multiscalar_mul<const N: usize>(
901-
scalars: &[Scalar; N],
902-
points: &[Self::Point; N],
901+
points_and_scalars: &[(Self::Point, Scalar); N],
903902
) -> Self::Point {
904-
crate::backend::straus_multiscalar_mul(scalars, points)
903+
crate::backend::straus_multiscalar_mul(points_and_scalars)
905904
}
906905

907906
#[cfg(feature = "alloc")]
908-
fn multiscalar_alloc_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
907+
fn multiscalar_alloc_mul<I, P, S>(points_and_scalars: I) -> EdwardsPoint
909908
where
910-
I: IntoIterator,
911-
I::Item: Borrow<Scalar>,
912-
J: IntoIterator,
913-
J::Item: Borrow<EdwardsPoint>,
909+
I: IntoIterator<Item = (P, S)>,
910+
P: Borrow<Self::Point>,
911+
S: Borrow<Scalar>,
914912
{
915-
// Sanity-check lengths of input iterators
916-
let mut scalars = scalars.into_iter();
917-
let mut points = points.into_iter();
918-
919-
// Lower and upper bounds on iterators
920-
let (s_lo, s_hi) = scalars.by_ref().size_hint();
921-
let (p_lo, p_hi) = points.by_ref().size_hint();
922-
923-
// They should all be equal
924-
assert_eq!(s_lo, p_lo);
925-
assert_eq!(s_hi, Some(s_lo));
926-
assert_eq!(p_hi, Some(p_lo));
927-
928-
// Now we know there's a single size. When we do
929-
// size-dependent algorithm dispatch, use this as the hint.
930-
let _size = s_lo;
931-
932-
crate::backend::straus_multiscalar_alloc_mul(scalars, points)
913+
crate::backend::straus_multiscalar_alloc_mul(points_and_scalars)
933914
}
934915
}
935916

@@ -2202,7 +2183,7 @@ mod test {
22022183
let Gs = xs.iter().map(EdwardsPoint::mul_base).collect::<Vec<_>>();
22032184

22042185
// Compute H1 = <xs, Gs> (consttime)
2205-
let H1 = EdwardsPoint::multiscalar_alloc_mul(&xs, &Gs);
2186+
let H1 = EdwardsPoint::multiscalar_alloc_mul(Gs.iter().zip(&xs));
22062187
// Compute H2 = <xs, Gs> (vartime)
22072188
let H2 = EdwardsPoint::vartime_multiscalar_mul(&xs, &Gs);
22082189
// Compute H3 = <xs, Gs> = sum(xi^2) * B
@@ -2360,8 +2341,9 @@ mod test {
23602341
&[A, constants::ED25519_BASEPOINT_POINT],
23612342
);
23622343
let result_consttime = EdwardsPoint::multiscalar_alloc_mul(
2363-
&[A_SCALAR, B_SCALAR],
2364-
&[A, constants::ED25519_BASEPOINT_POINT],
2344+
[A, constants::ED25519_BASEPOINT_POINT]
2345+
.into_iter()
2346+
.zip([A_SCALAR, B_SCALAR]),
23652347
);
23662348

23672349
assert_eq!(result_vartime.compress(), result_consttime.compress());

curve25519-dalek/src/ristretto.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,26 +1011,26 @@ impl MultiscalarMul for RistrettoPoint {
10111011
type Point = RistrettoPoint;
10121012

10131013
fn multiscalar_mul<const N: usize>(
1014-
scalars: &[Scalar; N],
1015-
points: &[Self::Point; N],
1016-
) -> Self::Point {
1017-
let extended_points: [_; N] = core::array::from_fn(|index| points[index].0);
1018-
RistrettoPoint(EdwardsPoint::multiscalar_mul(scalars, &extended_points))
1014+
points_and_scalars: &[(RistrettoPoint, Scalar); N],
1015+
) -> RistrettoPoint {
1016+
let points_and_scalars: [_; N] = core::array::from_fn(|index| {
1017+
let (p, s) = points_and_scalars[index];
1018+
(p.0, s)
1019+
});
1020+
RistrettoPoint(EdwardsPoint::multiscalar_mul(&points_and_scalars))
10191021
}
10201022

10211023
#[cfg(feature = "alloc")]
1022-
fn multiscalar_alloc_mul<I, J>(scalars: I, points: J) -> RistrettoPoint
1024+
fn multiscalar_alloc_mul<I, P, S>(points_and_scalars: I) -> RistrettoPoint
10231025
where
1024-
I: IntoIterator,
1025-
I::Item: Borrow<Scalar>,
1026-
J: IntoIterator,
1027-
J::Item: Borrow<RistrettoPoint>,
1026+
I: IntoIterator<Item = (P, S)>,
1027+
P: Borrow<RistrettoPoint>,
1028+
S: Borrow<Scalar>,
10281029
{
1029-
let extended_points = points.into_iter().map(|P| P.borrow().0);
1030-
RistrettoPoint(EdwardsPoint::multiscalar_alloc_mul(
1031-
scalars,
1032-
extended_points,
1033-
))
1030+
let points_and_scalars = points_and_scalars
1031+
.into_iter()
1032+
.map(|(p, s)| (p.borrow().0, s));
1033+
RistrettoPoint(EdwardsPoint::multiscalar_alloc_mul(points_and_scalars))
10341034
}
10351035
}
10361036

curve25519-dalek/src/traits.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,15 @@ pub trait MultiscalarMul {
111111
/// let R = P + Q;
112112
///
113113
/// // A1 = a*P + b*Q + c*R
114-
/// let abc = [a,b,c];
115-
/// let A1 = RistrettoPoint::multiscalar_mul(&abc, &[P,Q,R]);
114+
/// let A1 = RistrettoPoint::multiscalar_mul(&[(P, a), (Q, b), (R, c)]);
116115
///
117116
/// // A2 = (-a)*P + (-b)*Q + (-c)*R
118-
/// let minus_abc = abc.map(|x| -x);
119-
/// let A2 = RistrettoPoint::multiscalar_mul(&minus_abc, &[P,Q,R]);
117+
/// let A2 = RistrettoPoint::multiscalar_mul(&[(P, -a), (Q, -b), (R, -c)]);
120118
///
121119
/// assert_eq!(A1.compress(), (-A2).compress());
122120
/// ```
123121
fn multiscalar_mul<const N: usize>(
124-
scalars: &[Scalar; N],
125-
points: &[Self::Point; N],
122+
points_and_scalars: &[(Self::Point, Scalar); N],
126123
) -> Self::Point;
127124

128125
/// Given an iterator of (possibly secret) scalars and an iterator of
@@ -160,24 +157,23 @@ pub trait MultiscalarMul {
160157
///
161158
/// // A1 = a*P + b*Q + c*R
162159
/// let abc = [a,b,c];
163-
/// let A1 = RistrettoPoint::multiscalar_alloc_mul(&abc, &[P,Q,R]);
160+
/// let A1 = RistrettoPoint::multiscalar_alloc_mul([P,Q,R].into_iter().zip(&abc));
164161
/// // Note: (&abc).into_iter(): Iterator<Item=&Scalar>
165162
///
166163
/// // A2 = (-a)*P + (-b)*Q + (-c)*R
167164
/// let minus_abc = abc.iter().map(|x| -x);
168-
/// let A2 = RistrettoPoint::multiscalar_alloc_mul(minus_abc, &[P,Q,R]);
165+
/// let A2 = RistrettoPoint::multiscalar_alloc_mul([P,Q,R].into_iter().zip(minus_abc));
169166
/// // Note: minus_abc.into_iter(): Iterator<Item=Scalar>
170167
///
171168
/// assert_eq!(A1.compress(), (-A2).compress());
172169
/// # }
173170
/// ```
174171
#[cfg(feature = "alloc")]
175-
fn multiscalar_alloc_mul<I, J>(scalars: I, points: J) -> Self::Point
172+
fn multiscalar_alloc_mul<I, P, S>(points_and_scalars: I) -> Self::Point
176173
where
177-
I: IntoIterator,
178-
I::Item: Borrow<Scalar>,
179-
J: IntoIterator,
180-
J::Item: Borrow<Self::Point>;
174+
I: IntoIterator<Item = (P, S)>,
175+
P: Borrow<Self::Point>,
176+
S: Borrow<Scalar>;
181177
}
182178

183179
/// A trait for variable-time multiscalar multiplication without precomputation.

0 commit comments

Comments
 (0)