Skip to content

Commit 7cb7e32

Browse files
Aaron Elinecdisselkoen
andauthored
Removing error extension function (#874)
Signed-off-by: Aaron Eline <[email protected]> Co-authored-by: Craig Disselkoen <[email protected]>
1 parent 2b7fe99 commit 7cb7e32

File tree

4 files changed

+73
-112
lines changed

4 files changed

+73
-112
lines changed

cedar-policy-core/src/ast/expr.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ impl Expr {
321321
ExprBuilder::new().ite(test_expr, then_expr, else_expr)
322322
}
323323

324+
/// Create a ternary (if-then-else) `Expr`.
325+
/// Takes `Arc`s instead of owned `Expr`s.
326+
/// `test_expr` must evaluate to a Bool type
327+
pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
328+
ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
329+
}
330+
324331
/// Create a 'not' expression. `e` must evaluate to Bool type
325332
pub fn not(e: Expr) -> Self {
326333
ExprBuilder::new().not(e)
@@ -827,6 +834,22 @@ impl<T> ExprBuilder<T> {
827834
})
828835
}
829836

837+
/// Create a ternary (if-then-else) `Expr`.
838+
/// Takes `Arc`s instead of owned `Expr`s.
839+
/// `test_expr` must evaluate to a Bool type
840+
pub fn ite_arc(
841+
self,
842+
test_expr: Arc<Expr<T>>,
843+
then_expr: Arc<Expr<T>>,
844+
else_expr: Arc<Expr<T>>,
845+
) -> Expr<T> {
846+
self.with_expr_kind(ExprKind::If {
847+
test_expr,
848+
then_expr,
849+
else_expr,
850+
})
851+
}
852+
830853
/// Create a 'not' expression. `e` must evaluate to Bool type
831854
pub fn not(self, e: Expr<T>) -> Expr<T> {
832855
self.with_expr_kind(ExprKind::UnaryApp {

cedar-policy-core/src/evaluator.rs

Lines changed: 41 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -226,29 +226,6 @@ impl<'e> Evaluator<'e> {
226226
}
227227
}
228228

229-
/// Run an expression as far as possible.
230-
/// however, if an error is encountered, instead of error-ing, wrap the error
231-
/// in a call the `error` extension function.
232-
pub fn run_to_error(
233-
&self,
234-
e: &Expr,
235-
slots: &SlotEnv,
236-
) -> (PartialValue, Option<EvaluationError>) {
237-
match self.partial_interpret(e, slots) {
238-
Ok(e) => (e, None),
239-
Err(err) => {
240-
let arg = Expr::val(format!("{err}"));
241-
// PANIC SAFETY: Input to `parse` is fully static and a valid extension function name
242-
#[allow(clippy::unwrap_used)]
243-
let fn_name = "error".parse().unwrap();
244-
(
245-
PartialValue::Residual(Expr::call_extension_fn(fn_name, vec![arg])),
246-
Some(err),
247-
)
248-
}
249-
}
250-
}
251-
252229
/// Interpret an `Expr` into a `Value` in this evaluation environment.
253230
///
254231
/// Ensures the result is not a residual.
@@ -317,10 +294,9 @@ impl<'e> Evaluator<'e> {
317294
ExprKind::And { left, right } => {
318295
match self.partial_interpret(left, slots)? {
319296
// PE Case
320-
PartialValue::Residual(e) => Ok(PartialValue::Residual(Expr::and(
321-
e,
322-
self.run_to_error(right.as_ref(), slots).0.into(),
323-
))),
297+
PartialValue::Residual(e) => {
298+
Ok(PartialValue::Residual(Expr::and(e, right.as_ref().clone())))
299+
}
324300
// Full eval case
325301
PartialValue::Value(v) => {
326302
if v.get_as_bool()? {
@@ -344,10 +320,9 @@ impl<'e> Evaluator<'e> {
344320
ExprKind::Or { left, right } => {
345321
match self.partial_interpret(left, slots)? {
346322
// PE cases
347-
PartialValue::Residual(r) => Ok(PartialValue::Residual(Expr::or(
348-
r,
349-
self.run_to_error(right, slots).0.into(),
350-
))),
323+
PartialValue::Residual(r) => {
324+
Ok(PartialValue::Residual(Expr::or(r, right.as_ref().clone())))
325+
}
351326
// Full eval case
352327
PartialValue::Value(lhs) => {
353328
if lhs.get_as_bool()? {
@@ -695,8 +670,8 @@ impl<'e> Evaluator<'e> {
695670
fn eval_if(
696671
&self,
697672
guard: &Expr,
698-
consequent: &Expr,
699-
alternative: &Expr,
673+
consequent: &Arc<Expr>,
674+
alternative: &Arc<Expr>,
700675
slots: &SlotEnv,
701676
) -> Result<PartialValue> {
702677
match self.partial_interpret(guard, slots)? {
@@ -708,13 +683,7 @@ impl<'e> Evaluator<'e> {
708683
}
709684
}
710685
PartialValue::Residual(guard) => {
711-
let (consequent, consequent_errored) = self.run_to_error(consequent, slots);
712-
let (alternative, alternative_errored) = self.run_to_error(alternative, slots);
713-
// If both branches errored, the expression will always error
714-
match (consequent_errored, alternative_errored) {
715-
(Some(e), Some(_)) => Err(e),
716-
_ => Ok(Expr::ite(guard, consequent.into(), alternative.into()).into()),
717-
}
686+
Ok(Expr::ite_arc(Arc::new(guard), consequent.clone(), alternative.clone()).into())
718687
}
719688
}
720689
}
@@ -4755,7 +4724,7 @@ pub mod test {
47554724
let b = Expr::and(Expr::val(1), Expr::val(2));
47564725
let c = Expr::val(true);
47574726

4758-
let e = Expr::ite(a, b, c);
4727+
let e = Expr::ite(a, b.clone(), c);
47594728

47604729
let es = Entities::new();
47614730

@@ -4768,10 +4737,7 @@ pub mod test {
47684737
r,
47694738
PartialValue::Residual(Expr::ite(
47704739
Expr::unknown(Unknown::new_untyped("guard")),
4771-
Expr::call_extension_fn(
4772-
"error".parse().unwrap(),
4773-
vec![Expr::val("type error: expected bool, got long")]
4774-
),
4740+
b,
47754741
Expr::val(true)
47764742
))
47774743
)
@@ -4836,14 +4802,21 @@ pub mod test {
48364802
let b = Expr::and(Expr::val(1), Expr::val(2));
48374803
let c = Expr::or(Expr::val(1), Expr::val(3));
48384804

4839-
let e = Expr::ite(a, b, c);
4805+
let e = Expr::ite(a, b.clone(), c.clone());
48404806

48414807
let es = Entities::new();
48424808

48434809
let exts = Extensions::none();
48444810
let eval = Evaluator::new(empty_request(), &es, &exts);
48454811

4846-
assert_matches!(eval.partial_interpret(&e, &HashMap::new()), Err(_));
4812+
assert_eq!(
4813+
eval.partial_interpret(&e, &HashMap::new()).unwrap(),
4814+
PartialValue::Residual(Expr::ite(
4815+
Expr::unknown(Unknown::new_untyped("guard")),
4816+
b,
4817+
c
4818+
))
4819+
);
48474820
}
48484821

48494822
#[test]
@@ -5090,22 +5063,15 @@ pub mod test {
50905063
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
50915064
let cons = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
50925065
let alt = Expr::val(2);
5093-
let e = Expr::ite(guard.clone(), cons, alt);
5066+
let e = Expr::ite(guard.clone(), cons.clone(), alt);
50945067

50955068
let es = Entities::new();
50965069
let exts = Extensions::none();
50975070
let eval = Evaluator::new(empty_request(), &es, &exts);
50985071

50995072
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
51005073

5101-
let expected = Expr::ite(
5102-
guard,
5103-
Expr::call_extension_fn(
5104-
"error".parse().unwrap(),
5105-
vec![Expr::val("type error: expected long, got bool")],
5106-
),
5107-
Expr::val(2),
5108-
);
5074+
let expected = Expr::ite(guard, cons, Expr::val(2));
51095075

51105076
assert_eq!(r, PartialValue::Residual(expected));
51115077
}
@@ -5115,22 +5081,15 @@ pub mod test {
51155081
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
51165082
let cons = Expr::val(2);
51175083
let alt = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
5118-
let e = Expr::ite(guard.clone(), cons, alt);
5084+
let e = Expr::ite(guard.clone(), cons, alt.clone());
51195085

51205086
let es = Entities::new();
51215087
let exts = Extensions::none();
51225088
let eval = Evaluator::new(empty_request(), &es, &exts);
51235089

51245090
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
51255091

5126-
let expected = Expr::ite(
5127-
guard,
5128-
Expr::val(2),
5129-
Expr::call_extension_fn(
5130-
"error".parse().unwrap(),
5131-
vec![Expr::val("type error: expected long, got bool")],
5132-
),
5133-
);
5092+
let expected = Expr::ite(guard, Expr::val(2), alt);
51345093
assert_eq!(r, PartialValue::Residual(expected));
51355094
}
51365095

@@ -5139,13 +5098,16 @@ pub mod test {
51395098
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
51405099
let cons = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
51415100
let alt = Expr::less(Expr::val("hello"), Expr::val("bye"));
5142-
let e = Expr::ite(guard, cons, alt);
5101+
let e = Expr::ite(guard.clone(), cons.clone(), alt.clone());
51435102

51445103
let es = Entities::new();
51455104
let exts = Extensions::none();
51465105
let eval = Evaluator::new(empty_request(), &es, &exts);
51475106

5148-
assert_matches!(eval.partial_interpret(&e, &HashMap::new()), Err(_));
5107+
assert_eq!(
5108+
eval.partial_interpret(&e, &HashMap::new()).unwrap(),
5109+
PartialValue::Residual(Expr::ite(guard, cons, alt))
5110+
);
51495111
}
51505112

51515113
// err && res -> err
@@ -5212,27 +5174,27 @@ pub mod test {
52125174
fn partial_and_res_true() {
52135175
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
52145176
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(2));
5215-
let e = Expr::and(lhs.clone(), rhs);
5177+
let e = Expr::and(lhs.clone(), rhs.clone());
52165178
let es = Entities::new();
52175179
let exts = Extensions::none();
52185180
let eval = Evaluator::new(empty_request(), &es, &exts);
52195181

52205182
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5221-
let expected = Expr::and(lhs, Expr::val(true));
5183+
let expected = Expr::and(lhs, rhs);
52225184
assert_eq!(r, PartialValue::Residual(expected));
52235185
}
52245186

52255187
#[test]
52265188
fn partial_and_res_false() {
52275189
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
52285190
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(1));
5229-
let e = Expr::and(lhs.clone(), rhs);
5191+
let e = Expr::and(lhs.clone(), rhs.clone());
52305192
let es = Entities::new();
52315193
let exts = Extensions::none();
52325194
let eval = Evaluator::new(empty_request(), &es, &exts);
52335195

52345196
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5235-
let expected = Expr::and(lhs, Expr::val(false));
5197+
let expected = Expr::and(lhs, rhs);
52365198
assert_eq!(r, PartialValue::Residual(expected));
52375199
}
52385200

@@ -5260,7 +5222,7 @@ pub mod test {
52605222
fn partial_and_res_err() {
52615223
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
52625224
let rhs = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val("oops"));
5263-
let e = Expr::and(lhs, rhs);
5225+
let e = Expr::and(lhs, rhs.clone());
52645226
let es = Entities::new();
52655227
let exts = Extensions::none();
52665228
let eval = Evaluator::new(empty_request(), &es, &exts);
@@ -5269,10 +5231,7 @@ pub mod test {
52695231

52705232
let expected = Expr::and(
52715233
Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into()),
5272-
Expr::call_extension_fn(
5273-
"error".parse().unwrap(),
5274-
vec![Expr::val("type error: expected long, got string")],
5275-
),
5234+
rhs,
52765235
);
52775236
assert_eq!(r, PartialValue::Residual(expected));
52785237
}
@@ -5314,27 +5273,27 @@ pub mod test {
53145273
fn partial_or_res_true() {
53155274
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53165275
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(2));
5317-
let e = Expr::or(lhs.clone(), rhs);
5276+
let e = Expr::or(lhs.clone(), rhs.clone());
53185277
let es = Entities::new();
53195278
let exts = Extensions::none();
53205279
let eval = Evaluator::new(empty_request(), &es, &exts);
53215280

53225281
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5323-
let expected = Expr::or(lhs, Expr::val(true));
5282+
let expected = Expr::or(lhs, rhs);
53245283
assert_eq!(r, PartialValue::Residual(expected));
53255284
}
53265285

53275286
#[test]
53285287
fn partial_or_res_false() {
53295288
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53305289
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(1));
5331-
let e = Expr::or(lhs.clone(), rhs);
5290+
let e = Expr::or(lhs.clone(), rhs.clone());
53325291
let es = Entities::new();
53335292
let exts = Extensions::none();
53345293
let eval = Evaluator::new(empty_request(), &es, &exts);
53355294

53365295
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
5337-
let expected = Expr::or(lhs, Expr::val(false));
5296+
let expected = Expr::or(lhs, rhs);
53385297
assert_eq!(r, PartialValue::Residual(expected));
53395298
}
53405299

@@ -5362,7 +5321,7 @@ pub mod test {
53625321
fn partial_or_res_err() {
53635322
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
53645323
let rhs = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val("oops"));
5365-
let e = Expr::or(lhs, rhs);
5324+
let e = Expr::or(lhs, rhs.clone());
53665325
let es = Entities::new();
53675326
let exts = Extensions::none();
53685327
let eval = Evaluator::new(empty_request(), &es, &exts);
@@ -5371,10 +5330,7 @@ pub mod test {
53715330

53725331
let expected = Expr::or(
53735332
Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into()),
5374-
Expr::call_extension_fn(
5375-
"error".parse().unwrap(),
5376-
vec![Expr::val("type error: expected long, got string")],
5377-
),
5333+
rhs,
53785334
);
53795335
assert_eq!(r, PartialValue::Residual(expected));
53805336
}

cedar-policy-core/src/extensions/partial_evaluation.rs

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::{
1919
ast::{CallStyle, Extension, ExtensionFunction, ExtensionOutputValue, Unknown, Value},
2020
entities::SchemaType,
21-
evaluator::{self, EvaluationError},
21+
evaluator,
2222
};
2323

2424
/// Create a new untyped `Unknown`
@@ -28,37 +28,17 @@ fn create_new_unknown(v: Value) -> evaluator::Result<ExtensionOutputValue> {
2828
)))
2929
}
3030

31-
fn throw_error(v: Value) -> evaluator::Result<ExtensionOutputValue> {
32-
let msg = v.get_as_string()?;
33-
// PANIC SAFETY: This name is fully static, and is a valid extension name
34-
#[allow(clippy::unwrap_used)]
35-
let err = EvaluationError::failed_extension_function_application(
36-
"partial_evaluation".parse().unwrap(),
37-
msg.to_string(),
38-
None, // source loc will be added by the evaluator
39-
);
40-
Err(err)
41-
}
42-
4331
/// Construct the extension
4432
// PANIC SAFETY: all uses of `unwrap` here on parsing extension names are correct names
4533
#[allow(clippy::unwrap_used)]
4634
pub fn extension() -> Extension {
4735
Extension::new(
4836
"partial_evaluation".parse().unwrap(),
49-
vec![
50-
ExtensionFunction::unary_never(
51-
"unknown".parse().unwrap(),
52-
CallStyle::FunctionStyle,
53-
Box::new(create_new_unknown),
54-
Some(SchemaType::String),
55-
),
56-
ExtensionFunction::unary_never(
57-
"error".parse().unwrap(),
58-
CallStyle::FunctionStyle,
59-
Box::new(throw_error),
60-
Some(SchemaType::String),
61-
),
62-
],
37+
vec![ExtensionFunction::unary_never(
38+
"unknown".parse().unwrap(),
39+
CallStyle::FunctionStyle,
40+
Box::new(create_new_unknown),
41+
Some(SchemaType::String),
42+
)],
6343
)
6444
}

0 commit comments

Comments
 (0)