Skip to content

Commit b56bbad

Browse files
author
Aaron Eline
committed
Reducing precision of PE to remove error
Signed-off-by: Aaron Eline <[email protected]>
1 parent 501fcb2 commit b56bbad

File tree

4 files changed

+73
-112
lines changed

4 files changed

+73
-112
lines changed

cedar-policy-core/src/ast/expr.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ impl Expr {
321321
ExprBuilder::new().ite(test_expr, then_expr, else_expr)
322322
}
323323

324+
/// Create a ternary (if-then-else) `Expr`.
325+
/// Takes `Arc`s instead of owned `Expr`s
326+
/// `test_expr` must evaluate to a Bool type
327+
pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
328+
ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
329+
}
330+
324331
/// Create a 'not' expression. `e` must evaluate to Bool type
325332
pub fn not(e: Expr) -> Self {
326333
ExprBuilder::new().not(e)
@@ -827,6 +834,22 @@ impl<T> ExprBuilder<T> {
827834
})
828835
}
829836

837+
/// Create a ternary (if-then-else) `Expr`.
838+
/// Takes `Arc`s instead of owned `Expr`s
839+
/// `test_expr` must evaluate to a Bool type
840+
pub fn ite_arc(
841+
self,
842+
test_expr: Arc<Expr<T>>,
843+
then_expr: Arc<Expr<T>>,
844+
else_expr: Arc<Expr<T>>,
845+
) -> Expr<T> {
846+
self.with_expr_kind(ExprKind::If {
847+
test_expr,
848+
then_expr,
849+
else_expr,
850+
})
851+
}
852+
830853
/// Create a 'not' expression. `e` must evaluate to Bool type
831854
pub fn not(self, e: Expr<T>) -> Expr<T> {
832855
self.with_expr_kind(ExprKind::UnaryApp {

cedar-policy-core/src/evaluator.rs

Lines changed: 41 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -224,29 +224,6 @@ impl<'e> Evaluator<'e> {
224224
}
225225
}
226226

227-
/// Run an expression as far as possible.
228-
/// however, if an error is encountered, instead of error-ing, wrap the error
229-
/// in a call the `error` extension function.
230-
pub fn run_to_error(
231-
&self,
232-
e: &Expr,
233-
slots: &SlotEnv,
234-
) -> (PartialValue, Option<EvaluationError>) {
235-
match self.partial_interpret(e, slots) {
236-
Ok(e) => (e, None),
237-
Err(err) => {
238-
let arg = Expr::val(format!("{err}"));
239-
// PANIC SAFETY: Input to `parse` is fully static and a valid extension function name
240-
#[allow(clippy::unwrap_used)]
241-
let fn_name = "error".parse().unwrap();
242-
(
243-
PartialValue::Residual(Expr::call_extension_fn(fn_name, vec![arg])),
244-
Some(err),
245-
)
246-
}
247-
}
248-
}
249-
250227
/// Interpret an `Expr` into a `Value` in this evaluation environment.
251228
///
252229
/// Ensures the result is not a residual.
@@ -315,10 +292,9 @@ impl<'e> Evaluator<'e> {
315292
ExprKind::And { left, right } => {
316293
match self.partial_interpret(left, slots)? {
317294
// PE Case
318-
PartialValue::Residual(e) => Ok(PartialValue::Residual(Expr::and(
319-
e,
320-
self.run_to_error(right.as_ref(), slots).0.into(),
321-
))),
295+
PartialValue::Residual(e) => {
296+
Ok(PartialValue::Residual(Expr::and(e, right.as_ref().clone())))
297+
}
322298
// Full eval case
323299
PartialValue::Value(v) => {
324300
if v.get_as_bool()? {
@@ -342,10 +318,9 @@ impl<'e> Evaluator<'e> {
342318
ExprKind::Or { left, right } => {
343319
match self.partial_interpret(left, slots)? {
344320
// PE cases
345-
PartialValue::Residual(r) => Ok(PartialValue::Residual(Expr::or(
346-
r,
347-
self.run_to_error(right, slots).0.into(),
348-
))),
321+
PartialValue::Residual(r) => {
322+
Ok(PartialValue::Residual(Expr::or(r, right.as_ref().clone())))
323+
}
349324
// Full eval case
350325
PartialValue::Value(lhs) => {
351326
if lhs.get_as_bool()? {
@@ -687,8 +662,8 @@ impl<'e> Evaluator<'e> {
687662
fn eval_if(
688663
&self,
689664
guard: &Expr,
690-
consequent: &Expr,
691-
alternative: &Expr,
665+
consequent: &Arc<Expr>,
666+
alternative: &Arc<Expr>,
692667
slots: &SlotEnv,
693668
) -> Result<PartialValue> {
694669
match self.partial_interpret(guard, slots)? {
@@ -700,13 +675,7 @@ impl<'e> Evaluator<'e> {
700675
}
701676
}
702677
PartialValue::Residual(guard) => {
703-
let (consequent, consequent_errored) = self.run_to_error(consequent, slots);
704-
let (alternative, alternative_errored) = self.run_to_error(alternative, slots);
705-
// If both branches errored, the expression will always error
706-
match (consequent_errored, alternative_errored) {
707-
(Some(e), Some(_)) => Err(e),
708-
_ => Ok(Expr::ite(guard, consequent.into(), alternative.into()).into()),
709-
}
678+
Ok(Expr::ite_arc(Arc::new(guard), consequent.clone(), alternative.clone()).into())
710679
}
711680
}
712681
}
@@ -4889,7 +4858,7 @@ pub mod test {
48894858
let b = Expr::and(Expr::val(1), Expr::val(2));
48904859
let c = Expr::val(true);
48914860

4892-
let e = Expr::ite(a, b, c);
4861+
let e = Expr::ite(a, b.clone(), c);
48934862

48944863
let es = Entities::new();
48954864

@@ -4902,10 +4871,7 @@ pub mod test {
49024871
r,
49034872
PartialValue::Residual(Expr::ite(
49044873
Expr::unknown(Unknown::new_untyped("guard")),
4905-
Expr::call_extension_fn(
4906-
"error".parse().unwrap(),
4907-
vec![Expr::val("type error: expected bool, got long")]
4908-
),
4874+
b,
49094875
Expr::val(true)
49104876
))
49114877
)
@@ -4970,14 +4936,21 @@ pub mod test {
49704936
let b = Expr::and(Expr::val(1), Expr::val(2));
49714937
let c = Expr::or(Expr::val(1), Expr::val(3));
49724938

4973-
let e = Expr::ite(a, b, c);
4939+
let e = Expr::ite(a, b.clone(), c.clone());
49744940

49754941
let es = Entities::new();
49764942

49774943
let exts = Extensions::none();
49784944
let eval = Evaluator::new(empty_request(), &es, &exts);
49794945

4980-
assert_matches!(eval.partial_interpret(&e, &HashMap::new()), Err(_));
4946+
assert_eq!(
4947+
eval.partial_interpret(&e, &HashMap::new()).unwrap(),
4948+
PartialValue::Residual(Expr::ite(
4949+
Expr::unknown(Unknown::new_untyped("guard")),
4950+
b,
4951+
c
4952+
))
4953+
);
49814954
}
49824955

49834956
#[test]
@@ -5224,22 +5197,15 @@ pub mod test {
52245197
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
52255198
let cons = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
52265199
let alt = Expr::val(2);
5227-
let e = Expr::ite(guard.clone(), cons, alt);
5200+
let e = Expr::ite(guard.clone(), cons.clone(), alt);
52285201

52295202
let es = Entities::new();
52305203
let exts = Extensions::none();
52315204
let eval = Evaluator::new(empty_request(), &es, &exts);
52325205

52335206
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
52345207

5235-
let expected = Expr::ite(
5236-
guard,
5237-
Expr::call_extension_fn(
5238-
"error".parse().unwrap(),
5239-
vec![Expr::val("type error: expected long, got bool")],
5240-
),
5241-
Expr::val(2),
5242-
);
5208+
let expected = Expr::ite(guard, cons, Expr::val(2));
52435209

52445210
assert_eq!(r, PartialValue::Residual(expected));
52455211
}
@@ -5249,22 +5215,15 @@ pub mod test {
52495215
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
52505216
let cons = Expr::val(2);
52515217
let alt = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
5252-
let e = Expr::ite(guard.clone(), cons, alt);
5218+
let e = Expr::ite(guard.clone(), cons, alt.clone());
52535219

52545220
let es = Entities::new();
52555221
let exts = Extensions::none();
52565222
let eval = Evaluator::new(empty_request(), &es, &exts);
52575223

52585224
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
52595225

5260-
let expected = Expr::ite(
5261-
guard,
5262-
Expr::val(2),
5263-
Expr::call_extension_fn(
5264-
"error".parse().unwrap(),
5265-
vec![Expr::val("type error: expected long, got bool")],
5266-
),
5267-
);
5226+
let expected = Expr::ite(guard, Expr::val(2), alt);
52685227
assert_eq!(r, PartialValue::Residual(expected));
52695228
}
52705229

@@ -5273,13 +5232,16 @@ pub mod test {
52735232
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
52745233
let cons = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
52755234
let alt = Expr::less(Expr::val("hello"), Expr::val("bye"));
5276-
let e = Expr::ite(guard, cons, alt);
5235+
let e = Expr::ite(guard.clone(), cons.clone(), alt.clone());
52775236

52785237
let es = Entities::new();
52795238
let exts = Extensions::none();
52805239
let eval = Evaluator::new(empty_request(), &es, &exts);
52815240

5282-
assert_matches!(eval.partial_interpret(&e, &HashMap::new()), Err(_));
5241+
assert_eq!(
5242+
eval.partial_interpret(&e, &HashMap::new()).unwrap(),
5243+
PartialValue::Residual(Expr::ite(guard, cons, alt))
5244+
);
52835245
}
52845246

52855247
// err && res -> err
@@ -5346,27 +5308,27 @@ pub mod test {
53465308
fn partial_and_res_true() {
53475309
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53485310
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(2));
5349-
let e = Expr::and(lhs.clone(), rhs);
5311+
let e = Expr::and(lhs.clone(), rhs.clone());
53505312
let es = Entities::new();
53515313
let exts = Extensions::none();
53525314
let eval = Evaluator::new(empty_request(), &es, &exts);
53535315

53545316
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5355-
let expected = Expr::and(lhs, Expr::val(true));
5317+
let expected = Expr::and(lhs, rhs);
53565318
assert_eq!(r, PartialValue::Residual(expected));
53575319
}
53585320

53595321
#[test]
53605322
fn partial_and_res_false() {
53615323
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53625324
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(1));
5363-
let e = Expr::and(lhs.clone(), rhs);
5325+
let e = Expr::and(lhs.clone(), rhs.clone());
53645326
let es = Entities::new();
53655327
let exts = Extensions::none();
53665328
let eval = Evaluator::new(empty_request(), &es, &exts);
53675329

53685330
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5369-
let expected = Expr::and(lhs, Expr::val(false));
5331+
let expected = Expr::and(lhs, rhs);
53705332
assert_eq!(r, PartialValue::Residual(expected));
53715333
}
53725334

@@ -5394,7 +5356,7 @@ pub mod test {
53945356
fn partial_and_res_err() {
53955357
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53965358
let rhs = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val("oops"));
5397-
let e = Expr::and(lhs, rhs);
5359+
let e = Expr::and(lhs, rhs.clone());
53985360
let es = Entities::new();
53995361
let exts = Extensions::none();
54005362
let eval = Evaluator::new(empty_request(), &es, &exts);
@@ -5403,10 +5365,7 @@ pub mod test {
54035365

54045366
let expected = Expr::and(
54055367
Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into()),
5406-
Expr::call_extension_fn(
5407-
"error".parse().unwrap(),
5408-
vec![Expr::val("type error: expected long, got string")],
5409-
),
5368+
rhs,
54105369
);
54115370
assert_eq!(r, PartialValue::Residual(expected));
54125371
}
@@ -5448,27 +5407,27 @@ pub mod test {
54485407
fn partial_or_res_true() {
54495408
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
54505409
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(2));
5451-
let e = Expr::or(lhs.clone(), rhs);
5410+
let e = Expr::or(lhs.clone(), rhs.clone());
54525411
let es = Entities::new();
54535412
let exts = Extensions::none();
54545413
let eval = Evaluator::new(empty_request(), &es, &exts);
54555414

54565415
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5457-
let expected = Expr::or(lhs, Expr::val(true));
5416+
let expected = Expr::or(lhs, rhs);
54585417
assert_eq!(r, PartialValue::Residual(expected));
54595418
}
54605419

54615420
#[test]
54625421
fn partial_or_res_false() {
54635422
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
54645423
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(1));
5465-
let e = Expr::or(lhs.clone(), rhs);
5424+
let e = Expr::or(lhs.clone(), rhs.clone());
54665425
let es = Entities::new();
54675426
let exts = Extensions::none();
54685427
let eval = Evaluator::new(empty_request(), &es, &exts);
54695428

54705429
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5471-
let expected = Expr::or(lhs, Expr::val(false));
5430+
let expected = Expr::or(lhs, rhs);
54725431
assert_eq!(r, PartialValue::Residual(expected));
54735432
}
54745433

@@ -5496,7 +5455,7 @@ pub mod test {
54965455
fn partial_or_res_err() {
54975456
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
54985457
let rhs = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val("oops"));
5499-
let e = Expr::or(lhs, rhs);
5458+
let e = Expr::or(lhs, rhs.clone());
55005459
let es = Entities::new();
55015460
let exts = Extensions::none();
55025461
let eval = Evaluator::new(empty_request(), &es, &exts);
@@ -5505,10 +5464,7 @@ pub mod test {
55055464

55065465
let expected = Expr::or(
55075466
Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into()),
5508-
Expr::call_extension_fn(
5509-
"error".parse().unwrap(),
5510-
vec![Expr::val("type error: expected long, got string")],
5511-
),
5467+
rhs,
55125468
);
55135469
assert_eq!(r, PartialValue::Residual(expected));
55145470
}

cedar-policy-core/src/extensions/partial_evaluation.rs

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::{
1919
ast::{CallStyle, Extension, ExtensionFunction, ExtensionOutputValue, Unknown, Value},
2020
entities::SchemaType,
21-
evaluator::{self, EvaluationError},
21+
evaluator,
2222
};
2323

2424
/// Create a new untyped `Unknown`
@@ -28,37 +28,17 @@ fn create_new_unknown(v: Value) -> evaluator::Result<ExtensionOutputValue> {
2828
)))
2929
}
3030

31-
fn throw_error(v: Value) -> evaluator::Result<ExtensionOutputValue> {
32-
let msg = v.get_as_string()?;
33-
// PANIC SAFETY: This name is fully static, and is a valid extension name
34-
#[allow(clippy::unwrap_used)]
35-
let err = EvaluationError::failed_extension_function_application(
36-
"partial_evaluation".parse().unwrap(),
37-
msg.to_string(),
38-
None, // source loc will be added by the evaluator
39-
);
40-
Err(err)
41-
}
42-
4331
/// Construct the extension
4432
// PANIC SAFETY: all uses of `unwrap` here on parsing extension names are correct names
4533
#[allow(clippy::unwrap_used)]
4634
pub fn extension() -> Extension {
4735
Extension::new(
4836
"partial_evaluation".parse().unwrap(),
49-
vec![
50-
ExtensionFunction::unary_never(
51-
"unknown".parse().unwrap(),
52-
CallStyle::FunctionStyle,
53-
Box::new(create_new_unknown),
54-
Some(SchemaType::String),
55-
),
56-
ExtensionFunction::unary_never(
57-
"error".parse().unwrap(),
58-
CallStyle::FunctionStyle,
59-
Box::new(throw_error),
60-
Some(SchemaType::String),
61-
),
62-
],
37+
vec![ExtensionFunction::unary_never(
38+
"unknown".parse().unwrap(),
39+
CallStyle::FunctionStyle,
40+
Box::new(create_new_unknown),
41+
Some(SchemaType::String),
42+
)],
6343
)
6444
}

0 commit comments

Comments
 (0)