1
+ #![ deny( clippy:: arithmetic_side_effects) ]
2
+
1
3
use super :: config:: RateLimiterConfig ;
2
4
use crate :: rpc:: Protocol ;
3
5
use fnv:: FnvHashMap ;
4
6
use libp2p:: PeerId ;
5
7
use serde:: { Deserialize , Serialize } ;
6
8
use std:: future:: Future ;
7
9
use std:: hash:: Hash ;
10
+ use std:: num:: NonZeroU64 ;
8
11
use std:: pin:: Pin ;
9
12
use std:: sync:: Arc ;
10
13
use std:: task:: { Context , Poll } ;
@@ -55,20 +58,20 @@ pub struct Quota {
55
58
pub ( super ) replenish_all_every : Duration ,
56
59
/// Token limit. This translates on how large can an instantaneous batch of
57
60
/// tokens be.
58
- pub ( super ) max_tokens : u64 ,
61
+ pub ( super ) max_tokens : NonZeroU64 ,
59
62
}
60
63
61
64
impl Quota {
62
65
/// A hard limit of one token every `seconds`.
63
66
pub const fn one_every ( seconds : u64 ) -> Self {
64
67
Quota {
65
68
replenish_all_every : Duration :: from_secs ( seconds) ,
66
- max_tokens : 1 ,
69
+ max_tokens : NonZeroU64 :: new ( 1 ) . unwrap ( ) ,
67
70
}
68
71
}
69
72
70
73
/// Allow `n` tokens to be use used every `seconds`.
71
- pub const fn n_every ( n : u64 , seconds : u64 ) -> Self {
74
+ pub const fn n_every ( n : NonZeroU64 , seconds : u64 ) -> Self {
72
75
Quota {
73
76
replenish_all_every : Duration :: from_secs ( seconds) ,
74
77
max_tokens : n,
@@ -236,7 +239,9 @@ impl RPCRateLimiterBuilder {
236
239
237
240
// check for peers to prune every 30 seconds, starting in 30 seconds
238
241
let prune_every = tokio:: time:: Duration :: from_secs ( 30 ) ;
239
- let prune_start = tokio:: time:: Instant :: now ( ) + prune_every;
242
+ let prune_start = tokio:: time:: Instant :: now ( )
243
+ . checked_add ( prune_every)
244
+ . ok_or ( "prune time overflow" ) ?;
240
245
let prune_interval = tokio:: time:: interval_at ( prune_start, prune_every) ;
241
246
Ok ( RPCRateLimiter {
242
247
prune_interval,
@@ -412,14 +417,13 @@ pub struct Limiter<Key: Hash + Eq + Clone> {
412
417
413
418
impl < Key : Hash + Eq + Clone > Limiter < Key > {
414
419
pub fn from_quota ( quota : Quota ) -> Result < Self , & ' static str > {
415
- if quota. max_tokens == 0 {
416
- return Err ( "Max number of tokens should be positive" ) ;
417
- }
418
420
let tau = quota. replenish_all_every . as_nanos ( ) ;
419
421
if tau == 0 {
420
422
return Err ( "Replenish time must be positive" ) ;
421
423
}
422
- let t = ( tau / quota. max_tokens as u128 )
424
+ let t = tau
425
+ . checked_div ( quota. max_tokens . get ( ) as u128 )
426
+ . expect ( "Division by zero never occurs, since Quota::max_token is of type NonZeroU64." )
423
427
. try_into ( )
424
428
. map_err ( |_| "total replenish time is too long" ) ?;
425
429
let tau = tau
@@ -442,7 +446,7 @@ impl<Key: Hash + Eq + Clone> Limiter<Key> {
442
446
let tau = self . tau ;
443
447
let t = self . t ;
444
448
// how long does it take to replenish these tokens
445
- let additional_time = t * tokens;
449
+ let additional_time = t. saturating_mul ( tokens) ;
446
450
if additional_time > tau {
447
451
// the time required to process this amount of tokens is longer than the time that
448
452
// makes the bucket full. So, this batch can _never_ be processed
@@ -455,16 +459,16 @@ impl<Key: Hash + Eq + Clone> Limiter<Key> {
455
459
. entry ( key. clone ( ) )
456
460
. or_insert ( time_since_start) ;
457
461
// check how soon could the request be made
458
- let earliest_time = ( * tat + additional_time) . saturating_sub ( tau) ;
462
+ let earliest_time = ( * tat) . saturating_add ( additional_time) . saturating_sub ( tau) ;
459
463
// earliest_time is in the future
460
464
if time_since_start < earliest_time {
461
465
Err ( RateLimitedErr :: TooSoon ( Duration :: from_nanos (
462
466
/* time they need to wait, i.e. how soon were they */
463
- earliest_time - time_since_start,
467
+ earliest_time. saturating_sub ( time_since_start) ,
464
468
) ) )
465
469
} else {
466
470
// calculate the new TAT
467
- * tat = time_since_start. max ( * tat) + additional_time;
471
+ * tat = time_since_start. max ( * tat) . saturating_add ( additional_time) ;
468
472
Ok ( ( ) )
469
473
}
470
474
}
@@ -479,14 +483,15 @@ impl<Key: Hash + Eq + Clone> Limiter<Key> {
479
483
480
484
#[ cfg( test) ]
481
485
mod tests {
482
- use crate :: rpc:: rate_limiter:: { Limiter , Quota } ;
486
+ use crate :: rpc:: rate_limiter:: { Limiter , Quota , RateLimitedErr } ;
487
+ use std:: num:: NonZeroU64 ;
483
488
use std:: time:: Duration ;
484
489
485
490
#[ test]
486
491
fn it_works_a ( ) {
487
492
let mut limiter = Limiter :: from_quota ( Quota {
488
493
replenish_all_every : Duration :: from_secs ( 2 ) ,
489
- max_tokens : 4 ,
494
+ max_tokens : NonZeroU64 :: new ( 4 ) . unwrap ( ) ,
490
495
} )
491
496
. unwrap ( ) ;
492
497
let key = 10 ;
@@ -523,7 +528,7 @@ mod tests {
523
528
fn it_works_b ( ) {
524
529
let mut limiter = Limiter :: from_quota ( Quota {
525
530
replenish_all_every : Duration :: from_secs ( 2 ) ,
526
- max_tokens : 4 ,
531
+ max_tokens : NonZeroU64 :: new ( 4 ) . unwrap ( ) ,
527
532
} )
528
533
. unwrap ( ) ;
529
534
let key = 10 ;
@@ -547,4 +552,22 @@ mod tests {
547
552
. allows( Duration :: from_secs_f32( 0.4 ) , & key, 1 )
548
553
. is_err( ) ) ;
549
554
}
555
+
556
+ #[ test]
557
+ fn large_tokens ( ) {
558
+ // These have been adjusted so that an overflow occurs when calculating `additional_time` in
559
+ // `Limiter::allows`. If we don't handle overflow properly, `Limiter::allows` returns `Ok`
560
+ // in this case.
561
+ let replenish_all_every = 2 ;
562
+ let tokens = u64:: MAX / 2 + 1 ;
563
+
564
+ let mut limiter = Limiter :: from_quota ( Quota {
565
+ replenish_all_every : Duration :: from_nanos ( replenish_all_every) ,
566
+ max_tokens : NonZeroU64 :: new ( 1 ) . unwrap ( ) ,
567
+ } )
568
+ . unwrap ( ) ;
569
+
570
+ let result = limiter. allows ( Duration :: from_secs_f32 ( 0.0 ) , & 10 , tokens) ;
571
+ assert ! ( matches!( result, Err ( RateLimitedErr :: TooLarge ) ) ) ;
572
+ }
550
573
}
0 commit comments