-
Notifications
You must be signed in to change notification settings - Fork 11
Description
The goal of this issue is to provide a guideline on how to fix status-im/nimbus-eth1#1584.
Also pinging @treeform, @guzba on how to implement fast RSA (modexp is the bottleneck)
and @dlesnoff for nim/bigint
Here is a write-up on how to implement fast modular exponentiation.
Recommended textbook for implementer:
- Modern Computer Arithmetic https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
- Handbook of Applied Cryptography, chapter 14: https://cacr.uwaterloo.ca/hac/about/chap14.pdf
We assume 64-bit words.
Vocabulary
In textbook, you'll encounter the term "reduction" which should be read "modular reduction" and just means the remainder after Euclidean division.
The assembly myth
First of all let's dispel a myth, that assembly is the key for speed.
That's true but missing the big picture.
Some benchmarks on Constantine modular exponentiation on BN254 prime field by a 254-bit integer:
- GCC 13.1.1 + assembly
- constant-time 131147.541 ops/s, 7625 ns/op, 25180 CPU cycles
- variable-time 182882.224 ops/s, 5468 ns/op, 18056 CPU cycles
- GCC 13.1.1 - no assembly
- constant-time 87351.502 ops/s, 11448 ns/op, 37805 CPU cycles
- variable-time 103842.160 ops/s, 9630 ns/op, 31801 CPU cycles
- Clang 15.0.7 + assembly
- constant-time 131268.049 ops/s, 7618 ns/op, 25157 CPU cycles
- variable-time 184638.109 ops/s, 5416 ns/op, 17887 CPU cycles
- Clang 15.0.7 - no assembly
- constant-time 116726.976 ops/s, 8567 ns/op, 28291 CPU cycles
- variable-time 143513.203 ops/s, 6968 ns/op, 23012 CPU cycles
So the ratio assembly/no assembly is "only" 30% with Clang.
And don't use GCC with bigints, it's just bad if you don't use assembly.
The big picture
Modular exponentiation is implemented through an algorithm called double-and-add (or for exponentiation multiply-and-square), which does as many modular squarings as the number of bits in the exponent and as many modular multiplications as the number of set bits in the exponents.
For random numbers, about 50% of the bits are set. Assuming 256-bits, that's 384 modular multiplications/squarings.
Each modular multiplication is naively a multiplication 256-bit x 256-bit -> 512-bit and then modulo a 256-bit number.
The bottleneck
Let's take https://www.agner.org/optimize/instruction_tables.pdf and have a look at the speed of
the DIV
instruction, which is necessary to compute modulo.
x86 started to be extremely optimized for BigInt after Broadwell which introduced ADCX and ADOX (and MULX was introduced in Haswell, Broadwell predecessor).
DIV on 64-bit input takes 36 cycles, and has a latency of up to 95 cycles (i.e. anything that depends on that result may wait up to 95 cycles before proceeding).
In comparison add and shifts are just 1 cycle. So anything that uses division starts with a heavy disadvantage.
Note: that disadvantage is still faster than doing bit-by-bit division like here (the algorithm is chosen if there is at most 8-bit of length difference between operands)
nim-stint/stint/private/uint_div.nim
Lines 206 to 230 in 94fc521
func divmodBS(x, y: UintImpl, q, r: var UintImpl) = | |
## Division for multi-precision unsigned uint | |
## Implementation through binary shift division | |
doAssert y.isZero.not() # This should be checked on release mode in the divmod caller proc | |
type SubTy = type x.lo | |
var | |
shift = y.leadingZeros - x.leadingZeros | |
d = y shl shift | |
r = x | |
while shift >= 0: | |
q += q | |
if r >= d: | |
r -= d | |
q.lo = q.lo or one(SubTy) | |
d = d shr 1 | |
dec(shift) | |
const BinaryShiftThreshold = 8 # If the difference in bit-length is below 8 | |
# binary shift is probably faster |
Back of the napkin perf:
A 256-bit modular reduction will need 4 64-bit DIV (and other things). So we already have a cost of about 400 cycles.
We need that on 384 operations, so 384x400 = 153600 cycles. That's over 5x more costly that my slow benchmark of GCC without assembly (well it's on 254-bit instead of 256-bit)
And there is a lot more work beside the divisions, see
nim-stint/stint/private/uint_div.nim
Lines 131 to 169 in 94fc521
func div2n1n[T: SomeUnsignedInt](q, r: var T, n_hi, n_lo, d: T) = | |
# doAssert leadingZeros(d) == 0, "Divisor was not normalized" | |
const | |
size = bitsof(q) | |
halfSize = size div 2 | |
halfMask = (1.T shl halfSize) - 1.T | |
template halfQR(n_hi, n_lo, d, d_hi, d_lo: T): tuple[q,r: T] = | |
var (q, r) = divmod(n_hi, d_hi) | |
let m = q * d_lo | |
r = (r shl halfSize) or n_lo | |
# Fix the reminder, we're at most 2 iterations off | |
if r < m: | |
dec q | |
r += d | |
if r >= d and r < m: | |
dec q | |
r += d | |
r -= m | |
(q, r) | |
let | |
d_hi = d shr halfSize | |
d_lo = d and halfMask | |
n_lohi = n_lo shr halfSize | |
n_lolo = n_lo and halfMask | |
# First half of the quotient | |
let (q1, r1) = halfQR(n_hi, n_lohi, d, d_hi, d_lo) | |
# Second half | |
let (q2, r2) = halfQR(r1, n_lolo, d, d_hi, d_lo) | |
q = (q1 shl halfSize) or q2 | |
r = r2 |
i.e. all the example implementations on Wikipedia are really slow: https://en.wikipedia.org/wiki/Modular_arithmetic#Example_implementations
How to avoid division
There are 2 main techniques to avoid costly divisions:
Barret reduction
https://en.wikipedia.org/wiki/Barrett_reduction
Instead of doing a*b mod m
you do a*b*(2⁶⁴)ᵏ/m
and then you shift by k words (i.e. divide by 2⁶⁴)
This is called Barret reduction and is interesting when (2⁶⁴)ᵏ/m can be reused many times.
k is chosen so that the division by m has an inconsequential rounding error.
Montgomery reduction
- https://en.wikipedia.org/wiki/Montgomery_modular_multiplication
- https://eprint.iacr.org/2017/1057.pdf
Montgomery reduction uses a similar approach to Barret reduction but with lower complexity
at the price of the need to "transport"/convert the number being reduced to the "Montgomery domain"
In practice we do all computation on a' = aR (mod m)
with R = (2⁶⁴)^(numWords) (mod m)
, numWords = 4 for 256-bit numbers.
Once we have numbers in the Montgomery domain, there is an operation called Montgomery modular multiplication (montMul) that does: montMul(a', b') = montMul(aR, bR) = abR (mod m)
montMul
is the fastest modular multiplication algorithm that works on almost all moduli.
Why almost?
Well, it works on odd moduli and all primes besides 2 are odd.
Anyway, if we have an odd modulus, we can compute the Montgomery modular exponentiation instead of Modular exponentiation.
Converting to-from the Montgomery domain only requires montMul
by R² or by 1.
Reconciling Montgomery and even modulus.
One issue is that in the Ethereum Virtual Machine, modexp can receive any modulus, not just odd.
Thankfully, we can invoke the Chinese Remainder Theorem (CRT) (https://en.wikipedia.org/wiki/Chinese_remainder_theorem), that states that if your modulus m = a * b
with a and b coprimes, you can compute mod a
and mod b
separately and it gives you a way to recombine mod m
So if you have an even number, you split it into a = 2ᵏ
an even power of 2, and b an odd number.
Odd numbers are coprimes with power of 2 so you can for sure apply the CRT.
- Doing modulo a power of 2 is very easy,
x mod 2ᵏ == x and (2ᵏ - 1)
. - And for the odd number you have Montgomery.
Engineering & Implementation
Now that we have the theory, let's look at engineering problems and reference code.
Fast Montgomery multiplication
There are many ways to multiply bigints, there were categorized in Acar thesis: https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf
- Separated Operand Scanning (SOS): multiprecision multiplication and Montgomery reduction are done separately
- Coarsely Integrated Operand Scanning (CIOS): This interleaves schoolbook multiplication and Montgomery reduction
- Finely Integrated Operand Scanning (FIOS): similar to previous
- Finely Integrated Product Scanning (FIPS): When doing schoolbook multiplication you can compute rows-by-rows or columns-by-columns, this does columns-by-columns.
- Coarsely integrated Hybrid Scanning (CIHS)
On modern architectures, with MULX/ADCX/ADOX, CIOS is the fastest. without MULX/ADCX/ADOX (and so potentially on everything besides x86), FIPS is the fastest. Why? AFAIK data movement is better and less carries to save/restore.
The no-assembly algorithm for both is available here:
- CIOS: https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/arithmetic/limbs_montgomery.nim#L215-L261
- FIPS: https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/arithmetic/limbs_montgomery.nim#L263-L305
Fast large multiplication
One issue with multiplication is that it's O(n²) regarding the number of words since we need to multiply each word in each multiplicand with each word in the other.
The Karatsuba algorithm has a complexity of about O(n¹˙⁵) with a constant factor that becomes negligeable at around 8~12 words (so 512-768, to be measured), https://en.wikipedia.org/wiki/Karatsuba_algorithm
Fast exponentiation
Now we can look into exponentiation. The basic algorithm is square-and-multiply https://en.wikipedia.org/wiki/Exponentiation_by_squaring
You scan the bits of the exponent (from left-to-right, i.e. MSB-to-LSB or right to left, i.e. LSB-to-MSB, both variants are possible), you always square and then you multiply by something (depending on your scanning direction) if the bit if set.
However, you can precompute bit pattern, for example
1101, 1001, 0111 and instead of doing 2 or 3 multiplication, ensure that you only do 1 every 4 squarings.
That's called the window
method.
An example, fixed window method is available in Constantine, . As there is no constant-time requirement in Stint it can be simplified, but I needed to ensure I could use the value of secret bits for windows (for example RSA uses modular exponentiation) without revealing what the secret bits are. https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/arithmetic/limbs_montgomery.nim#L614-L836
Sliding window
An extra optimization is using window of variable sizes called sliding window. I don't have an implementation of that (cannot be made constant-time) but Wikipedia has a pseudocode https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Sliding-window_method
NAF/signed recoding
In your research you might come across NAF or signed recoding, this only applies to elliptic curves (because inversion 1/a (mod m) is not cheap in modular arithmetic but -P is cheap in elliptic curve)
Montgomery domain, conversion and constants
Computing R and R² can be done like this: https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/config/precompute.nim#L307-L350
You will also need 1/M0 (mod 2⁶⁴) with M0 being the first limb of your modulus M.
Wrapup
The implementation steps should be:
- Implement Montgomery Modular Multiplication montMul
- Compute Montgomery constants R and R²
- implement conversion to Montgomery domain with
montMul(a, R²) = a' = aR (mod m)
- implement conversion from Montgomery domain with
montMul(a', 1) = a (mod m)
- implement conversion to Montgomery domain with
- implement exponentiation by squaring
- Optional
- implement window method
- implement assembly
and use Clang