Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cedar-policy-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ prost = { version = "0.13", optional = true }

[features]
# by default, enable all Cedar extensions
default = ["ipaddr", "decimal", "datetime"]
default = ["ipaddr", "decimal"]
ipaddr = []
decimal = ["dep:regex"]
datetime = ["dep:chrono", "dep:regex"]
Expand Down
36 changes: 14 additions & 22 deletions cedar-policy-core/src/ast/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,15 @@
* limitations under the License.
*/

pub use names::TYPES_WITH_OPERATOR_OVERLOADING;

use crate::ast::*;
use crate::entities::SchemaType;
use crate::evaluator;
use std::any::Any;
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use std::fmt::Debug;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::Arc;

// PANIC SAFETY: `Name`s in here are valid `Name`s
#[allow(clippy::expect_used)]
mod names {
use std::collections::BTreeSet;

use super::Name;

lazy_static::lazy_static! {
/// Extension type names that support operator overloading
// INVARIANT: this set must not be empty.
pub static ref TYPES_WITH_OPERATOR_OVERLOADING : BTreeSet<Name> =
BTreeSet::from_iter(
[Name::parse_unqualified_name("datetime").expect("valid identifier"),
Name::parse_unqualified_name("duration").expect("valid identifier")]
);
}
}

/// Cedar extension.
///
/// An extension can define new types and functions on those types. (Currently,
Expand All @@ -54,14 +34,21 @@ pub struct Extension {
name: Name,
/// Extension functions. These are legal to call in Cedar expressions.
functions: HashMap<Name, ExtensionFunction>,
/// Types with operator overloading
types_with_operator_overloading: BTreeSet<Name>,
}

impl Extension {
/// Create a new `Extension` with the given name and extension functions
pub fn new(name: Name, functions: impl IntoIterator<Item = ExtensionFunction>) -> Self {
pub fn new(
name: Name,
functions: impl IntoIterator<Item = ExtensionFunction>,
types_with_operator_overloading: impl IntoIterator<Item = Name>,
) -> Self {
Self {
name,
functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
}
}

Expand All @@ -86,6 +73,11 @@ impl Extension {
pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
self.funcs().flat_map(|func| func.ext_types())
}

/// Iterate over extension types with operator overloading
pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
self.types_with_operator_overloading.iter()
}
}

impl std::fmt::Debug for Extension {
Expand Down
320 changes: 122 additions & 198 deletions cedar-policy-core/src/evaluator.rs

Large diffs are not rendered by default.

21 changes: 7 additions & 14 deletions cedar-policy-core/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ pub mod partial_evaluation;

use std::collections::HashMap;

use crate::ast::{Extension, ExtensionFunction, Name, TYPES_WITH_OPERATOR_OVERLOADING};
use crate::ast::{Extension, ExtensionFunction, Name};
use crate::entities::SchemaType;
use crate::parser::Loc;
use miette::Diagnostic;
use nonempty::NonEmpty;
use thiserror::Error;

use self::extension_function_lookup_errors::FuncDoesNotExistError;
Expand Down Expand Up @@ -98,21 +97,15 @@ impl Extensions<'static> {
pub fn none() -> &'static Extensions<'static> {
&EXTENSIONS_NONE
}

/// Obtain the non-empty vector of types supporting operator overloading
pub fn types_with_operator_overloading() -> NonEmpty<Name> {
// PANIC SAFETY: There are more than one element in `TYPES_WITH_OPERATOR_OVERLOADING`
#[allow(clippy::unwrap_used)]
NonEmpty::collect(TYPES_WITH_OPERATOR_OVERLOADING.iter().cloned()).unwrap()
}

/// Iterate over extension types that support operator overloading
pub fn iter_type_with_operator_overloading() -> impl Iterator<Item = &'static Name> {
TYPES_WITH_OPERATOR_OVERLOADING.iter()
}
}

impl<'a> Extensions<'a> {
/// Obtain the non-empty vector of types supporting operator overloading
pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
self.extensions
.iter()
.flat_map(|ext| ext.types_with_operator_overloading())
}
/// Get a new `Extensions` with these specific extensions enabled.
pub fn specific_extensions(
extensions: &'a [Extension],
Expand Down
4 changes: 4 additions & 0 deletions cedar-policy-core/src/extensions/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,10 @@ pub fn extension() -> Extension {
duration_type,
),
],
[
DATETIME_CONSTRUCTOR_NAME.clone(),
DURATION_CONSTRUCTOR_NAME.clone(),
],
)
}

