Skip to content

Commit 0ba2750

Browse files
Optimize BDN Signature/Key Aggregation (#546)
* Add BDN test fixtures * Remove n^2 algorithm from signature/key aggregation CountEnabled and IndexOfNthEnabled are both O(n) in the size of the mask, making this loop n^2. The BLS operations still tend to be the slow part, but the n^2 factor will start to show up with thousands of keys. * Remove an unnecessary loop from hashPointToR * Introduce a new CachedMask for BDN This new mask will pre-compute reusable values, speeding up repeated verification and aggregation of aggregate signatures (mostly the former). * Ignore golangci lint * Move Mask into BDN and remove the interface * fix docs Co-authored-by: AnomalRoil <[email protected]> * Document mutability of Mask fields --------- Co-authored-by: AnomalRoil <[email protected]>
1 parent a318fba commit 0ba2750

File tree

4 files changed

+236
-60
lines changed

4 files changed

+236
-60
lines changed

sign/bdn/bdn.go

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ package bdn
1212
import (
1313
"crypto/cipher"
1414
"errors"
15+
"fmt"
1516
"math/big"
1617

1718
"go.dedis.ch/kyber/v4"
@@ -31,23 +32,16 @@ var modulus128 = new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewI
3132
// We also use the entire roster so that the coefficient will vary for the same
3233
// public key used in different roster
3334
func hashPointToR(pubs []kyber.Point) ([]kyber.Scalar, error) {
34-
peers := make([][]byte, len(pubs))
35-
for i, pub := range pubs {
36-
peer, err := pub.MarshalBinary()
37-
if err != nil {
38-
return nil, err
39-
}
40-
41-
peers[i] = peer
42-
}
43-
4435
h, err := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil)
4536
if err != nil {
4637
return nil, err
4738
}
48-
49-
for _, peer := range peers {
50-
_, err := h.Write(peer)
39+
for _, pub := range pubs {
40+
peer, err := pub.MarshalBinary()
41+
if err != nil {
42+
return nil, err
43+
}
44+
_, err = h.Write(peer)
5145
if err != nil {
5246
return nil, err
5347
}
@@ -128,62 +122,58 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error {
128122

129123
// AggregateSignatures aggregates the signatures using a coefficient for each
130124
// one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128}
131-
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
132-
if len(sigs) != mask.CountEnabled() {
133-
return nil, errors.New("length of signatures and public keys must match")
134-
}
135-
136-
coefs, err := hashPointToR(mask.Publics())
137-
if err != nil {
138-
return nil, err
139-
}
140-
125+
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *Mask) (kyber.Point, error) {
141126
agg := scheme.sigGroup.Point()
142-
for i, buf := range sigs {
143-
peerIndex := mask.IndexOfNthEnabled(i)
144-
if peerIndex < 0 {
145-
// this should never happen as we check the lenths at the beginning
146-
// an error here is probably a bug in the mask
147-
return nil, errors.New("couldn't find the index")
127+
for i := range mask.publics {
128+
if enabled, err := mask.GetBit(i); err != nil {
129+
// this should never happen because of the loop boundary
130+
// an error here is probably a bug in the mask implementation
131+
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
132+
} else if !enabled {
133+
continue
134+
}
135+
136+
if len(sigs) == 0 {
137+
return nil, errors.New("length of signatures and public keys must match")
148138
}
149139

140+
buf := sigs[0]
141+
sigs = sigs[1:]
142+
150143
sig := scheme.sigGroup.Point()
151-
err = sig.UnmarshalBinary(buf)
144+
err := sig.UnmarshalBinary(buf)
152145
if err != nil {
153146
return nil, err
154147
}
155148

156-
sigC := sig.Clone().Mul(coefs[peerIndex], sig)
149+
sigC := sig.Clone().Mul(mask.publicCoefs[i], sig)
157150
// c+1 because R is in the range [1, 2^128] and not [0, 2^128-1]
158151
sigC = sigC.Add(sigC, sig)
159152
agg = agg.Add(agg, sigC)
160153
}
161154

155+
if len(sigs) > 0 {
156+
return nil, errors.New("length of signatures and public keys must match")
157+
}
158+
162159
return agg, nil
163160
}
164161

165162
// AggregatePublicKeys aggregates a set of public keys (similarly to
166163
// AggregateSignatures for signatures) using the hash function
167164
// H: keyGroup -> R with R = {1, ..., 2^128}.
168-
func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) {
169-
coefs, err := hashPointToR(mask.Publics())
170-
if err != nil {
171-
return nil, err
172-
}
173-
165+
func (scheme *Scheme) AggregatePublicKeys(mask *Mask) (kyber.Point, error) {
174166
agg := scheme.keyGroup.Point()
175-
for i := 0; i < mask.CountEnabled(); i++ {
176-
peerIndex := mask.IndexOfNthEnabled(i)
177-
if peerIndex < 0 {
167+
for i := range mask.publics {
168+
if enabled, err := mask.GetBit(i); err != nil {
178169
// this should never happen because of the loop boundary
179170
// an error here is probably a bug in the mask implementation
180-
return nil, errors.New("couldn't find the index")
171+
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
172+
} else if !enabled {
173+
continue
181174
}
182175

183-
pub := mask.Publics()[peerIndex]
184-
pubC := pub.Clone().Mul(coefs[peerIndex], pub)
185-
pubC = pubC.Add(pubC, pub)
186-
agg = agg.Add(agg, pubC)
176+
agg = agg.Add(agg, mask.publicTerms[i])
187177
}
188178

189179
return agg, nil
@@ -217,14 +207,14 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error {
217207
// AggregateSignatures aggregates the signatures using a coefficient for each
218208
// one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128}
219209
// Deprecated: use the new scheme methods instead.
220-
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
210+
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *Mask) (kyber.Point, error) {
221211
return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask)
222212
}
223213

224214
// AggregatePublicKeys aggregates a set of public keys (similarly to
225215
// AggregateSignatures for signatures) using the hash function
226216
// H: G2 -> R with R = {1, ..., 2^128}.
227217
// Deprecated: use the new scheme methods instead.
228-
func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) {
218+
func AggregatePublicKeys(suite pairing.Suite, mask *Mask) (kyber.Point, error) {
229219
return NewSchemeOnG1(suite).AggregatePublicKeys(mask)
230220
}

sign/bdn/bdn_test.go

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
package bdn
22

33
import (
4+
"encoding"
5+
"encoding/hex"
46
"fmt"
57
"testing"
68

79
"github.com/stretchr/testify/require"
810
"go.dedis.ch/kyber/v4"
11+
"go.dedis.ch/kyber/v4/pairing/bls12381/kilic"
912
"go.dedis.ch/kyber/v4/pairing/bn256"
10-
"go.dedis.ch/kyber/v4/sign"
1113
"go.dedis.ch/kyber/v4/sign/bls"
1214
"go.dedis.ch/kyber/v4/util/random"
1315
)
@@ -30,7 +32,7 @@ func TestBDN_HashPointToR_BN256(t *testing.T) {
3032
require.Equal(t, "933f6013eb3f654f9489d6d45ad04eaf", coefs[2].String())
3133
require.Equal(t, 16, coefs[0].MarshalSize())
3234

33-
mask, _ := sign.NewMask([]kyber.Point{p1, p2, p3}, nil)
35+
mask, _ := NewMask([]kyber.Point{p1, p2, p3}, nil)
3436
mask.SetBit(0, true)
3537
mask.SetBit(1, true)
3638
mask.SetBit(2, true)
@@ -54,7 +56,7 @@ func TestBDN_AggregateSignatures(t *testing.T) {
5456
sig2, err := Sign(suite, private2, msg)
5557
require.NoError(t, err)
5658

57-
mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil)
59+
mask, _ := NewMask([]kyber.Point{public1, public2}, nil)
5860
mask.SetBit(0, true)
5961
mask.SetBit(1, true)
6062

@@ -92,7 +94,7 @@ func TestBDN_SubsetSignature(t *testing.T) {
9294
sig2, err := Sign(suite, private2, msg)
9395
require.NoError(t, err)
9496

95-
mask, _ := sign.NewMask([]kyber.Point{public1, public3, public2}, nil)
97+
mask, _ := NewMask([]kyber.Point{public1, public3, public2}, nil)
9698
mask.SetBit(0, true)
9799
mask.SetBit(2, true)
98100

@@ -131,7 +133,7 @@ func TestBDN_RogueAttack(t *testing.T) {
131133
require.NoError(t, scheme.Verify(agg, msg, sig))
132134

133135
// New scheme that should detect
134-
mask, _ := sign.NewMask(pubs, nil)
136+
mask, _ := NewMask(pubs, nil)
135137
mask.SetBit(0, true)
136138
mask.SetBit(1, true)
137139
agg, err = AggregatePublicKeys(suite, mask)
@@ -149,7 +151,7 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) {
149151
sig2, err := Sign(suite, private2, msg)
150152
require.Nil(b, err)
151153

152-
mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil)
154+
mask, _ := NewMask([]kyber.Point{public1, public2}, nil)
153155
mask.SetBit(0, true)
154156
mask.SetBit(1, false)
155157

@@ -158,3 +160,99 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) {
158160
AggregateSignatures(suite, [][]byte{sig1, sig2}, mask)
159161
}
160162
}
163+
164+
func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) {
165+
suite := kilic.NewBLS12381Suite()
166+
schemeOnG2 := NewSchemeOnG2(suite)
167+
168+
rng := random.New()
169+
pubKeys := make([]kyber.Point, 3000)
170+
privKeys := make([]kyber.Scalar, 3000)
171+
for i := range pubKeys {
172+
privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng)
173+
}
174+
175+
mask, err := NewMask(pubKeys, nil)
176+
require.NoError(b, err)
177+
for i := range pubKeys {
178+
require.NoError(b, mask.SetBit(i, true))
179+
}
180+
181+
msg := []byte("Hello many times Boneh-Lynn-Shacham")
182+
sigs := make([][]byte, len(privKeys))
183+
for i, k := range privKeys {
184+
s, err := schemeOnG2.Sign(k, msg)
185+
require.NoError(b, err)
186+
sigs[i] = s
187+
}
188+
189+
sig, err := schemeOnG2.AggregateSignatures(sigs, mask)
190+
require.NoError(b, err)
191+
sigb, err := sig.MarshalBinary()
192+
require.NoError(b, err)
193+
194+
b.ResetTimer()
195+
for i := 0; i < b.N; i++ {
196+
pk, err := schemeOnG2.AggregatePublicKeys(mask)
197+
require.NoError(b, err)
198+
require.NoError(b, schemeOnG2.Verify(pk, msg, sigb))
199+
}
200+
}
201+
202+
func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T {
203+
t.Helper()
204+
b, err := hex.DecodeString(s)
205+
require.NoError(t, err)
206+
require.NoError(t, into.UnmarshalBinary(b))
207+
return into
208+
}
209+
210+
// This tests exists to make sure we don't accidentally make breaking changes to signature
211+
// aggregation by using checking against known aggregated signatures and keys.
212+
func TestBDNFixtures(t *testing.T) {
213+
suite := bn256.NewSuite()
214+
schemeOnG1 := NewSchemeOnG1(suite)
215+
216+
public1 := unmarshalHex(t, suite.G2().Point(), "1a30714035c7a161e286e54c191b8c68345bd8239c74925a26290e8e1ae97ed6657958a17dca12c943fadceb11b824402389ff427179e0f10194da3c1b771c6083797d2b5915ea78123cbdb99ea6389d6d6b67dcb512a2b552c373094ee5693524e3ebb4a176f7efa7285c25c80081d8cb598745978f1a63b886c09a316b1493")
217+
private1 := unmarshalHex(t, suite.G2().Scalar(), "49cfe5e9f4532670137184d43c0299f8b635bcacf6b0af7cab262494602d9f38")
218+
public2 := unmarshalHex(t, suite.G2().Point(), "603bc61466ec8762ec6de2ba9a80b9d302d08f580d1685ac45a8e404a6ed549719dc0faf94d896a9983ff23423772720e3de5d800bc200de6f7d7e146162d3183b8880c5c0d8b71ca4b3b40f30c12d8cc0679c81a47c239c6aa7e9cc2edab4a927fe865cd413c1c17e3df8f74108e784cd77dd3e161bdaf30019a55826a32a1f")
219+
private2 := unmarshalHex(t, suite.G2().Scalar(), "493abea4bb35b74c78ad9245f9d37883aeb6ee91f7fb0d8a8e11abf7aa2be581")
220+
public3 := unmarshalHex(t, suite.G2().Point(), "56118769a1f0b6286abacaa32109c1497ab0819c5d21f27317e184b6681c283007aa981cb4760de044946febdd6503ab77a4586bc29c04159e53a6fa5dcb9c0261ccd1cb2e28db5204ca829ac9f6be95f957a626544adc34ba3bc542533b6e2f5cbd0567e343641a61a42b63f26c3625f74b66f6f46d17b3bf1688fae4d455ec")
221+
private3 := unmarshalHex(t, suite.G2().Scalar(), "7fb0ebc317e161502208c3c16a4af890dedc3c7b275e8a04e99c0528aa6a19aa")
222+
223+
sig1Exp, err := hex.DecodeString("0913b76987be19f943be23b636cab9a2484507717326bd8bbdcdbbb6b8d5eb9253cfb3597c3fa550ee4972a398813650825a871f8e0b242ae5ddbce1b7c0e2a8")
224+
require.NoError(t, err)
225+
sig2Exp, err := hex.DecodeString("21195d29b1863bca1559e24375211d1411d8a28a8f4c772870b07f4ccda2fd5e337c1315c210475c683e3aa8b87d3aed3f7255b3087daa30d1e1432dd61d7484")
226+
require.NoError(t, err)
227+
sig3Exp, err := hex.DecodeString("3c1ac80345c1733630dbdc8106925c867544b521c259f9fa9678d477e6e5d3d212b09bc0d95137c3dbc0af2241415156c56e757d5577a609293584d045593195")
228+
require.NoError(t, err)
229+
230+
aggSigExp := unmarshalHex(t, suite.G1().Point(), "43c1d2ad5a7d71a08f3cd7495db6b3c81a4547af1b76438b2f215e85ec178fea048f93f6ffed65a69ea757b47761e7178103bb347fd79689652e55b6e0054af2")
231+
aggKeyExp := unmarshalHex(t, suite.G2().Point(), "43b5161ede207b9a69fc93114b0c5022b76cc22e813ba739c7e622d826b132333cd637505399963b94e393ec7f5d4875f82391620b34be1fde1f232204fa4f723935d4dbfb725f059456bcf2557f846c03190969f7b800e904d25b0b5bcbdd421c9877d443f0313c3425dfc1e7e646b665d27b9e649faadef1129f95670d70e1")
232+
233+
msg := []byte("Hello many times Boneh-Lynn-Shacham")
234+
sig1, err := schemeOnG1.Sign(private1, msg)
235+
require.Nil(t, err)
236+
require.Equal(t, sig1Exp, sig1)
237+
238+
sig2, err := schemeOnG1.Sign(private2, msg)
239+
require.Nil(t, err)
240+
require.Equal(t, sig2Exp, sig2)
241+
242+
sig3, err := schemeOnG1.Sign(private3, msg)
243+
require.Nil(t, err)
244+
require.Equal(t, sig3Exp, sig3)
245+
246+
mask, _ := NewMask([]kyber.Point{public1, public2, public3}, nil)
247+
mask.SetBit(0, true)
248+
mask.SetBit(1, false)
249+
mask.SetBit(2, true)
250+
251+
aggSig, err := schemeOnG1.AggregateSignatures([][]byte{sig1, sig3}, mask)
252+
require.NoError(t, err)
253+
require.True(t, aggSigExp.Equal(aggSig))
254+
255+
aggKey, err := schemeOnG1.AggregatePublicKeys(mask)
256+
require.NoError(t, err)
257+
require.True(t, aggKeyExp.Equal(aggKey))
258+
}

sign/mask.go renamed to sign/bdn/mask.go

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,36 @@
1-
// Package sign contains useful tools for the different signing algorithms.
2-
package sign
1+
package bdn
32

43
import (
54
"errors"
65
"fmt"
6+
"slices"
77

88
"go.dedis.ch/kyber/v4"
99
)
1010

1111
// Mask is a bitmask of the participation to a collective signature.
1212
type Mask struct {
13-
mask []byte
13+
// The bitmask indicating which public keys are enabled/disabled for aggregation. This is
14+
// the only mutable field.
15+
mask []byte
16+
17+
// The following fields are immutable and should not be changed after the mask is created.
18+
// They may be shared between multiple masks.
19+
20+
// Public keys for aggregation & signature verification.
1421
publics []kyber.Point
22+
// Coefficients used when aggregating signatures.
23+
publicCoefs []kyber.Scalar
24+
// Terms used to aggregate public keys
25+
publicTerms []kyber.Point
1526
}
1627

1728
// NewMask creates a new mask from a list of public keys. If a key is provided, it
1829
// will set the bit of the key to 1 or return an error if it is not found.
30+
//
31+
// The returned Mask will contain pre-computed terms and coefficients for all provided public
32+
// keys, so it should be re-used for optimal performance (e.g., by creating a "base" mask and
33+
// cloning it whenever aggregating signatures and/or public keys).
1934
func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) {
2035
m := &Mask{
2136
publics: publics,
@@ -33,6 +48,18 @@ func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) {
3348
return nil, errors.New("key not found")
3449
}
3550

51+
var err error
52+
m.publicCoefs, err = hashPointToR(publics)
53+
if err != nil {
54+
return nil, fmt.Errorf("failed to hash public keys: %w", err)
55+
}
56+
57+
m.publicTerms = make([]kyber.Point, len(publics))
58+
for i, pub := range publics {
59+
pubC := pub.Clone().Mul(m.publicCoefs[i], pub)
60+
m.publicTerms[i] = pubC.Add(pubC, pub)
61+
}
62+
3663
return m, nil
3764
}
3865

@@ -58,6 +85,17 @@ func (m *Mask) SetMask(mask []byte) error {
5885
return nil
5986
}
6087

88+
// GetBit returns true if the given bit is set.
89+
func (m *Mask) GetBit(i int) (bool, error) {
90+
if i >= len(m.publics) || i < 0 {
91+
return false, errors.New("index out of range")
92+
}
93+
94+
byteIndex := i / 8
95+
mask := byte(1) << uint(i&7)
96+
return m.mask[byteIndex]&mask != 0, nil
97+
}
98+
6199
// SetBit turns on or off the bit at the given index.
62100
func (m *Mask) SetBit(i int, enable bool) error {
63101
if i >= len(m.publics) || i < 0 {
@@ -170,3 +208,14 @@ func (m *Mask) Merge(mask []byte) error {
170208

171209
return nil
172210
}
211+
212+
// Clone copies the mask while keeping the precomputed coefficients, etc. This method is thread safe
213+
// and does not modify the original mask. Modifications to the new Mask will not affect the original.
214+
func (m *Mask) Clone() *Mask {
215+
return &Mask{
216+
mask: slices.Clone(m.mask),
217+
publics: m.publics,
218+
publicCoefs: m.publicCoefs,
219+
publicTerms: m.publicTerms,
220+
}
221+
}

0 commit comments

Comments
 (0)