@@ -296,6 +296,84 @@ impl ConstantPool {
296296        . map ( SmolStr :: new) 
297297    } 
298298
299+     // Generate a roughly valid datetime string 
300+     fn  arbitrary_datetime_str_inner ( & self ,  u :  & mut  Unstructured < ' _ > )  -> Result < String >  { 
301+         let  mut  result = String :: new ( ) ; 
302+         // Generate YYYY-MM-DD 
303+         let  y = u. int_in_range ( 0 ..=9999 ) ?; 
304+         let  m = u. int_in_range ( 1 ..=12 ) ?; 
305+         let  d = u. int_in_range ( 1 ..=31 ) ?; 
306+         result. push_str ( & format ! ( "{:04}-{:02}-{:02}" ,  y,  m,  d) ) ; 
307+         // There's a 25% chance where just a date is generated 
308+         if  u. ratio ( 1 ,  4 ) ? { 
309+             return  Ok ( result) ; 
310+         } 
311+         // Generate hh:mm:ss 
312+         result. push ( 'T' ) ; 
313+         let  h = u. int_in_range ( 0 ..=23 ) ?; 
314+         let  m = u. int_in_range ( 0 ..=59 ) ?; 
315+         let  s = u. int_in_range ( 0 ..=59 ) ?; 
316+         result. push_str ( & format ! ( "{:02}:{:02}:{:02}" ,  h,  m,  s) ) ; 
317+         match  u. int_in_range ( 0 ..=3 ) ? { 
318+             0  => { 
319+                 // end the string with `Z` 
320+                 result. push ( 'Z' ) ; 
321+             } 
322+             1  => { 
323+                 // Generate a millisecond and end the string 
324+                 let  ms = u. int_in_range ( 0 ..=999 ) ?; 
325+                 result. push_str ( & format ! ( ".{:03}Z" ,  ms) ) ; 
326+             } 
327+             2  => { 
328+                 // Generate an offset 
329+                 let  sign = if  u. arbitrary ( ) ? {  '+'  }  else  {  '-'  } ; 
330+                 let  hh = u. int_in_range ( 0 ..=14 ) ?; 
331+                 let  mm = u. int_in_range ( 0 ..=59 ) ?; 
332+                 result. push_str ( & format ! ( "{sign}{:02}{:02}" ,  hh,  mm) ) ; 
333+             } 
334+             3  => { 
335+                 // Generate a millisecond and an offset 
336+                 let  ms = u. int_in_range ( 0 ..=999 ) ?; 
337+                 let  sign = if  u. arbitrary ( ) ? {  '+'  }  else  {  '-'  } ; 
338+                 let  hh = u. int_in_range ( 0 ..=14 ) ?; 
339+                 let  mm = u. int_in_range ( 0 ..=59 ) ?; 
340+                 result. push_str ( & format ! ( ".{:03}{sign}{:02}{:02}" ,  ms,  hh,  mm) ) ; 
341+             } 
342+             _ => { 
343+                 unreachable ! ( "the number is from 0 to 3" ) 
344+             } 
345+         } 
346+         Ok ( result) 
347+     } 
348+ 
349+     /// Generate a roughly valid datetime string and mutate it 
350+ pub  fn  arbitrary_datetime_str ( & self ,  u :  & mut  Unstructured < ' _ > )  -> Result < SmolStr >  { 
351+         let  result = self . arbitrary_datetime_str_inner ( u) ?; 
352+         mutate_str ( u,  & result) . map ( Into :: into) 
353+     } 
354+ 
355+     /// Generate a roughly valid duration string and mutate it 
356+ pub  fn  arbitrary_duration_str ( & self ,  u :  & mut  Unstructured < ' _ > )  -> Result < SmolStr >  { 
357+         let  mut  result = String :: new ( ) ; 
358+         // flip a coin and add `-` 
359+         if  u. arbitrary ( ) ? { 
360+             result. push ( '-' ) ; 
361+         } 
362+         for  suffix in  [ "d" ,  "h" ,  "m" ,  "s" ,  "ms" ]  { 
363+             // Generate a number with certain suffix 
364+             if  u. arbitrary ( ) ? { 
365+                 let  i = self . arbitrary_int_constant ( u) ?; 
366+                 result. push_str ( & format ! ( "{}{suffix}" ,  ( i as  i128 ) . abs( ) ) ) ; 
367+             } 
368+         } 
369+         // If none of the units is generated, generate a random milliseconds 
370+         if  result. is_empty ( )  || result == "-"  { 
371+             let  i = self . arbitrary_int_constant ( u) ?; 
372+             result. push_str ( & format ! ( "{}ms" ,  ( i as  i128 ) . abs( ) ) ) ; 
373+         } 
374+         mutate_str ( u,  & result) . map ( Into :: into) 
375+     } 
376+ 
299377    /// size hint for arbitrary_string_constant() 
300378pub  fn  arbitrary_string_constant_size_hint ( _depth :  usize )  -> ( usize ,  Option < usize > )  { 
301379        size_hint_for_choose ( None ) 
@@ -430,6 +508,83 @@ impl AvailableExtensionFunctions {
430508                    parameter_types:  vec![ Type :: decimal( ) ,  Type :: decimal( ) ] , 
431509                    return_ty:  Type :: bool ( ) , 
432510                } , 
511+                 AvailableExtensionFunction  { 
512+                     name:  Name :: parse_unqualified_name( "datetime" ) 
513+                         . expect( "should be a valid identifier" ) , 
514+                     is_constructor:  true , 
515+                     parameter_types:  vec![ Type :: string( ) ] , 
516+                     return_ty:  Type :: datetime( ) , 
517+                 } , 
518+                 AvailableExtensionFunction  { 
519+                     name:  Name :: parse_unqualified_name( "offset" ) 
520+                         . expect( "should be a valid identifier" ) , 
521+                     is_constructor:  false , 
522+                     parameter_types:  vec![ Type :: datetime( ) ,  Type :: duration( ) ] , 
523+                     return_ty:  Type :: datetime( ) , 
524+                 } , 
525+                 AvailableExtensionFunction  { 
526+                     name:  Name :: parse_unqualified_name( "durationSince" ) 
527+                         . expect( "should be a valid identifier" ) , 
528+                     is_constructor:  false , 
529+                     parameter_types:  vec![ Type :: datetime( ) ,  Type :: datetime( ) ] , 
530+                     return_ty:  Type :: duration( ) , 
531+                 } , 
532+                 AvailableExtensionFunction  { 
533+                     name:  Name :: parse_unqualified_name( "toDate" ) 
534+                         . expect( "should be a valid identifier" ) , 
535+                     is_constructor:  false , 
536+                     parameter_types:  vec![ Type :: datetime( ) ] , 
537+                     return_ty:  Type :: datetime( ) , 
538+                 } , 
539+                 AvailableExtensionFunction  { 
540+                     name:  Name :: parse_unqualified_name( "toTime" ) 
541+                         . expect( "should be a valid identifier" ) , 
542+                     is_constructor:  false , 
543+                     parameter_types:  vec![ Type :: datetime( ) ] , 
544+                     return_ty:  Type :: duration( ) , 
545+                 } , 
546+                 AvailableExtensionFunction  { 
547+                     name:  Name :: parse_unqualified_name( "duration" ) 
548+                         . expect( "should be a valid identifier" ) , 
549+                     is_constructor:  true , 
550+                     parameter_types:  vec![ Type :: string( ) ] , 
551+                     return_ty:  Type :: duration( ) , 
552+                 } , 
553+                 AvailableExtensionFunction  { 
554+                     name:  Name :: parse_unqualified_name( "toMilliseconds" ) 
555+                         . expect( "should be a valid identifier" ) , 
556+                     is_constructor:  false , 
557+                     parameter_types:  vec![ Type :: duration( ) ] , 
558+                     return_ty:  Type :: long( ) , 
559+                 } , 
560+                 AvailableExtensionFunction  { 
561+                     name:  Name :: parse_unqualified_name( "toSeconds" ) 
562+                         . expect( "should be a valid identifier" ) , 
563+                     is_constructor:  false , 
564+                     parameter_types:  vec![ Type :: duration( ) ] , 
565+                     return_ty:  Type :: long( ) , 
566+                 } , 
567+                 AvailableExtensionFunction  { 
568+                     name:  Name :: parse_unqualified_name( "toMinutes" ) 
569+                         . expect( "should be a valid identifier" ) , 
570+                     is_constructor:  false , 
571+                     parameter_types:  vec![ Type :: duration( ) ] , 
572+                     return_ty:  Type :: long( ) , 
573+                 } , 
574+                 AvailableExtensionFunction  { 
575+                     name:  Name :: parse_unqualified_name( "toHours" ) 
576+                         . expect( "should be a valid identifier" ) , 
577+                     is_constructor:  false , 
578+                     parameter_types:  vec![ Type :: duration( ) ] , 
579+                     return_ty:  Type :: long( ) , 
580+                 } , 
581+                 AvailableExtensionFunction  { 
582+                     name:  Name :: parse_unqualified_name( "toDays" ) 
583+                         . expect( "should be a valid identifier" ) , 
584+                     is_constructor:  false , 
585+                     parameter_types:  vec![ Type :: duration( ) ] , 
586+                     return_ty:  Type :: long( ) , 
587+                 } , 
433588            ] 
434589        }  else  { 
435590            vec ! [ ] 
@@ -566,6 +721,10 @@ pub enum Type {
566721IPAddr , 
567722    /// Decimal numbers 
568723Decimal , 
724+     /// datetime 
725+ DateTime , 
726+     /// duration 
727+ Duration , 
569728} 
570729
571730impl  Type  { 
@@ -606,6 +765,14 @@ impl Type {
606765pub  fn  decimal ( )  -> Self  { 
607766        Type :: Decimal 
608767    } 
768+     /// datetime type 
769+ pub  fn  datetime ( )  -> Self  { 
770+         Type :: DateTime 
771+     } 
772+     /// duration type 
773+ pub  fn  duration ( )  -> Self  { 
774+         Type :: Duration 
775+     } 
609776
610777    /// `Type` has `Arbitrary` auto-derived for it, but for the case where you 
611778/// want "any nonextension Type", you have this 
@@ -639,6 +806,12 @@ impl TryFrom<Type> for ast::Type {
639806            Type :: Decimal  => Ok ( ast:: Type :: Extension  { 
640807                name :  extensions:: decimal:: extension ( ) . name ( ) . clone ( ) , 
641808            } ) , 
809+             Type :: DateTime  => Ok ( ast:: Type :: Extension  { 
810+                 name :  extensions:: datetime:: extension ( ) . name ( ) . clone ( ) , 
811+             } ) , 
812+             Type :: Duration  => Ok ( ast:: Type :: Extension  { 
813+                 name :  "duration" . parse ( ) . unwrap ( ) , 
814+             } ) , 
642815        } 
643816    } 
644817} 
@@ -767,3 +940,86 @@ impl From<ABACRequest> for ast::Request {
767940        abac. 0 . into ( ) 
768941    } 
769942} 
943+ 
944+ #[ cfg( test) ]  
945+ mod  tests { 
946+     use  std:: { collections:: HashMap ,  sync:: Arc } ; 
947+ 
948+     use  arbitrary:: { Arbitrary ,  Unstructured } ; 
949+     use  cedar_policy_core:: { 
950+         ast:: { Expr ,  Name ,  Request ,  Value } , 
951+         entities:: Entities , 
952+         evaluator:: Evaluator , 
953+         extensions:: Extensions , 
954+     } ; 
955+     use  rand:: { rngs:: StdRng ,  RngCore ,  SeedableRng } ; 
956+     use  smol_str:: SmolStr ; 
957+ 
958+     use  super :: ConstantPool ; 
959+ 
960+     // get validate string count 
961+     #[ track_caller]  
962+     fn  evaluate_batch ( strs :  & [ SmolStr ] ,  constructor :  Name )  -> usize  { 
963+         let  dummy_euid:  Arc < cedar_policy_core:: ast:: EntityUID >  =
964+             Arc :: new ( r#"A::"""# . parse ( ) . unwrap ( ) ) ; 
965+         let  dummy_request = Request :: new_unchecked ( 
966+             cedar_policy_core:: ast:: EntityUIDEntry :: Known  { 
967+                 euid :  dummy_euid. clone ( ) , 
968+                 loc :  None , 
969+             } , 
970+             cedar_policy_core:: ast:: EntityUIDEntry :: Known  { 
971+                 euid :  dummy_euid. clone ( ) , 
972+                 loc :  None , 
973+             } , 
974+             cedar_policy_core:: ast:: EntityUIDEntry :: Known  { 
975+                 euid :  dummy_euid. clone ( ) , 
976+                 loc :  None , 
977+             } , 
978+             None , 
979+         ) ; 
980+         let  entities = Entities :: new ( ) ; 
981+         let  evaluator = Evaluator :: new ( dummy_request,  & entities,  Extensions :: all_available ( ) ) ; 
982+         let  valid_strs:  Vec < _ >  = strs
983+             . into_iter ( ) 
984+             . filter ( |s| { 
985+                 evaluator
986+                     . interpret ( 
987+                         & Expr :: call_extension_fn ( 
988+                             constructor. clone ( ) , 
989+                             vec ! [ Value :: from( s. to_owned( ) . to_owned( ) ) . into( ) ] , 
990+                         ) , 
991+                         & HashMap :: new ( ) , 
992+                     ) 
993+                     . is_ok ( ) 
994+             } ) 
995+             . collect ( ) ; 
996+         valid_strs. len ( ) 
997+     } 
998+ 
999+     #[ test]  
1000+     fn  test_valid_extension_value_ratio ( )  { 
1001+         let  mut  rng = StdRng :: seed_from_u64 ( 666 ) ; 
1002+         let  mut  bytes = [ 0 ;  4096 ] ; 
1003+         rng. fill_bytes ( & mut  bytes) ; 
1004+         let  mut  u = Unstructured :: new ( & bytes) ; 
1005+         let  pool = ConstantPool :: arbitrary ( & mut  u) . expect ( "should not fail" ) ; 
1006+ 
1007+         let  datetime_strs:  Vec < _ >  = ( 0 ..100 ) 
1008+             . map ( |_| { 
1009+                 pool. arbitrary_datetime_str ( & mut  u) 
1010+                     . expect ( "should not fail" ) 
1011+             } ) 
1012+             . collect ( ) ; 
1013+         let  valid_datetime_count = evaluate_batch ( & datetime_strs,  "datetime" . parse ( ) . unwrap ( ) ) ; 
1014+         println ! ( "{}" ,  valid_datetime_count) ; 
1015+ 
1016+         let  duration_strs:  Vec < _ >  = ( 0 ..100 ) 
1017+             . map ( |_| { 
1018+                 pool. arbitrary_duration_str ( & mut  u) 
1019+                     . expect ( "should not fail" ) 
1020+             } ) 
1021+             . collect ( ) ; 
1022+         let  valid_duration_count = evaluate_batch ( & duration_strs,  "duration" . parse ( ) . unwrap ( ) ) ; 
1023+         println ! ( "{}" ,  valid_duration_count) ; 
1024+     } 
1025+ } 
0 commit comments