@@ -38,10 +38,10 @@ const pruneKeyBufferSize = 1000
38
38
39
39
// BlockState defines fields for manipulating the state of blocks, such as BlockTree, BlockDB and Header
40
40
type BlockState struct {
41
- bt * blocktree.BlockTree
42
- baseDB chaindb.Database
43
- db chaindb.Database
44
- lock sync.RWMutex
41
+ bt * blocktree.BlockTree
42
+ baseDB chaindb.Database
43
+ db chaindb.Database
44
+ sync.RWMutex
45
45
genesisHash common.Hash
46
46
47
47
// block notifiers
@@ -268,7 +268,7 @@ func (bs *BlockState) GetHeader(hash common.Hash) (*types.Header, error) {
268
268
func (bs * BlockState ) GetHashByNumber (num * big.Int ) (common.Hash , error ) {
269
269
bh , err := bs .db .Get (headerHashKey (num .Uint64 ()))
270
270
if err != nil {
271
- return common.Hash {}, fmt .Errorf ("cannot get block %d: %s " , num , err )
271
+ return common.Hash {}, fmt .Errorf ("cannot get block %d: %w " , num , err )
272
272
}
273
273
274
274
return common .NewHash (bh ), nil
@@ -278,7 +278,7 @@ func (bs *BlockState) GetHashByNumber(num *big.Int) (common.Hash, error) {
278
278
func (bs * BlockState ) GetHeaderByNumber (num * big.Int ) (* types.Header , error ) {
279
279
bh , err := bs .db .Get (headerHashKey (num .Uint64 ()))
280
280
if err != nil {
281
- return nil , fmt .Errorf ("cannot get block %d: %s " , num , err )
281
+ return nil , fmt .Errorf ("cannot get block %d: %w " , num , err )
282
282
}
283
283
284
284
hash := common .NewHash (bh )
@@ -304,7 +304,7 @@ func (bs *BlockState) GetBlockByNumber(num *big.Int) (*types.Block, error) {
304
304
// First retrieve the block hash in a byte array based on the block number from the database
305
305
byteHash , err := bs .db .Get (headerHashKey (num .Uint64 ()))
306
306
if err != nil {
307
- return nil , fmt .Errorf ("cannot get block %d: %s " , num , err )
307
+ return nil , fmt .Errorf ("cannot get block %d: %w " , num , err )
308
308
}
309
309
310
310
// Then find the block based on the hash
@@ -322,17 +322,14 @@ func (bs *BlockState) GetBlockHash(blockNumber *big.Int) (*common.Hash, error) {
322
322
// First retrieve the block hash in a byte array based on the block number from the database
323
323
byteHash , err := bs .db .Get (headerHashKey (blockNumber .Uint64 ()))
324
324
if err != nil {
325
- return nil , fmt .Errorf ("cannot get block %d: %s " , blockNumber , err )
325
+ return nil , fmt .Errorf ("cannot get block %d: %w " , blockNumber , err )
326
326
}
327
327
hash := common .NewHash (byteHash )
328
328
return & hash , nil
329
329
}
330
330
331
331
// SetHeader will set the header into DB
332
332
func (bs * BlockState ) SetHeader (header * types.Header ) error {
333
- bs .lock .Lock ()
334
- defer bs .lock .Unlock ()
335
-
336
333
hash := header .Hash ()
337
334
338
335
// Write the encoded header
@@ -366,11 +363,7 @@ func (bs *BlockState) GetBlockBody(hash common.Hash) (*types.Body, error) {
366
363
367
364
// SetBlockBody will add a block body to the db
368
365
func (bs * BlockState ) SetBlockBody (hash common.Hash , body * types.Body ) error {
369
- bs .lock .Lock ()
370
- defer bs .lock .Unlock ()
371
-
372
- err := bs .db .Put (blockBodyKey (hash ), body .AsOptional ().Value ())
373
- return err
366
+ return bs .db .Put (blockBodyKey (hash ), body .AsOptional ().Value ())
374
367
}
375
368
376
369
// HasFinalizedBlock returns true if there is a finalized block for a given round and setID, false otherwise
@@ -427,6 +420,9 @@ func (bs *BlockState) GetFinalizedHash(round, setID uint64) (common.Hash, error)
427
420
428
421
// SetFinalizedHash sets the latest finalized block header
429
422
func (bs * BlockState ) SetFinalizedHash (hash common.Hash , round , setID uint64 ) error {
423
+ bs .Lock ()
424
+ defer bs .Unlock ()
425
+
430
426
go bs .notifyFinalized (hash )
431
427
if round > 0 {
432
428
err := bs .SetRound (round )
@@ -496,6 +492,8 @@ func (bs *BlockState) CompareAndSetBlockData(bd *types.BlockData) error {
496
492
497
493
// AddBlock adds a block to the blocktree and the DB with arrival time as current unix time
498
494
func (bs * BlockState ) AddBlock (block * types.Block ) error {
495
+ bs .Lock ()
496
+ defer bs .Unlock ()
499
497
return bs .AddBlockWithArrivalTime (block , time .Now ())
500
498
}
501
499
@@ -506,6 +504,8 @@ func (bs *BlockState) AddBlockWithArrivalTime(block *types.Block, arrivalTime ti
506
504
return err
507
505
}
508
506
507
+ prevHead := bs .bt .DeepestBlockHash ()
508
+
509
509
// add block to blocktree
510
510
err = bs .bt .AddBlock (block .Header , uint64 (arrivalTime .UnixNano ()))
511
511
if err != nil {
@@ -541,12 +541,58 @@ func (bs *BlockState) AddBlockWithArrivalTime(block *types.Block, arrivalTime ti
541
541
return err
542
542
}
543
543
544
+ // check if there was a re-org, if so, re-set the canonical number->hash mapping
545
+ err = bs .handleAddedBlock (prevHead , bs .bt .DeepestBlockHash ())
546
+ if err != nil {
547
+ return err
548
+ }
549
+
544
550
go bs .notifyImported (block )
545
551
return bs .baseDB .Flush ()
546
552
}
547
553
554
+ // handleAddedBlock re-sets the canonical number->hash mapping if there was a chain re-org.
555
+ // prev is the previous best block hash before the new block was added to the blocktree.
556
+ // curr is the current best blogetck hash.
557
+ func (bs * BlockState ) handleAddedBlock (prev , curr common.Hash ) error {
558
+ ancestor , err := bs .HighestCommonAncestor (prev , curr )
559
+ if err != nil {
560
+ return err
561
+ }
562
+
563
+ // if the highest common ancestor of the previous chain head and current chain head is the previous chain head,
564
+ // then the current chain head is the descendant of the previous and thus are on the same chain
565
+ if ancestor == prev {
566
+ return nil
567
+ }
568
+
569
+ subchain , err := bs .SubChain (ancestor , curr )
570
+ if err != nil {
571
+ return err
572
+ }
573
+
574
+ batch := bs .db .NewBatch ()
575
+ for _ , hash := range subchain {
576
+ // TODO: set number from ancestor.Number + i ?
577
+ header , err := bs .GetHeader (hash )
578
+ if err != nil {
579
+ return fmt .Errorf ("failed to get header in subchain: %w" , err )
580
+ }
581
+
582
+ err = batch .Put (headerHashKey (header .Number .Uint64 ()), hash .ToBytes ())
583
+ if err != nil {
584
+ return err
585
+ }
586
+ }
587
+
588
+ return batch .Flush ()
589
+ }
590
+
548
591
// AddBlockToBlockTree adds the given block to the blocktree. It does not write it to the database.
549
592
func (bs * BlockState ) AddBlockToBlockTree (header * types.Header ) error {
593
+ bs .Lock ()
594
+ defer bs .Unlock ()
595
+
550
596
arrivalTime , err := bs .GetArrivalTime (header .Hash ())
551
597
if err != nil {
552
598
arrivalTime = time .Now ()
@@ -567,7 +613,7 @@ func (bs *BlockState) isBlockOnCurrentChain(header *types.Header) (bool, error)
567
613
}
568
614
569
615
// if the new block is ahead of our best block, then it is on our current chain.
570
- if header .Number .Cmp (bestBlock .Number ) == 1 {
616
+ if header .Number .Cmp (bestBlock .Number ) > 0 {
571
617
return true , nil
572
618
}
573
619
0 commit comments