Skip to content

Commit 441004e

Browse files
Changes needed to test datetime extension (#470)
Signed-off-by: Shaobo He <[email protected]>
1 parent 4786741 commit 441004e

File tree

4 files changed

+302
-3
lines changed

4 files changed

+302
-3
lines changed

cedar-policy-generators/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ anyhow = "1.0.72"
1818
nanoid = "0.4.0"
1919
serde_with = "3.4.0"
2020
thiserror = "2.0"
21+
22+
[dev.dependencies]
23+
rand = "0.8.5"

cedar-policy-generators/src/abac.rs

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,84 @@ impl ConstantPool {
296296
.map(SmolStr::new)
297297
}
298298

299+
// Generate a roughly valid datetime string
300+
fn arbitrary_datetime_str_inner(&self, u: &mut Unstructured<'_>) -> Result<String> {
301+
let mut result = String::new();
302+
// Generate YYYY-MM-DD
303+
let y = u.int_in_range(0..=9999)?;
304+
let m = u.int_in_range(1..=12)?;
305+
let d = u.int_in_range(1..=31)?;
306+
result.push_str(&format!("{:04}-{:02}-{:02}", y, m, d));
307+
// There's a 25% chance where just a date is generated
308+
if u.ratio(1, 4)? {
309+
return Ok(result);
310+
}
311+
// Generate hh:mm:ss
312+
result.push('T');
313+
let h = u.int_in_range(0..=23)?;
314+
let m = u.int_in_range(0..=59)?;
315+
let s = u.int_in_range(0..=59)?;
316+
result.push_str(&format!("{:02}:{:02}:{:02}", h, m, s));
317+
match u.int_in_range(0..=3)? {
318+
0 => {
319+
// end the string with `Z`
320+
result.push('Z');
321+
}
322+
1 => {
323+
// Generate a millisecond and end the string
324+
let ms = u.int_in_range(0..=999)?;
325+
result.push_str(&format!(".{:03}Z", ms));
326+
}
327+
2 => {
328+
// Generate an offset
329+
let sign = if u.arbitrary()? { '+' } else { '-' };
330+
let hh = u.int_in_range(0..=14)?;
331+
let mm = u.int_in_range(0..=59)?;
332+
result.push_str(&format!("{sign}{:02}{:02}", hh, mm));
333+
}
334+
3 => {
335+
// Generate a millisecond and an offset
336+
let ms = u.int_in_range(0..=999)?;
337+
let sign = if u.arbitrary()? { '+' } else { '-' };
338+
let hh = u.int_in_range(0..=14)?;
339+
let mm = u.int_in_range(0..=59)?;
340+
result.push_str(&format!(".{:03}{sign}{:02}{:02}", ms, hh, mm));
341+
}
342+
_ => {
343+
unreachable!("the number is from 0 to 3")
344+
}
345+
}
346+
Ok(result)
347+
}
348+
349+
/// Generate a roughly valid datetime string and mutate it
350+
pub fn arbitrary_datetime_str(&self, u: &mut Unstructured<'_>) -> Result<SmolStr> {
351+
let result = self.arbitrary_datetime_str_inner(u)?;
352+
mutate_str(u, &result).map(Into::into)
353+
}
354+
355+
/// Generate a roughly valid duration string and mutate it
356+
pub fn arbitrary_duration_str(&self, u: &mut Unstructured<'_>) -> Result<SmolStr> {
357+
let mut result = String::new();
358+
// flip a coin and add `-`
359+
if u.arbitrary()? {
360+
result.push('-');
361+
}
362+
for suffix in ["d", "h", "m", "s", "ms"] {
363+
// Generate a number with certain suffix
364+
if u.arbitrary()? {
365+
let i = self.arbitrary_int_constant(u)?;
366+
result.push_str(&format!("{}{suffix}", (i as i128).abs()));
367+
}
368+
}
369+
// If none of the units is generated, generate a random milliseconds
370+
if result.is_empty() || result == "-" {
371+
let i = self.arbitrary_int_constant(u)?;
372+
result.push_str(&format!("{}ms", (i as i128).abs()));
373+
}
374+
mutate_str(u, &result).map(Into::into)
375+
}
376+
299377
/// size hint for arbitrary_string_constant()
300378
pub fn arbitrary_string_constant_size_hint(_depth: usize) -> (usize, Option<usize>) {
301379
size_hint_for_choose(None)
@@ -430,6 +508,83 @@ impl AvailableExtensionFunctions {
430508
parameter_types: vec![Type::decimal(), Type::decimal()],
431509
return_ty: Type::bool(),
432510
},
511+
AvailableExtensionFunction {
512+
name: Name::parse_unqualified_name("datetime")
513+
.expect("should be a valid identifier"),
514+
is_constructor: true,
515+
parameter_types: vec![Type::string()],
516+
return_ty: Type::datetime(),
517+
},
518+
AvailableExtensionFunction {
519+
name: Name::parse_unqualified_name("offset")
520+
.expect("should be a valid identifier"),
521+
is_constructor: false,
522+
parameter_types: vec![Type::datetime(), Type::duration()],
523+
return_ty: Type::datetime(),
524+
},
525+
AvailableExtensionFunction {
526+
name: Name::parse_unqualified_name("durationSince")
527+
.expect("should be a valid identifier"),
528+
is_constructor: false,
529+
parameter_types: vec![Type::datetime(), Type::datetime()],
530+
return_ty: Type::duration(),
531+
},
532+
AvailableExtensionFunction {
533+
name: Name::parse_unqualified_name("toDate")
534+
.expect("should be a valid identifier"),
535+
is_constructor: false,
536+
parameter_types: vec![Type::datetime()],
537+
return_ty: Type::datetime(),
538+
},
539+
AvailableExtensionFunction {
540+
name: Name::parse_unqualified_name("toTime")
541+
.expect("should be a valid identifier"),
542+
is_constructor: false,
543+
parameter_types: vec![Type::datetime()],
544+
return_ty: Type::duration(),
545+
},
546+
AvailableExtensionFunction {
547+
name: Name::parse_unqualified_name("duration")
548+
.expect("should be a valid identifier"),
549+
is_constructor: true,
550+
parameter_types: vec![Type::string()],
551+
return_ty: Type::duration(),
552+
},
553+
AvailableExtensionFunction {
554+
name: Name::parse_unqualified_name("toMilliseconds")
555+
.expect("should be a valid identifier"),
556+
is_constructor: false,
557+
parameter_types: vec![Type::duration()],
558+
return_ty: Type::long(),
559+
},
560+
AvailableExtensionFunction {
561+
name: Name::parse_unqualified_name("toSeconds")
562+
.expect("should be a valid identifier"),
563+
is_constructor: false,
564+
parameter_types: vec![Type::duration()],
565+
return_ty: Type::long(),
566+
},
567+
AvailableExtensionFunction {
568+
name: Name::parse_unqualified_name("toMinutes")
569+
.expect("should be a valid identifier"),
570+
is_constructor: false,
571+
parameter_types: vec![Type::duration()],
572+
return_ty: Type::long(),
573+
},
574+
AvailableExtensionFunction {
575+
name: Name::parse_unqualified_name("toHours")
576+
.expect("should be a valid identifier"),
577+
is_constructor: false,
578+
parameter_types: vec![Type::duration()],
579+
return_ty: Type::long(),
580+
},
581+
AvailableExtensionFunction {
582+
name: Name::parse_unqualified_name("toDays")
583+
.expect("should be a valid identifier"),
584+
is_constructor: false,
585+
parameter_types: vec![Type::duration()],
586+
return_ty: Type::long(),
587+
},
433588
]
434589
} else {
435590
vec![]
@@ -566,6 +721,10 @@ pub enum Type {
566721
IPAddr,
567722
/// Decimal numbers
568723
Decimal,
724+
/// datetime
725+
DateTime,
726+
/// duration
727+
Duration,
569728
}
570729

571730
impl Type {
@@ -606,6 +765,14 @@ impl Type {
606765
pub fn decimal() -> Self {
607766
Type::Decimal
608767
}
768+
/// datetime type
769+
pub fn datetime() -> Self {
770+
Type::DateTime
771+
}
772+
/// duration type
773+
pub fn duration() -> Self {
774+
Type::Duration
775+
}
609776

610777
/// `Type` has `Arbitrary` auto-derived for it, but for the case where you
611778
/// want "any nonextension Type", you have this
@@ -639,6 +806,12 @@ impl TryFrom<Type> for ast::Type {
639806
Type::Decimal => Ok(ast::Type::Extension {
640807
name: extensions::decimal::extension().name().clone(),
641808
}),
809+
Type::DateTime => Ok(ast::Type::Extension {
810+
name: extensions::datetime::extension().name().clone(),
811+
}),
812+
Type::Duration => Ok(ast::Type::Extension {
813+
name: "duration".parse().unwrap(),
814+
}),
642815
}
643816
}
644817
}
@@ -767,3 +940,86 @@ impl From<ABACRequest> for ast::Request {
767940
abac.0.into()
768941
}
769942
}
943+
944+
#[cfg(test)]
945+
mod tests {
946+
use std::{collections::HashMap, sync::Arc};
947+
948+
use arbitrary::{Arbitrary, Unstructured};
949+
use cedar_policy_core::{
950+
ast::{Expr, Name, Request, Value},
951+
entities::Entities,
952+
evaluator::Evaluator,
953+
extensions::Extensions,
954+
};
955+
use rand::{rngs::StdRng, RngCore, SeedableRng};
956+
use smol_str::SmolStr;
957+
958+
use super::ConstantPool;
959+
960+
// get validate string count
961+
#[track_caller]
962+
fn evaluate_batch(strs: &[SmolStr], constructor: Name) -> usize {
963+
let dummy_euid: Arc<cedar_policy_core::ast::EntityUID> =
964+
Arc::new(r#"A::"""#.parse().unwrap());
965+
let dummy_request = Request::new_unchecked(
966+
cedar_policy_core::ast::EntityUIDEntry::Known {
967+
euid: dummy_euid.clone(),
968+
loc: None,
969+
},
970+
cedar_policy_core::ast::EntityUIDEntry::Known {
971+
euid: dummy_euid.clone(),
972+
loc: None,
973+
},
974+
cedar_policy_core::ast::EntityUIDEntry::Known {
975+
euid: dummy_euid.clone(),
976+
loc: None,
977+
},
978+
None,
979+
);
980+
let entities = Entities::new();
981+
let evaluator = Evaluator::new(dummy_request, &entities, Extensions::all_available());
982+
let valid_strs: Vec<_> = strs
983+
.into_iter()
984+
.filter(|s| {
985+
evaluator
986+
.interpret(
987+
&Expr::call_extension_fn(
988+
constructor.clone(),
989+
vec![Value::from(s.to_owned().to_owned()).into()],
990+
),
991+
&HashMap::new(),
992+
)
993+
.is_ok()
994+
})
995+
.collect();
996+
valid_strs.len()
997+
}
998+
999+
#[test]
1000+
fn test_valid_extension_value_ratio() {
1001+
let mut rng = StdRng::seed_from_u64(666);
1002+
let mut bytes = [0; 4096];
1003+
rng.fill_bytes(&mut bytes);
1004+
let mut u = Unstructured::new(&bytes);
1005+
let pool = ConstantPool::arbitrary(&mut u).expect("should not fail");
1006+
1007+
let datetime_strs: Vec<_> = (0..100)
1008+
.map(|_| {
1009+
pool.arbitrary_datetime_str(&mut u)
1010+
.expect("should not fail")
1011+
})
1012+
.collect();
1013+
let valid_datetime_count = evaluate_batch(&datetime_strs, "datetime".parse().unwrap());
1014+
println!("{}", valid_datetime_count);
1015+
1016+
let duration_strs: Vec<_> = (0..100)
1017+
.map(|_| {
1018+
pool.arbitrary_duration_str(&mut u)
1019+
.expect("should not fail")
1020+
})
1021+
.collect();
1022+
let valid_duration_count = evaluate_batch(&duration_strs, "duration".parse().unwrap());
1023+
println!("{}", valid_duration_count);
1024+
}
1025+
}

cedar-policy-generators/src/expr.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,7 @@ impl<'a> ExprGenerator<'a> {
11701170
})
11711171
}
11721172
}
1173-
Type::IPAddr | Type::Decimal => {
1173+
Type::IPAddr | Type::Decimal | Type::DateTime | Type::Duration => {
11741174
if !self.settings.enable_extensions {
11751175
return Err(Error::ExtensionsDisabled);
11761176
};
@@ -1183,13 +1183,17 @@ impl<'a> ExprGenerator<'a> {
11831183
let args = vec![ast::Expr::val(match target_type {
11841184
Type::IPAddr => self.constant_pool.arbitrary_ip_str(u)?,
11851185
Type::Decimal => self.constant_pool.arbitrary_decimal_str(u)?,
1186+
Type::DateTime => self.constant_pool.arbitrary_datetime_str(u)?,
1187+
Type::Duration => self.constant_pool.arbitrary_duration_str(u)?,
11861188
_ => unreachable!("ty is deemed to be an extension type"),
11871189
})];
11881190
Ok(ast::Expr::call_extension_fn(constructor.name.clone(), args))
11891191
} else {
11901192
let type_name: UnreservedId = match target_type {
11911193
Type::IPAddr => "ipaddr".parse::<UnreservedId>().unwrap(),
11921194
Type::Decimal => "decimal".parse().unwrap(),
1195+
Type::DateTime => "datetime".parse().unwrap(),
1196+
Type::Duration => "duration".parse().unwrap(),
11931197
_ => unreachable!("target type is deemed to be an extension type!"),
11941198
};
11951199
gen!(u,
@@ -1653,6 +1657,8 @@ impl<'a> ExprGenerator<'a> {
16531657
match name.as_ref() {
16541658
"ipaddr" => self.generate_expr_for_type(&Type::ipaddr(), max_depth, u),
16551659
"decimal" => self.generate_expr_for_type(&Type::decimal(), max_depth, u),
1660+
"datetime" => self.generate_expr_for_type(&Type::datetime(), max_depth, u),
1661+
"duration" => self.generate_expr_for_type(&Type::duration(), max_depth, u),
16561662
_ => panic!("unrecognized extension type: {name:?}"),
16571663
}
16581664
}
@@ -1705,7 +1711,7 @@ impl<'a> ExprGenerator<'a> {
17051711
arbitrary_specified_uid(u)?,
17061712
)))
17071713
}
1708-
Type::IPAddr | Type::Decimal => {
1714+
Type::IPAddr | Type::Decimal | Type::DateTime | Type::Duration => {
17091715
unimplemented!("constant expression of type ipaddr or decimal")
17101716
}
17111717
}
@@ -1784,6 +1790,12 @@ impl<'a> ExprGenerator<'a> {
17841790
"decimal" => {
17851791
self.generate_ext_func_call_for_type(&Type::decimal(), max_depth, u)
17861792
}
1793+
"datetime" => {
1794+
self.generate_ext_func_call_for_type(&Type::datetime(), max_depth, u)
1795+
}
1796+
"duration" => {
1797+
self.generate_ext_func_call_for_type(&Type::duration(), max_depth, u)
1798+
}
17871799
_ => panic!("unrecognized extension type: {name:?}"),
17881800
},
17891801
// no existing extension functions return set type
@@ -1837,7 +1849,7 @@ impl<'a> ExprGenerator<'a> {
18371849
// the only valid entity-typed attribute value is a UID literal
18381850
Ok(AttrValue::UIDLit(self.generate_uid(u)?))
18391851
}
1840-
Type::IPAddr | Type::Decimal => {
1852+
Type::IPAddr | Type::Decimal | Type::DateTime | Type::Duration => {
18411853
// the only valid extension-typed attribute value is a call of an extension constructor with return the type returned
18421854
if max_depth == 0 {
18431855
return Err(Error::TooDeep);
@@ -1858,6 +1870,14 @@ impl<'a> ExprGenerator<'a> {
18581870
.constant_pool
18591871
.arbitrary_decimal_str(u)
18601872
.map(AttrValue::StringLit),
1873+
Type::DateTime => self
1874+
.constant_pool
1875+
.arbitrary_datetime_str(u)
1876+
.map(AttrValue::StringLit),
1877+
Type::Duration => self
1878+
.constant_pool
1879+
.arbitrary_duration_str(u)
1880+
.map(AttrValue::StringLit),
18611881
_ => unreachable!("target_type should only be one of these two"),
18621882
})
18631883
.collect::<Result<_>>()?;
@@ -2082,6 +2102,12 @@ impl<'a> ExprGenerator<'a> {
20822102
match name.as_ref() {
20832103
"ipaddr" => self.generate_attr_value_for_type(&Type::ipaddr(), max_depth, u),
20842104
"decimal" => self.generate_attr_value_for_type(&Type::decimal(), max_depth, u),
2105+
"datetime" => {
2106+
self.generate_attr_value_for_type(&Type::datetime(), max_depth, u)
2107+
}
2108+
"duration" => {
2109+
self.generate_attr_value_for_type(&Type::duration(), max_depth, u)
2110+
}
20852111
_ => unimplemented!("extension type {name:?}"),
20862112
}
20872113
}

0 commit comments

Comments
 (0)