Expand Down
5 changes: 3 additions & 2 deletions cedar-policy-core/src/extensions/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ pub fn extension() -> Extension {
(decimal_type.clone(), decimal_type),
),
],
std::iter::empty(),
)
}

Expand Down Expand Up @@ -627,12 +628,12 @@ mod tests {
&parse_expr(r#"decimal("1.23") < decimal("1.24")"#).expect("parsing error")
),
Err(EvaluationError::TypeError(evaluation_errors::TypeError { expected, actual, advice, .. })) => {
assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}]);
assert_eq!(expected, nonempty![Type::Long]);
assert_eq!(actual, Type::Extension {
name: Name::parse_unqualified_name("decimal")
.expect("should be a valid identifier")
});
assert_eq!(advice, Some("Only extension types `datetime` and `duration` support operator overloading".into()));
assert_eq!(advice, Some("Only types long support comparison".into()));
}
);
assert_matches!(
Expand Down
5 changes: 3 additions & 2 deletions cedar-policy-core/src/extensions/ipaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ pub fn extension() -> Extension {
(ipaddr_type.clone(), ipaddr_type),
),
],
std::iter::empty(),
)
}

Expand Down Expand Up @@ -609,12 +610,12 @@ mod tests {
assert_matches!(
eval.interpret_inline_policy(&Expr::less(ip("127.0.0.1"), ip("10.0.0.10"))),
Err(EvaluationError::TypeError(evaluation_errors::TypeError { expected, actual, advice, .. })) => {
assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}]);
assert_eq!(expected, nonempty![Type::Long]);
assert_eq!(actual, Type::Extension {
name: Name::parse_unqualified_name("ipaddr")
.expect("should be a valid identifier")
});
assert_eq!(advice, Some("Only extension types `datetime` and `duration` support operator overloading".into()));
assert_eq!(advice, Some("Only types long support comparison".into()));
}
);
// test that isIpv4 on a String is an error
Expand Down
1 change: 1 addition & 0 deletions cedar-policy-core/src/extensions/partial_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,6 @@ pub fn extension() -> Extension {
Box::new(create_new_unknown),
SchemaType::String,
)],
std::iter::empty(),
)
}
2 changes: 1 addition & 1 deletion cedar-policy-validator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ prost = { version = "0.13", optional = true }

[features]
# by default, enable all Cedar extensions
default = ["ipaddr", "decimal", "datetime"]
default = ["ipaddr", "decimal"]
# when enabling a feature, make sure that the Core feature is also enabled
ipaddr = ["cedar-policy-core/ipaddr"]
decimal = ["cedar-policy-core/decimal"]
Expand Down
11 changes: 11 additions & 0 deletions cedar-policy-validator/src/extension_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

use std::collections::BTreeSet;

use crate::types::Type;
use cedar_policy_core::ast::{Expr, Name};

Expand All @@ -23,6 +25,8 @@ pub struct ExtensionSchema {
name: Name,
/// Type information for extension functions
function_types: Vec<ExtensionFunctionType>,
/// Types that support operator overloading
types_with_operator_overloading: BTreeSet<Name>,
}

