Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions cedar-policy-core/src/ast/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use serde_with::{serde_as, TryFromInto};
use smol_str::SmolStr;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error;

/// The entity type that Actions must have
Expand Down Expand Up @@ -657,6 +658,25 @@ impl TCNode<EntityUID> for Entity {
}
}

impl TCNode<EntityUID> for Arc<Entity> {
fn get_key(&self) -> EntityUID {
self.uid().clone()
}

fn add_edge_to(&mut self, k: EntityUID) {
// Use Arc::make_mut to get a mutable reference to the inner value
Arc::make_mut(self).add_ancestor(k)
}

fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
Box::new(self.ancestors())
}

fn has_edge_to(&self, e: &EntityUID) -> bool {
self.is_descendant_of(e)
}
}

impl std::fmt::Display for Entity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
Expand Down Expand Up @@ -753,6 +773,13 @@ impl From<&Entity> for proto::Entity {
}
}

#[cfg(feature = "protobufs")]
impl From<&Arc<Entity>> for proto::Entity {
fn from(v: &Arc<Entity>) -> Self {
Self::from(v.as_ref())
}
}

/// `PartialValue`, but serialized as a `RestrictedExpr`.
///
/// (Extension values can't be directly serialized, but can be serialized as
Expand Down
91 changes: 57 additions & 34 deletions cedar-policy-core/src/entities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub struct Entities {
/// Important internal invariant: for any `Entities` object that exists, the
/// the `ancestor` relation is transitively closed.
#[serde_as(as = "Vec<(_, _)>")]
entities: HashMap<EntityUID, Entity>,
entities: HashMap<EntityUID, Arc<Entity>>,

/// The mode flag determines whether this store functions as a partial store or
/// as a fully concrete store.
Expand Down Expand Up @@ -109,7 +109,7 @@ impl Entities {

/// Iterate over the `Entity`s in the `Entities`
pub fn iter(&self) -> impl Iterator<Item = &Entity> {
self.entities.values()
self.entities.values().map(|e| e.as_ref())
}

/// Adds the [`crate::ast::Entity`]s in the iterator to this [`Entities`].
Expand All @@ -125,7 +125,7 @@ impl Entities {
/// responsible for ensuring that TC and DAG hold before calling this method.
pub fn add_entities(
mut self,
collection: impl IntoIterator<Item = Entity>,
collection: impl IntoIterator<Item = Arc<Entity>>,
schema: Option<&impl Schema>,
tc_computation: TCComputation,
extensions: &Extensions<'_>,
Expand Down Expand Up @@ -174,7 +174,7 @@ impl Entities {
tc_computation: TCComputation,
extensions: &Extensions<'_>,
) -> Result<Self> {
let mut entity_map = create_entity_map(entities.into_iter())?;
let mut entity_map = create_entity_map(entities.into_iter().map(Arc::new))?;
if let Some(schema) = schema {
// Validate non-action entities against schema.
// We do this before adding the actions, because we trust the
Expand Down Expand Up @@ -213,7 +213,7 @@ impl Entities {
schema
.action_entities()
.into_iter()
.map(|e| (e.uid().clone(), Arc::unwrap_or_clone(e))),
Copy link
Contributor

@john-h-kastner-aws john-h-kastner-aws Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a second to puzzle out why this actually gives a speed up, so I'll leave my reasoning here. Maybe worth elaborating in a comment.

We remove this unwrap_or_clone, but add a make_mut (which I assume must clone if there's another reference to the data) call inside TC computation, so at first glance it feels like that should be a wash. But, the TC was already computed for the entity map (which we just constructed, so one ref to each Arc), and we separately computed TC for the actions inside the schema constructor (while it had unique ownership). Put together these should mean we never trigger a clone on the make_mut during TC computation.

.map(|e: Arc<Entity>| (e.uid().clone(), e)),
);
}
Ok(Self {
Expand Down Expand Up @@ -252,6 +252,7 @@ impl Entities {
fn to_ejsons(&self) -> Result<Vec<EntityJson>> {
self.entities
.values()
.map(Arc::as_ref)
.map(EntityJson::from_entity)
.collect::<std::result::Result<_, JsonSerializationError>>()
.map_err(Into::into)
Expand Down Expand Up @@ -322,7 +323,9 @@ impl Entities {
}

/// Create a map from EntityUids to Entities, erroring if there are any duplicates
fn create_entity_map(es: impl Iterator<Item = Entity>) -> Result<HashMap<EntityUID, Entity>> {
fn create_entity_map(
es: impl Iterator<Item = Arc<Entity>>,
) -> Result<HashMap<EntityUID, Arc<Entity>>> {
let mut map = HashMap::new();
for e in es {
match map.entry(e.uid().clone()) {
Expand All @@ -338,10 +341,13 @@ fn create_entity_map(es: impl Iterator<Item = Entity>) -> Result<HashMap<EntityU
impl IntoIterator for Entities {
type Item = Entity;

type IntoIter = hash_map::IntoValues<EntityUID, Entity>;
type IntoIter = std::iter::Map<
std::collections::hash_map::IntoValues<EntityUID, Arc<Entity>>,
fn(Arc<Entity>) -> Entity,
>;

fn into_iter(self) -> Self::IntoIter {
self.entities.into_values()
self.entities.into_values().map(Arc::unwrap_or_clone)
}
}

Expand All @@ -363,7 +369,11 @@ impl From<&proto::Entities> for Entities {
// PANIC SAFETY: experimental feature
#[allow(clippy::expect_used)]
fn from(v: &proto::Entities) -> Self {
let entities: Vec<Entity> = v.entities.iter().map(Entity::from).collect();
let entities: Vec<Arc<Entity>> = v
.entities
.iter()
.map(|e| Arc::new(Entity::from(e)))
.collect();

#[cfg(not(feature = "partial-eval"))]
let result = Entities::new();
Expand Down Expand Up @@ -548,7 +558,8 @@ mod json_parsing_tests {

let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let err = simple_entities(&parser).add_entities(
addl_entities,
None::<&NoEntitiesSchema>,
Expand Down Expand Up @@ -588,7 +599,8 @@ mod json_parsing_tests {

let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let err = simple_entities(&parser).add_entities(
addl_entities,
None::<&NoEntitiesSchema>,
Expand Down Expand Up @@ -627,7 +639,8 @@ mod json_parsing_tests {

let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let err = simple_entities(&parser).add_entities(
addl_entities,
None::<&NoEntitiesSchema>,
Expand Down Expand Up @@ -670,7 +683,8 @@ mod json_parsing_tests {

let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let es = simple_entities(&parser)
.add_entities(
addl_entities,
Expand Down Expand Up @@ -709,7 +723,8 @@ mod json_parsing_tests {

let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let es = simple_entities(&parser)
.add_entities(
addl_entities,
Expand Down Expand Up @@ -750,7 +765,8 @@ mod json_parsing_tests {

let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let es = simple_entities(&parser)
.add_entities(
addl_entities,
Expand Down Expand Up @@ -790,7 +806,8 @@ mod json_parsing_tests {

let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let es = simple_entities(&parser)
.add_entities(
addl_entities,
Expand All @@ -817,7 +834,8 @@ mod json_parsing_tests {

let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let err = simple_entities(&parser)
.add_entities(
addl_entities,
Expand All @@ -838,7 +856,8 @@ mod json_parsing_tests {
let new = serde_json::json!([{"uid":{ "type": "Test", "id": "alice" }, "attrs" : {}, "parents" : []}]);
let addl_entities = parser
.iter_from_json_value(new)
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
.map(Arc::new);
let err = simple_entities(&parser).add_entities(
addl_entities,
None::<&NoEntitiesSchema>,
Expand Down Expand Up @@ -3531,14 +3550,16 @@ pub mod protobuf_tests {
let attrs = (1..=7)
.map(|id| (format!("{id}").into(), RestrictedExpr::val(true)))
.collect::<HashMap<SmolStr, _>>();
let entity: Entity = Entity::new(
r#"Foo::"bar""#.parse().unwrap(),
attrs.clone(),
HashSet::new(),
BTreeMap::new(),
&Extensions::none(),
)
.unwrap();
let entity: Arc<Entity> = Arc::new(
Entity::new(
r#"Foo::"bar""#.parse().unwrap(),
attrs.clone(),
HashSet::new(),
BTreeMap::new(),
&Extensions::none(),
)
.unwrap(),
);
let mut entities2: Entities = Entities::new();
entities2 = entities2
.add_entities(
Expand All @@ -3554,14 +3575,16 @@ pub mod protobuf_tests {
);

// Two Element Test
let entity2: Entity = Entity::new(
r#"Bar::"foo""#.parse().unwrap(),
attrs.clone(),
HashSet::new(),
BTreeMap::new(),
&Extensions::none(),
)
.unwrap();
let entity2: Arc<Entity> = Arc::new(
Entity::new(
r#"Bar::"foo""#.parse().unwrap(),
attrs.clone(),
HashSet::new(),
BTreeMap::new(),
&Extensions::none(),
)
.unwrap(),
);
let mut entities3: Entities = Entities::new();
entities3 = entities3
.add_entities(
Expand Down
9 changes: 5 additions & 4 deletions cedar-policy/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ use smol_str::SmolStr;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::io::Read;
use std::str::FromStr;
use std::sync::Arc;

// PANIC SAFETY: `CARGO_PKG_VERSION` should return a valid SemVer version string
#[allow(clippy::unwrap_used)]
Expand Down Expand Up @@ -407,7 +408,7 @@ impl Entities {
) -> Result<Self, EntitiesError> {
Ok(Self(
self.0.add_entities(
entities.into_iter().map(|e| e.0),
entities.into_iter().map(|e| Arc::new(e.0)),
schema
.map(|s| cedar_policy_validator::CoreSchema::new(&s.0))
.as_ref(),
Expand Down Expand Up @@ -446,7 +447,7 @@ impl Entities {
Extensions::all_available(),
cedar_policy_core::entities::TCComputation::ComputeNow,
);
let new_entities = eparser.iter_from_json_str(json)?;
let new_entities = eparser.iter_from_json_str(json)?.map(Arc::new);
Ok(Self(self.0.add_entities(
new_entities,
schema.as_ref(),
Expand Down Expand Up @@ -484,7 +485,7 @@ impl Entities {
Extensions::all_available(),
cedar_policy_core::entities::TCComputation::ComputeNow,
);
let new_entities = eparser.iter_from_json_value(json)?;
let new_entities = eparser.iter_from_json_value(json)?.map(Arc::new);
Ok(Self(self.0.add_entities(
new_entities,
schema.as_ref(),
Expand Down Expand Up @@ -523,7 +524,7 @@ impl Entities {
Extensions::all_available(),
cedar_policy_core::entities::TCComputation::ComputeNow,
);
let new_entities = eparser.iter_from_json_file(json)?;
let new_entities = eparser.iter_from_json_file(json)?.map(Arc::new);
Ok(Self(self.0.add_entities(
new_entities,
schema.as_ref(),
Expand Down