Skip to content

Commit 88014c6

Browse files
qdm12timwu20
authored andcommitted
chore(trie): lib/trie/hash.go tests (ChainSafe#2049)
1 parent 81c245b commit 88014c6

File tree

8 files changed

+1446
-90
lines changed

8 files changed

+1446
-90
lines changed

lib/trie/bytesBuffer_mock_test.go

Lines changed: 77 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/trie/hash.go

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ func hashNode(n node, digestBuffer io.Writer) (err error) {
5757
// if length of encoded leaf is less than 32 bytes, do not hash
5858
if encodingBuffer.Len() < 32 {
5959
_, err = digestBuffer.Write(encodingBuffer.Bytes())
60-
return err
60+
if err != nil {
61+
return fmt.Errorf("cannot write encoded node to buffer: %w", err)
62+
}
63+
return nil
6164
}
6265

6366
// otherwise, hash encoded node
@@ -72,16 +75,26 @@ func hashNode(n node, digestBuffer io.Writer) (err error) {
7275
}
7376

7477
_, err = digestBuffer.Write(hasher.Sum(nil))
75-
return err
78+
if err != nil {
79+
return fmt.Errorf("cannot write hash sum of node to buffer: %w", err)
80+
}
81+
return nil
7682
}
7783

7884
var ErrNodeTypeUnsupported = errors.New("node type is not supported")
7985

86+
type bytesBuffer interface {
87+
// note: cannot compose with io.Writer for mock generation
88+
Write(p []byte) (n int, err error)
89+
Len() int
90+
Bytes() []byte
91+
}
92+
8093
// encodeNode writes the encoding of the node to the buffer given.
8194
// It is the high-level function wrapping the encoding for different
8295
// node types. The encoding has the following format:
8396
// NodeHeader | Extra partial key length | Partial Key | Value
84-
func encodeNode(n node, buffer *bytes.Buffer, parallel bool) (err error) {
97+
func encodeNode(n node, buffer bytesBuffer, parallel bool) (err error) {
8598
switch n := n.(type) {
8699
case *branch:
87100
err := encodeBranch(n, buffer, parallel)
@@ -104,70 +117,58 @@ func encodeNode(n node, buffer *bytes.Buffer, parallel bool) (err error) {
104117
copy(n.encoding, buffer.Bytes())
105118
return nil
106119
case nil:
107-
buffer.Write([]byte{0})
120+
_, err := buffer.Write([]byte{0})
121+
if err != nil {
122+
return fmt.Errorf("cannot encode nil node: %w", err)
123+
}
108124
return nil
109125
default:
110126
return fmt.Errorf("%w: %T", ErrNodeTypeUnsupported, n)
111127
}
112128
}
113129

114-
func encodeAndHash(n node) ([]byte, error) {
115-
buffer := digestBufferPool.Get().(*bytes.Buffer)
116-
buffer.Reset()
117-
defer digestBufferPool.Put(buffer)
118-
119-
err := hashNode(n, buffer)
120-
if err != nil {
121-
return nil, err
122-
}
123-
124-
scEncChild, err := scale.Marshal(buffer.Bytes())
125-
if err != nil {
126-
return nil, err
127-
}
128-
return scEncChild, nil
129-
}
130-
131130
// encodeBranch encodes a branch with the encoding specified at the top of this package
132131
// to the buffer given.
133132
func encodeBranch(b *branch, buffer io.Writer, parallel bool) (err error) {
134133
if !b.dirty && b.encoding != nil {
135134
_, err = buffer.Write(b.encoding)
136135
if err != nil {
137-
return fmt.Errorf("cannot write stored encoded branch to buffer: %w", err)
136+
return fmt.Errorf("cannot write stored encoding to buffer: %w", err)
138137
}
139138
return nil
140139
}
141140

142-
encoding, err := b.header()
141+
encodedHeader, err := b.header()
143142
if err != nil {
144-
return fmt.Errorf("cannot encode branch header: %w", err)
143+
return fmt.Errorf("cannot encode header: %w", err)
145144
}
146145

147-
_, err = buffer.Write(encoding)
146+
_, err = buffer.Write(encodedHeader)
148147
if err != nil {
149-
return fmt.Errorf("cannot write encoded branch header to buffer: %w", err)
148+
return fmt.Errorf("cannot write encoded header to buffer: %w", err)
150149
}
151150

152-
_, err = buffer.Write(nibblesToKeyLE(b.key))
151+
keyLE := nibblesToKeyLE(b.key)
152+
_, err = buffer.Write(keyLE)
153153
if err != nil {
154-
return fmt.Errorf("cannot write encoded branch key to buffer: %w", err)
154+
return fmt.Errorf("cannot write encoded key to buffer: %w", err)
155155
}
156156

157-
_, err = buffer.Write(common.Uint16ToBytes(b.childrenBitmap()))
157+
childrenBitmap := common.Uint16ToBytes(b.childrenBitmap())
158+
_, err = buffer.Write(childrenBitmap)
158159
if err != nil {
159-
return fmt.Errorf("cannot write branch children bitmap to buffer: %w", err)
160+
return fmt.Errorf("cannot write children bitmap to buffer: %w", err)
160161
}
161162

162163
if b.value != nil {
163164
bytes, err := scale.Marshal(b.value)
164165
if err != nil {
165-
return fmt.Errorf("cannot scale encode branch value: %w", err)
166+
return fmt.Errorf("cannot scale encode value: %w", err)
166167
}
167168

168169
_, err = buffer.Write(bytes)
169170
if err != nil {
170-
return fmt.Errorf("cannot write encoded branch value to buffer: %w", err)
171+
return fmt.Errorf("cannot write encoded value to buffer: %w", err)
171172
}
172173
}
173174

@@ -222,10 +223,15 @@ func encodeChildrenInParallel(children [16]node, buffer io.Writer) (err error) {
222223
// write as many completed buffers to the result buffer.
223224
for currentIndex < len(children) &&
224225
resultBuffers[currentIndex] != nil {
225-
// note buffer.Write copies the byte slice given as argument
226-
_, writeErr := buffer.Write(resultBuffers[currentIndex].Bytes())
227-
if writeErr != nil && err == nil {
228-
err = writeErr
226+
bufferSlice := resultBuffers[currentIndex].Bytes()
227+
if len(bufferSlice) > 0 {
228+
// note buffer.Write copies the byte slice given as argument
229+
_, writeErr := buffer.Write(bufferSlice)
230+
if writeErr != nil && err == nil {
231+
err = fmt.Errorf(
232+
"cannot write encoding of child at index %d: %w",
233+
currentIndex, writeErr)
234+
}
229235
}
230236

231237
encodingBufferPool.Put(resultBuffers[currentIndex])
@@ -246,17 +252,26 @@ func encodeChildrenInParallel(children [16]node, buffer io.Writer) (err error) {
246252
}
247253

248254
func encodeChildrenSequentially(children [16]node, buffer io.Writer) (err error) {
249-
for _, child := range children {
255+
for i, child := range children {
250256
err = encodeChild(child, buffer)
251257
if err != nil {
252-
return err
258+
return fmt.Errorf("cannot encode child at index %d: %w", i, err)
253259
}
254260
}
255261
return nil
256262
}
257263

258264
func encodeChild(child node, buffer io.Writer) (err error) {
259-
if child == nil {
265+
var isNil bool
266+
switch impl := child.(type) {
267+
case *branch:
268+
isNil = impl == nil
269+
case *leaf:
270+
isNil = impl == nil
271+
default:
272+
isNil = child == nil
273+
}
274+
if isNil {
260275
return nil
261276
}
262277

@@ -273,6 +288,23 @@ func encodeChild(child node, buffer io.Writer) (err error) {
273288
return nil
274289
}
275290

291+
func encodeAndHash(n node) (b []byte, err error) {
292+
buffer := digestBufferPool.Get().(*bytes.Buffer)
293+
buffer.Reset()
294+
defer digestBufferPool.Put(buffer)
295+
296+
err = hashNode(n, buffer)
297+
if err != nil {
298+
return nil, fmt.Errorf("cannot hash node: %w", err)
299+
}
300+
301+
scEncChild, err := scale.Marshal(buffer.Bytes())
302+
if err != nil {
303+
return nil, fmt.Errorf("cannot scale encode hashed node: %w", err)
304+
}
305+
return scEncChild, nil
306+
}
307+
276308
// encodeLeaf encodes a leaf to the buffer given, with the encoding
277309
// specified at the top of this package.
278310
func encodeLeaf(l *leaf, buffer io.Writer) (err error) {
@@ -286,27 +318,28 @@ func encodeLeaf(l *leaf, buffer io.Writer) (err error) {
286318
return nil
287319
}
288320

289-
encoding, err := l.header()
321+
encodedHeader, err := l.header()
290322
if err != nil {
291323
return fmt.Errorf("cannot encode header: %w", err)
292324
}
293325

294-
_, err = buffer.Write(encoding)
326+
_, err = buffer.Write(encodedHeader)
295327
if err != nil {
296328
return fmt.Errorf("cannot write encoded header to buffer: %w", err)
297329
}
298330

299-
_, err = buffer.Write(nibblesToKeyLE(l.key))
331+
keyLE := nibblesToKeyLE(l.key)
332+
_, err = buffer.Write(keyLE)
300333
if err != nil {
301334
return fmt.Errorf("cannot write LE key to buffer: %w", err)
302335
}
303336

304-
bytes, err := scale.Marshal(l.value) // TODO scale encoder to write to buffer
337+
encodedValue, err := scale.Marshal(l.value) // TODO scale encoder to write to buffer
305338
if err != nil {
306-
return err
339+
return fmt.Errorf("cannot scale marshal value: %w", err)
307340
}
308341

309-
_, err = buffer.Write(bytes)
342+
_, err = buffer.Write(encodedValue)
310343
if err != nil {
311344
return fmt.Errorf("cannot write scale encoded value to buffer: %w", err)
312345
}

0 commit comments

Comments
 (0)