Skip to content

Commit 8451695

Browse files
khietaAaron Elinecdisselkoen
authored
Removing error extension function (#874) (#1132)
Signed-off-by: Aaron Eline <[email protected]> Co-authored-by: Aaron Eline <[email protected]> Co-authored-by: Craig Disselkoen <[email protected]>
1 parent 2b4d4ef commit 8451695

File tree

3 files changed

+71
-112
lines changed

3 files changed

+71
-112
lines changed

cedar-policy-core/src/ast/expr.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ impl Expr {
321321
ExprBuilder::new().ite(test_expr, then_expr, else_expr)
322322
}
323323

324+
/// Create a ternary (if-then-else) `Expr`.
325+
/// Takes `Arc`s instead of owned `Expr`s.
326+
/// `test_expr` must evaluate to a Bool type
327+
pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
328+
ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
329+
}
330+
324331
/// Create a 'not' expression. `e` must evaluate to Bool type
325332
pub fn not(e: Expr) -> Self {
326333
ExprBuilder::new().not(e)
@@ -827,6 +834,22 @@ impl<T> ExprBuilder<T> {
827834
})
828835
}
829836

837+
/// Create a ternary (if-then-else) `Expr`.
838+
/// Takes `Arc`s instead of owned `Expr`s.
839+
/// `test_expr` must evaluate to a Bool type
840+
pub fn ite_arc(
841+
self,
842+
test_expr: Arc<Expr<T>>,
843+
then_expr: Arc<Expr<T>>,
844+
else_expr: Arc<Expr<T>>,
845+
) -> Expr<T> {
846+
self.with_expr_kind(ExprKind::If {
847+
test_expr,
848+
then_expr,
849+
else_expr,
850+
})
851+
}
852+
830853
/// Create a 'not' expression. `e` must evaluate to Bool type
831854
pub fn not(self, e: Expr<T>) -> Expr<T> {
832855
self.with_expr_kind(ExprKind::UnaryApp {

cedar-policy-core/src/evaluator.rs

Lines changed: 41 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -223,29 +223,6 @@ impl<'e> Evaluator<'e> {
223223
}
224224
}
225225

226-
/// Run an expression as far as possible.
227-
/// however, if an error is encountered, instead of error-ing, wrap the error
228-
/// in a call the `error` extension function.
229-
pub fn run_to_error(
230-
&self,
231-
e: &Expr,
232-
slots: &SlotEnv,
233-
) -> (PartialValue, Option<EvaluationError>) {
234-
match self.partial_interpret(e, slots) {
235-
Ok(e) => (e, None),
236-
Err(err) => {
237-
let arg = Expr::val(format!("{err}"));
238-
// PANIC SAFETY: Input to `parse` is fully static and a valid extension function name
239-
#[allow(clippy::unwrap_used)]
240-
let fn_name = "error".parse().unwrap();
241-
(
242-
PartialValue::Residual(Expr::call_extension_fn(fn_name, vec![arg])),
243-
Some(err),
244-
)
245-
}
246-
}
247-
}
248-
249226
/// Interpret an `Expr` into a `Value` in this evaluation environment.
250227
///
251228
/// Ensures the result is not a residual.
@@ -314,10 +291,9 @@ impl<'e> Evaluator<'e> {
314291
ExprKind::And { left, right } => {
315292
match self.partial_interpret(left, slots)? {
316293
// PE Case
317-
PartialValue::Residual(e) => Ok(PartialValue::Residual(Expr::and(
318-
e,
319-
self.run_to_error(right.as_ref(), slots).0.into(),
320-
))),
294+
PartialValue::Residual(e) => {
295+
Ok(PartialValue::Residual(Expr::and(e, right.as_ref().clone())))
296+
}
321297
// Full eval case
322298
PartialValue::Value(v) => {
323299
if v.get_as_bool()? {
@@ -341,10 +317,9 @@ impl<'e> Evaluator<'e> {
341317
ExprKind::Or { left, right } => {
342318
match self.partial_interpret(left, slots)? {
343319
// PE cases
344-
PartialValue::Residual(r) => Ok(PartialValue::Residual(Expr::or(
345-
r,
346-
self.run_to_error(right, slots).0.into(),
347-
))),
320+
PartialValue::Residual(r) => {
321+
Ok(PartialValue::Residual(Expr::or(r, right.as_ref().clone())))
322+
}
348323
// Full eval case
349324
PartialValue::Value(lhs) => {
350325
if lhs.get_as_bool()? {
@@ -686,8 +661,8 @@ impl<'e> Evaluator<'e> {
686661
fn eval_if(
687662
&self,
688663
guard: &Expr,
689-
consequent: &Expr,
690-
alternative: &Expr,
664+
consequent: &Arc<Expr>,
665+
alternative: &Arc<Expr>,
691666
slots: &SlotEnv,
692667
) -> Result<PartialValue> {
693668
match self.partial_interpret(guard, slots)? {
@@ -699,13 +674,7 @@ impl<'e> Evaluator<'e> {
699674
}
700675
}
701676
PartialValue::Residual(guard) => {
702-
let (consequent, consequent_errored) = self.run_to_error(consequent, slots);
703-
let (alternative, alternative_errored) = self.run_to_error(alternative, slots);
704-
// If both branches errored, the expression will always error
705-
match (consequent_errored, alternative_errored) {
706-
(Some(e), Some(_)) => Err(e),
707-
_ => Ok(Expr::ite(guard, consequent.into(), alternative.into()).into()),
708-
}
677+
Ok(Expr::ite_arc(Arc::new(guard), consequent.clone(), alternative.clone()).into())
709678
}
710679
}
711680
}
@@ -4885,7 +4854,7 @@ pub mod test {
48854854
let b = Expr::and(Expr::val(1), Expr::val(2));
48864855
let c = Expr::val(true);
48874856

4888-
let e = Expr::ite(a, b, c);
4857+
let e = Expr::ite(a, b.clone(), c);
48894858

48904859
let es = Entities::new();
48914860

@@ -4898,10 +4867,7 @@ pub mod test {
48984867
r,
48994868
PartialValue::Residual(Expr::ite(
49004869
Expr::unknown(Unknown::new_untyped("guard")),
4901-
Expr::call_extension_fn(
4902-
"error".parse().unwrap(),
4903-
vec![Expr::val("type error: expected bool, got long")]
4904-
),
4870+
b,
49054871
Expr::val(true)
49064872
))
49074873
)
@@ -4966,14 +4932,21 @@ pub mod test {
49664932
let b = Expr::and(Expr::val(1), Expr::val(2));
49674933
let c = Expr::or(Expr::val(1), Expr::val(3));
49684934

4969-
let e = Expr::ite(a, b, c);
4935+
let e = Expr::ite(a, b.clone(), c.clone());
49704936

49714937
let es = Entities::new();
49724938

49734939
let exts = Extensions::none();
49744940
let eval = Evaluator::new(empty_request(), &es, &exts);
49754941

4976-
assert_matches!(eval.partial_interpret(&e, &HashMap::new()), Err(_));
4942+
assert_eq!(
4943+
eval.partial_interpret(&e, &HashMap::new()).unwrap(),
4944+
PartialValue::Residual(Expr::ite(
4945+
Expr::unknown(Unknown::new_untyped("guard")),
4946+
b,
4947+
c
4948+
))
4949+
);
49774950
}
49784951

49794952
#[test]
@@ -5220,22 +5193,15 @@ pub mod test {
52205193
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
52215194
let cons = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
52225195
let alt = Expr::val(2);
5223-
let e = Expr::ite(guard.clone(), cons, alt);
5196+
let e = Expr::ite(guard.clone(), cons.clone(), alt);
52245197

52255198
let es = Entities::new();
52265199
let exts = Extensions::none();
52275200
let eval = Evaluator::new(empty_request(), &es, &exts);
52285201

52295202
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
52305203

5231-
let expected = Expr::ite(
5232-
guard,
5233-
Expr::call_extension_fn(
5234-
"error".parse().unwrap(),
5235-
vec![Expr::val("type error: expected long, got bool")],
5236-
),
5237-
Expr::val(2),
5238-
);
5204+
let expected = Expr::ite(guard, cons, Expr::val(2));
52395205

52405206
assert_eq!(r, PartialValue::Residual(expected));
52415207
}
@@ -5245,22 +5211,15 @@ pub mod test {
52455211
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
52465212
let cons = Expr::val(2);
52475213
let alt = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
5248-
let e = Expr::ite(guard.clone(), cons, alt);
5214+
let e = Expr::ite(guard.clone(), cons, alt.clone());
52495215

52505216
let es = Entities::new();
52515217
let exts = Extensions::none();
52525218
let eval = Evaluator::new(empty_request(), &es, &exts);
52535219

52545220
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
52555221

5256-
let expected = Expr::ite(
5257-
guard,
5258-
Expr::val(2),
5259-
Expr::call_extension_fn(
5260-
"error".parse().unwrap(),
5261-
vec![Expr::val("type error: expected long, got bool")],
5262-
),
5263-
);
5222+
let expected = Expr::ite(guard, Expr::val(2), alt);
52645223
assert_eq!(r, PartialValue::Residual(expected));
52655224
}
52665225

@@ -5269,13 +5228,16 @@ pub mod test {
52695228
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
52705229
let cons = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
52715230
let alt = Expr::less(Expr::val("hello"), Expr::val("bye"));
5272-
let e = Expr::ite(guard, cons, alt);
5231+
let e = Expr::ite(guard.clone(), cons.clone(), alt.clone());
52735232

52745233
let es = Entities::new();
52755234
let exts = Extensions::none();
52765235
let eval = Evaluator::new(empty_request(), &es, &exts);
52775236

5278-
assert_matches!(eval.partial_interpret(&e, &HashMap::new()), Err(_));
5237+
assert_eq!(
5238+
eval.partial_interpret(&e, &HashMap::new()).unwrap(),
5239+
PartialValue::Residual(Expr::ite(guard, cons, alt))
5240+
);
52795241
}
52805242

52815243
// err && res -> err
@@ -5342,27 +5304,27 @@ pub mod test {
53425304
fn partial_and_res_true() {
53435305
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53445306
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(2));
5345-
let e = Expr::and(lhs.clone(), rhs);
5307+
let e = Expr::and(lhs.clone(), rhs.clone());
53465308
let es = Entities::new();
53475309
let exts = Extensions::none();
53485310
let eval = Evaluator::new(empty_request(), &es, &exts);
53495311

53505312
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5351-
let expected = Expr::and(lhs, Expr::val(true));
5313+
let expected = Expr::and(lhs, rhs);
53525314
assert_eq!(r, PartialValue::Residual(expected));
53535315
}
53545316

53555317
#[test]
53565318
fn partial_and_res_false() {
53575319
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53585320
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(1));
5359-
let e = Expr::and(lhs.clone(), rhs);
5321+
let e = Expr::and(lhs.clone(), rhs.clone());
53605322
let es = Entities::new();
53615323
let exts = Extensions::none();
53625324
let eval = Evaluator::new(empty_request(), &es, &exts);
53635325

53645326
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5365-
let expected = Expr::and(lhs, Expr::val(false));
5327+
let expected = Expr::and(lhs, rhs);
53665328
assert_eq!(r, PartialValue::Residual(expected));
53675329
}
53685330

@@ -5390,7 +5352,7 @@ pub mod test {
53905352
fn partial_and_res_err() {
53915353
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53925354
let rhs = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val("oops"));
5393-
let e = Expr::and(lhs, rhs);
5355+
let e = Expr::and(lhs, rhs.clone());
53945356
let es = Entities::new();
53955357
let exts = Extensions::none();
53965358
let eval = Evaluator::new(empty_request(), &es, &exts);
@@ -5399,10 +5361,7 @@ pub mod test {
53995361

54005362
let expected = Expr::and(
54015363
Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into()),
5402-
Expr::call_extension_fn(
5403-
"error".parse().unwrap(),
5404-
vec![Expr::val("type error: expected long, got string")],
5405-
),
5364+
rhs,
54065365
);
54075366
assert_eq!(r, PartialValue::Residual(expected));
54085367
}
@@ -5444,27 +5403,27 @@ pub mod test {
54445403
fn partial_or_res_true() {
54455404
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
54465405
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(2));
5447-
let e = Expr::or(lhs.clone(), rhs);
5406+
let e = Expr::or(lhs.clone(), rhs.clone());
54485407
let es = Entities::new();
54495408
let exts = Extensions::none();
54505409
let eval = Evaluator::new(empty_request(), &es, &exts);
54515410

54525411
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5453-
let expected = Expr::or(lhs, Expr::val(true));
5412+
let expected = Expr::or(lhs, rhs);
54545413
assert_eq!(r, PartialValue::Residual(expected));
54555414
}
54565415

54575416
#[test]
54585417
fn partial_or_res_false() {
54595418
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
54605419
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(1));
5461-
let e = Expr::or(lhs.clone(), rhs);
5420+
let e = Expr::or(lhs.clone(), rhs.clone());
54625421
let es = Entities::new();
54635422
let exts = Extensions::none();
54645423
let eval = Evaluator::new(empty_request(), &es, &exts);
54655424

54665425
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5467-
let expected = Expr::or(lhs, Expr::val(false));
5426+
let expected = Expr::or(lhs, rhs);
54685427
assert_eq!(r, PartialValue::Residual(expected));
54695428
}
54705429

@@ -5492,7 +5451,7 @@ pub mod test {
54925451
fn partial_or_res_err() {
54935452
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
54945453
let rhs = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val("oops"));
5495-
let e = Expr::or(lhs, rhs);
5454+
let e = Expr::or(lhs, rhs.clone());
54965455
let es = Entities::new();
54975456
let exts = Extensions::none();
54985457
let eval = Evaluator::new(empty_request(), &es, &exts);
@@ -5501,10 +5460,7 @@ pub mod test {
55015460

55025461
let expected = Expr::or(
55035462
Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into()),
5504-
Expr::call_extension_fn(
5505-
"error".parse().unwrap(),
5506-
vec![Expr::val("type error: expected long, got string")],
5507-
),
5463+
rhs,
55085464
);
55095465
assert_eq!(r, PartialValue::Residual(expected));
55105466
}

cedar-policy-core/src/extensions/partial_evaluation.rs

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
use crate::{
2020
ast::{CallStyle, Extension, ExtensionFunction, ExtensionOutputValue, Unknown, Value},
2121
entities::SchemaType,
22-
evaluator::{self, EvaluationError},
22+
evaluator,
2323
};
2424

2525
/// Create a new untyped `Unknown`
@@ -29,37 +29,17 @@ fn create_new_unknown(v: Value) -> evaluator::Result<ExtensionOutputValue> {
2929
)))
3030
}
3131

32-
fn throw_error(v: Value) -> evaluator::Result<ExtensionOutputValue> {
33-
let msg = v.get_as_string()?;
34-
// PANIC SAFETY: This name is fully static, and is a valid extension name
35-
#[allow(clippy::unwrap_used)]
36-
let err = EvaluationError::failed_extension_function_application(
37-
"partial_evaluation".parse().unwrap(),
38-
msg.to_string(),
39-
None, // source loc will be added by the evaluator
40-
);
41-
Err(err)
42-
}
43-
4432
/// Construct the extension
4533
// PANIC SAFETY: all uses of `unwrap` here on parsing extension names are correct names
4634
#[allow(clippy::unwrap_used)]
4735
pub fn extension() -> Extension {
4836
Extension::new(
4937
"partial_evaluation".parse().unwrap(),
50-
vec![
51-
ExtensionFunction::unary_never(
52-
"unknown".parse().unwrap(),
53-
CallStyle::FunctionStyle,
54-
Box::new(create_new_unknown),
55-
Some(SchemaType::String),
56-
),
57-
ExtensionFunction::unary_never(
58-
"error".parse().unwrap(),
59-
CallStyle::FunctionStyle,
60-
Box::new(throw_error),
61-
Some(SchemaType::String),
62-
),
63-
],
38+
vec![ExtensionFunction::unary_never(
39+
"unknown".parse().unwrap(),
40+
CallStyle::FunctionStyle,
41+
Box::new(create_new_unknown),
42+
Some(SchemaType::String),
43+
)],
6444
)
6545
}

0 commit comments

Comments
 (0)