impl std::fmt::Debug for ExtensionSchema {
Expand All @@ -36,10 +40,12 @@ impl ExtensionSchema {
pub fn new(
name: Name,
function_types: impl IntoIterator<Item = ExtensionFunctionType>,
types_with_operator_overloading: impl IntoIterator<Item = Name>,
) -> Self {
Self {
name,
function_types: function_types.into_iter().collect(),
types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
}
}

Expand All @@ -51,6 +57,11 @@ impl ExtensionSchema {
pub fn function_types(&self) -> impl Iterator<Item = &ExtensionFunctionType> {
self.function_types.iter()
}

/// Get all extension types that support operator overloading
pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> {
self.types_with_operator_overloading.iter()
}
}

/// The type of a function used to perform custom argument validation on an
Expand Down
27 changes: 24 additions & 3 deletions cedar-policy-validator/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

//! This module contains type information for all of the standard Cedar extensions.

use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};

use cedar_policy_core::{
ast::{Name, RestrictedExpr, Value},
Expand Down Expand Up @@ -57,7 +57,7 @@ lazy_static::lazy_static! {
static ref ALL_AVAILABLE_EXTENSION_SCHEMAS : ExtensionSchemas<'static> = ExtensionSchemas::build_all_available();
}

/// Aggregate structure containing function signatures for multiple [`ExtensionSchema`].
/// Aggregate structure containing information such as function signatures for multiple [`ExtensionSchema`].
/// Ensures that no function name is defined mode than once.
/// Intentionally does not derive `Clone` to avoid clones of the `HashMap`. For the
/// moment, it's easy to pass this around by reference. We could make this
Expand All @@ -69,6 +69,8 @@ pub struct ExtensionSchemas<'a> {
/// extension function lookup that at most one extension functions exists
/// for a name.
function_types: HashMap<&'a Name, &'a ExtensionFunctionType>,
/// Extension types that support operator overloading
types_with_operator_overloading: BTreeSet<&'a Name>,
}

impl<'a> ExtensionSchemas<'a> {
Expand Down Expand Up @@ -98,14 +100,33 @@ impl<'a> ExtensionSchemas<'a> {
)
.map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;

Ok(Self { function_types })
// We already ensure that names of extension types do not collide, at the language level
let types_with_operator_overloading = extension_schemas
.iter()
.flat_map(|f| f.types_with_operator_overloading())
.collect();

Ok(Self {
function_types,
types_with_operator_overloading,
})
}

/// Get the [`ExtensionFunctionType`] for a function with this [`Name`].
/// Return `None` if no such function exists.
pub fn func_type(&self, name: &Name) -> Option<&ExtensionFunctionType> {
self.function_types.get(name).copied()
}

/// Query if `ext_ty_name` supports operator overloading
pub fn has_type_with_operator_overloading(&self, ext_ty_name: &Name) -> bool {
self.types_with_operator_overloading.contains(ext_ty_name)
}

/// Get all extension types that support operator overloading
pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
self.types_with_operator_overloading.iter().cloned()
}
}

/// Evaluates ane extension function on a single string literal argument. Used
Expand Down
9 changes: 7 additions & 2 deletions cedar-policy-validator/src/extensions/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ pub fn extension_schema() -> ExtensionSchema {
let datetime_ty = Type::extension(datetime_ext.name().clone());
//PANIC SAFETY: `duration` is a valid name
#[allow(clippy::unwrap_used)]
let duration_ty = Type::extension("duration".parse().unwrap());
let duration_ty_name: Name = "duration".parse().unwrap();
let duration_ty = Type::extension(duration_ty_name.clone());

let fun_tys = datetime_ext.funcs().map(|f| {
let return_type = get_return_type(f.name(), &datetime_ty, &duration_ty);
Expand All @@ -112,7 +113,11 @@ pub fn extension_schema() -> ExtensionSchema {
get_argument_check(f.name()),
)
});
ExtensionSchema::new(datetime_ext.name().clone(), fun_tys)
ExtensionSchema::new(
datetime_ext.name().clone(),
fun_tys,
[datetime_ext.name().clone(), duration_ty_name],
)
}

/// Extra validation step for the `datetime` function.
Expand Down
2 changes: 1 addition & 1 deletion cedar-policy-validator/src/extensions/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pub fn extension_schema() -> ExtensionSchema {
get_argument_check(f.name()),
)
});
ExtensionSchema::new(decimal_ext.name().clone(), fun_tys)
ExtensionSchema::new(decimal_ext.name().clone(), fun_tys, std::iter::empty())
}

/// Extra validation step for the `decimal` function.
Expand Down
2 changes: 1 addition & 1 deletion cedar-policy-validator/src/extensions/ipaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub fn extension_schema() -> ExtensionSchema {
get_argument_check(f.name()),
)
});
ExtensionSchema::new(ipaddr_ext.name().clone(), fun_tys)
ExtensionSchema::new(ipaddr_ext.name().clone(), fun_tys, std::iter::empty())
}

/// Extra validation step for the `ip` function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub fn extension_schema() -> ExtensionSchema {
None,
)
});
ExtensionSchema::new(pe_ext.name().clone(), fun_tys)
ExtensionSchema::new(pe_ext.name().clone(), fun_tys, std::iter::empty())
}

#[cfg(test)]
Expand Down
Loading
Loading