Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cedar-policy-core/src/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ impl Expr {
ExprBuilder::new().ite(test_expr, then_expr, else_expr)
}

/// Create a ternary (if-then-else) `Expr`.
/// Takes `Arc`s instead of owned `Expr`s
/// `test_expr` must evaluate to a Bool type
pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we make it pub(crate)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rest of the functions are not pub(crate), why this one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the signature of this function is not idiomatic as public functions. I was thinking another function as well but just commented on this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just pub in Core, right? not re-exported in cedar-policy?

ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
}

/// Create a 'not' expression. `e` must evaluate to Bool type
pub fn not(e: Expr) -> Self {
ExprBuilder::new().not(e)
Expand Down Expand Up @@ -827,6 +834,22 @@ impl<T> ExprBuilder<T> {
})
}

/// Create a ternary (if-then-else) `Expr`.
/// Takes `Arc`s instead of owned `Expr`s
/// `test_expr` must evaluate to a Bool type
pub fn ite_arc(
self,
test_expr: Arc<Expr<T>>,
then_expr: Arc<Expr<T>>,
else_expr: Arc<Expr<T>>,
) -> Expr<T> {
self.with_expr_kind(ExprKind::If {
test_expr,
then_expr,
else_expr,
})
}

/// Create a 'not' expression. `e` must evaluate to Bool type
pub fn not(self, e: Expr<T>) -> Expr<T> {
self.with_expr_kind(ExprKind::UnaryApp {
Expand Down
126 changes: 41 additions & 85 deletions cedar-policy-core/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,29 +224,6 @@ impl<'e> Evaluator<'e> {
}
}

/// Run an expression as far as possible.
/// however, if an error is encountered, instead of error-ing, wrap the error
/// in a call the `error` extension function.
pub fn run_to_error(
&self,
e: &Expr,
slots: &SlotEnv,
) -> (PartialValue, Option<EvaluationError>) {
match self.partial_interpret(e, slots) {
Ok(e) => (e, None),
Err(err) => {
let arg = Expr::val(format!("{err}"));
// PANIC SAFETY: Input to `parse` is fully static and a valid extension function name
#[allow(clippy::unwrap_used)]
let fn_name = "error".parse().unwrap();
(
PartialValue::Residual(Expr::call_extension_fn(fn_name, vec![arg])),
Some(err),
)
}
}
}

/// Interpret an `Expr` into a `Value` in this evaluation environment.
///
/// Ensures the result is not a residual.
Expand Down Expand Up @@ -315,10 +292,9 @@ impl<'e> Evaluator<'e> {
ExprKind::And { left, right } => {
match self.partial_interpret(left, slots)? {
// PE Case
PartialValue::Residual(e) => Ok(PartialValue::Residual(Expr::and(
e,
self.run_to_error(right.as_ref(), slots).0.into(),
))),
PartialValue::Residual(e) => {
Ok(PartialValue::Residual(Expr::and(e, right.as_ref().clone())))
}
// Full eval case
PartialValue::Value(v) => {
if v.get_as_bool()? {
Expand All @@ -342,10 +318,9 @@ impl<'e> Evaluator<'e> {
ExprKind::Or { left, right } => {
match self.partial_interpret(left, slots)? {
// PE cases
PartialValue::Residual(r) => Ok(PartialValue::Residual(Expr::or(
r,
self.run_to_error(right, slots).0.into(),
))),
PartialValue::Residual(r) => {
Ok(PartialValue::Residual(Expr::or(r, right.as_ref().clone())))
}
// Full eval case
PartialValue::Value(lhs) => {
if lhs.get_as_bool()? {
Expand Down Expand Up @@ -687,8 +662,8 @@ impl<'e> Evaluator<'e> {
fn eval_if(
&self,
guard: &Expr,
consequent: &Expr,
alternative: &Expr,
consequent: &Arc<Expr>,
alternative: &Arc<Expr>,
slots: &SlotEnv,
) -> Result<PartialValue> {
match self.partial_interpret(guard, slots)? {
Expand All @@ -700,13 +675,7 @@ impl<'e> Evaluator<'e> {
}
}
PartialValue::Residual(guard) => {
let (consequent, consequent_errored) = self.run_to_error(consequent, slots);
let (alternative, alternative_errored) = self.run_to_error(alternative, slots);
// If both branches errored, the expression will always error
match (consequent_errored, alternative_errored) {
(Some(e), Some(_)) => Err(e),
_ => Ok(Expr::ite(guard, consequent.into(), alternative.into()).into()),
}
Ok(Expr::ite_arc(Arc::new(guard), consequent.clone(), alternative.clone()).into())
}
}
}
Expand Down Expand Up @@ -4889,7 +4858,7 @@ pub mod test {
let b = Expr::and(Expr::val(1), Expr::val(2));
let c = Expr::val(true);

let e = Expr::ite(a, b, c);
let e = Expr::ite(a, b.clone(), c);

let es = Entities::new();

Expand All @@ -4902,10 +4871,7 @@ pub mod test {
r,
PartialValue::Residual(Expr::ite(
Expr::unknown(Unknown::new_untyped("guard")),
Expr::call_extension_fn(
"error".parse().unwrap(),
vec![Expr::val("type error: expected bool, got long")]
),
b,
Expr::val(true)
))
)
Expand Down Expand Up @@ -4970,14 +4936,21 @@ pub mod test {
let b = Expr::and(Expr::val(1), Expr::val(2));
let c = Expr::or(Expr::val(1), Expr::val(3));

let e = Expr::ite(a, b, c);
let e = Expr::ite(a, b.clone(), c.clone());

let es = Entities::new();

let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

assert_matches!(eval.partial_interpret(&e, &HashMap::new()), Err(_));
assert_eq!(
eval.partial_interpret(&e, &HashMap::new()).unwrap(),
PartialValue::Residual(Expr::ite(
Expr::unknown(Unknown::new_untyped("guard")),
b,
c
))
);
}

#[test]
Expand Down Expand Up @@ -5224,22 +5197,15 @@ pub mod test {
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
let cons = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
let alt = Expr::val(2);
let e = Expr::ite(guard.clone(), cons, alt);
let e = Expr::ite(guard.clone(), cons.clone(), alt);

let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();

let expected = Expr::ite(
guard,
Expr::call_extension_fn(
"error".parse().unwrap(),
vec![Expr::val("type error: expected long, got bool")],
),
Expr::val(2),
);
let expected = Expr::ite(guard, cons, Expr::val(2));

assert_eq!(r, PartialValue::Residual(expected));
}
Expand All @@ -5249,22 +5215,15 @@ pub mod test {
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
let cons = Expr::val(2);
let alt = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
let e = Expr::ite(guard.clone(), cons, alt);
let e = Expr::ite(guard.clone(), cons, alt.clone());

let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();

let expected = Expr::ite(
guard,
Expr::val(2),
Expr::call_extension_fn(
"error".parse().unwrap(),
vec![Expr::val("type error: expected long, got bool")],
),
);
let expected = Expr::ite(guard, Expr::val(2), alt);
assert_eq!(r, PartialValue::Residual(expected));
}

Expand All @@ -5273,13 +5232,16 @@ pub mod test {
let guard = Expr::get_attr(Expr::unknown(Unknown::new_untyped("a")), "field".into());
let cons = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val(true));
let alt = Expr::less(Expr::val("hello"), Expr::val("bye"));
let e = Expr::ite(guard, cons, alt);
let e = Expr::ite(guard.clone(), cons.clone(), alt.clone());

let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

assert_matches!(eval.partial_interpret(&e, &HashMap::new()), Err(_));
assert_eq!(
eval.partial_interpret(&e, &HashMap::new()).unwrap(),
PartialValue::Residual(Expr::ite(guard, cons, alt))
);
}

// err && res -> err
Expand Down Expand Up @@ -5346,27 +5308,27 @@ pub mod test {
fn partial_and_res_true() {
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(2));
let e = Expr::and(lhs.clone(), rhs);
let e = Expr::and(lhs.clone(), rhs.clone());
let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
let expected = Expr::and(lhs, Expr::val(true));
let expected = Expr::and(lhs, rhs);
assert_eq!(r, PartialValue::Residual(expected));
}

#[test]
fn partial_and_res_false() {
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(1));
let e = Expr::and(lhs.clone(), rhs);
let e = Expr::and(lhs.clone(), rhs.clone());
let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
let expected = Expr::and(lhs, Expr::val(false));
let expected = Expr::and(lhs, rhs);
assert_eq!(r, PartialValue::Residual(expected));
}

Expand Down Expand Up @@ -5394,7 +5356,7 @@ pub mod test {
fn partial_and_res_err() {
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
let rhs = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val("oops"));
let e = Expr::and(lhs, rhs);
let e = Expr::and(lhs, rhs.clone());
let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);
Expand All @@ -5403,10 +5365,7 @@ pub mod test {

let expected = Expr::and(
Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into()),
Expr::call_extension_fn(
"error".parse().unwrap(),
vec![Expr::val("type error: expected long, got string")],
),
rhs,
);
assert_eq!(r, PartialValue::Residual(expected));
}
Expand Down Expand Up @@ -5448,27 +5407,27 @@ pub mod test {
fn partial_or_res_true() {
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(2));
let e = Expr::or(lhs.clone(), rhs);
let e = Expr::or(lhs.clone(), rhs.clone());
let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
let expected = Expr::or(lhs, Expr::val(true));
let expected = Expr::or(lhs, rhs);
assert_eq!(r, PartialValue::Residual(expected));
}

#[test]
fn partial_or_res_false() {
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
let rhs = Expr::binary_app(BinaryOp::Eq, Expr::val(2), Expr::val(1));
let e = Expr::or(lhs.clone(), rhs);
let e = Expr::or(lhs.clone(), rhs.clone());
let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
let expected = Expr::or(lhs, Expr::val(false));
let expected = Expr::or(lhs, rhs);
assert_eq!(r, PartialValue::Residual(expected));
}

Expand Down Expand Up @@ -5496,7 +5455,7 @@ pub mod test {
fn partial_or_res_err() {
let lhs = Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into());
let rhs = Expr::binary_app(BinaryOp::Add, Expr::val(1), Expr::val("oops"));
let e = Expr::or(lhs, rhs);
let e = Expr::or(lhs, rhs.clone());
let es = Entities::new();
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);
Expand All @@ -5505,10 +5464,7 @@ pub mod test {

let expected = Expr::or(
Expr::get_attr(Expr::unknown(Unknown::new_untyped("test")), "field".into()),
Expr::call_extension_fn(
"error".parse().unwrap(),
vec![Expr::val("type error: expected long, got string")],
),
rhs,
);
assert_eq!(r, PartialValue::Residual(expected));
}
Expand Down
34 changes: 7 additions & 27 deletions cedar-policy-core/src/extensions/partial_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::{
ast::{CallStyle, Extension, ExtensionFunction, ExtensionOutputValue, Unknown, Value},
entities::SchemaType,
evaluator::{self, EvaluationError},
evaluator,
};

/// Create a new untyped `Unknown`
Expand All @@ -28,37 +28,17 @@ fn create_new_unknown(v: Value) -> evaluator::Result<ExtensionOutputValue> {
)))
}

fn throw_error(v: Value) -> evaluator::Result<ExtensionOutputValue> {
let msg = v.get_as_string()?;
// PANIC SAFETY: This name is fully static, and is a valid extension name
#[allow(clippy::unwrap_used)]
let err = EvaluationError::failed_extension_function_application(
"partial_evaluation".parse().unwrap(),
msg.to_string(),
None, // source loc will be added by the evaluator
);
Err(err)
}

/// Construct the extension
// PANIC SAFETY: all uses of `unwrap` here on parsing extension names are correct names
#[allow(clippy::unwrap_used)]
pub fn extension() -> Extension {
Extension::new(
"partial_evaluation".parse().unwrap(),
vec![
ExtensionFunction::unary_never(
"unknown".parse().unwrap(),
CallStyle::FunctionStyle,
Box::new(create_new_unknown),
Some(SchemaType::String),
),
ExtensionFunction::unary_never(
"error".parse().unwrap(),
CallStyle::FunctionStyle,
Box::new(throw_error),
Some(SchemaType::String),
),
],
vec![ExtensionFunction::unary_never(
"unknown".parse().unwrap(),
CallStyle::FunctionStyle,
Box::new(create_new_unknown),
Some(SchemaType::String),
)],
)
}
Loading