Skip to content

Commit 96cec46

Browse files
tomlikestorockcdisselkoen
authored andcommitted
Speed Improvement: Make incoming entities Arc<Entities> like schema entities (#1296)
Signed-off-by: Thomas Hill <[email protected]>
1 parent 95923dd commit 96cec46

File tree

3 files changed

+57
-21
lines changed

3 files changed

+57
-21
lines changed

cedar-policy-core/src/ast/entity.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use serde_with::{serde_as, TryFromInto};
3030
use smol_str::SmolStr;
3131
use std::collections::{BTreeMap, HashMap, HashSet};
3232
use std::str::FromStr;
33+
use std::sync::Arc;
3334
use thiserror::Error;
3435

3536
/// The entity type that Actions must have
@@ -605,6 +606,25 @@ impl TCNode<EntityUID> for Entity {
605606
}
606607
}
607608

609+
impl TCNode<EntityUID> for Arc<Entity> {
610+
fn get_key(&self) -> EntityUID {
611+
self.uid().clone()
612+
}
613+
614+
fn add_edge_to(&mut self, k: EntityUID) {
615+
// Use Arc::make_mut to get a mutable reference to the inner value
616+
Arc::make_mut(self).add_ancestor(k)
617+
}
618+
619+
fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
620+
Box::new(self.ancestors())
621+
}
622+
623+
fn has_edge_to(&self, e: &EntityUID) -> bool {
624+
self.is_descendant_of(e)
625+
}
626+
}
627+
608628
impl std::fmt::Display for Entity {
609629
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
610630
write!(

cedar-policy-core/src/entities.rs

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pub struct Entities {
5858
/// Important internal invariant: for any `Entities` object that exists, the
5959
/// the `ancestor` relation is transitively closed.
6060
#[serde_as(as = "Vec<(_, _)>")]
61-
entities: HashMap<EntityUID, Entity>,
61+
entities: HashMap<EntityUID, Arc<Entity>>,
6262

6363
/// The mode flag determines whether this store functions as a partial store or
6464
/// as a fully concrete store.
@@ -109,7 +109,7 @@ impl Entities {
109109

110110
/// Iterate over the `Entity`s in the `Entities`
111111
pub fn iter(&self) -> impl Iterator<Item = &Entity> {
112-
self.entities.values()
112+
self.entities.values().map(|e| e.as_ref())
113113
}
114114

115115
/// Adds the [`crate::ast::Entity`]s in the iterator to this [`Entities`].
@@ -125,7 +125,7 @@ impl Entities {
125125
/// responsible for ensuring that TC and DAG hold before calling this method.
126126
pub fn add_entities(
127127
mut self,
128-
collection: impl IntoIterator<Item = Entity>,
128+
collection: impl IntoIterator<Item = Arc<Entity>>,
129129
schema: Option<&impl Schema>,
130130
tc_computation: TCComputation,
131131
extensions: &Extensions<'_>,
@@ -174,7 +174,7 @@ impl Entities {
174174
tc_computation: TCComputation,
175175
extensions: &Extensions<'_>,
176176
) -> Result<Self> {
177-
let mut entity_map = create_entity_map(entities.into_iter())?;
177+
let mut entity_map = create_entity_map(entities.into_iter().map(Arc::new))?;
178178
if let Some(schema) = schema {
179179
// Validate non-action entities against schema.
180180
// We do this before adding the actions, because we trust the
@@ -213,7 +213,7 @@ impl Entities {
213213
schema
214214
.action_entities()
215215
.into_iter()
216-
.map(|e| (e.uid().clone(), Arc::unwrap_or_clone(e))),
216+
.map(|e: Arc<Entity>| (e.uid().clone(), e)),
217217
);
218218
}
219219
Ok(Self {
@@ -252,6 +252,7 @@ impl Entities {
252252
fn to_ejsons(&self) -> Result<Vec<EntityJson>> {
253253
self.entities
254254
.values()
255+
.map(Arc::as_ref)
255256
.map(EntityJson::from_entity)
256257
.collect::<std::result::Result<_, JsonSerializationError>>()
257258
.map_err(Into::into)
@@ -322,7 +323,9 @@ impl Entities {
322323
}
323324

324325
/// Create a map from EntityUids to Entities, erroring if there are any duplicates
325-
fn create_entity_map(es: impl Iterator<Item = Entity>) -> Result<HashMap<EntityUID, Entity>> {
326+
fn create_entity_map(
327+
es: impl Iterator<Item = Arc<Entity>>,
328+
) -> Result<HashMap<EntityUID, Arc<Entity>>> {
326329
let mut map = HashMap::new();
327330
for e in es {
328331
match map.entry(e.uid().clone()) {
@@ -338,10 +341,13 @@ fn create_entity_map(es: impl Iterator<Item = Entity>) -> Result<HashMap<EntityU
338341
impl IntoIterator for Entities {
339342
type Item = Entity;
340343

341-
type IntoIter = hash_map::IntoValues<EntityUID, Entity>;
344+
type IntoIter = std::iter::Map<
345+
std::collections::hash_map::IntoValues<EntityUID, Arc<Entity>>,
346+
fn(Arc<Entity>) -> Entity,
347+
>;
342348

343349
fn into_iter(self) -> Self::IntoIter {
344-
self.entities.into_values()
350+
self.entities.into_values().map(Arc::unwrap_or_clone)
345351
}
346352
}
347353

@@ -497,7 +503,8 @@ mod json_parsing_tests {
497503

498504
let addl_entities = parser
499505
.iter_from_json_value(new)
500-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
506+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
507+
.map(Arc::new);
501508
let err = simple_entities(&parser).add_entities(
502509
addl_entities,
503510
None::<&NoEntitiesSchema>,
@@ -537,7 +544,8 @@ mod json_parsing_tests {
537544

538545
let addl_entities = parser
539546
.iter_from_json_value(new)
540-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
547+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
548+
.map(Arc::new);
541549
let err = simple_entities(&parser).add_entities(
542550
addl_entities,
543551
None::<&NoEntitiesSchema>,
@@ -576,7 +584,8 @@ mod json_parsing_tests {
576584

577585
let addl_entities = parser
578586
.iter_from_json_value(new)
579-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
587+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
588+
.map(Arc::new);
580589
let err = simple_entities(&parser).add_entities(
581590
addl_entities,
582591
None::<&NoEntitiesSchema>,
@@ -619,7 +628,8 @@ mod json_parsing_tests {
619628

620629
let addl_entities = parser
621630
.iter_from_json_value(new)
622-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
631+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
632+
.map(Arc::new);
623633
let es = simple_entities(&parser)
624634
.add_entities(
625635
addl_entities,
@@ -658,7 +668,8 @@ mod json_parsing_tests {
658668

659669
let addl_entities = parser
660670
.iter_from_json_value(new)
661-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
671+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
672+
.map(Arc::new);
662673
let es = simple_entities(&parser)
663674
.add_entities(
664675
addl_entities,
@@ -699,7 +710,8 @@ mod json_parsing_tests {
699710

700711
let addl_entities = parser
701712
.iter_from_json_value(new)
702-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
713+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
714+
.map(Arc::new);
703715
let es = simple_entities(&parser)
704716
.add_entities(
705717
addl_entities,
@@ -739,7 +751,8 @@ mod json_parsing_tests {
739751

740752
let addl_entities = parser
741753
.iter_from_json_value(new)
742-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
754+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
755+
.map(Arc::new);
743756
let es = simple_entities(&parser)
744757
.add_entities(
745758
addl_entities,
@@ -766,7 +779,8 @@ mod json_parsing_tests {
766779

767780
let addl_entities = parser
768781
.iter_from_json_value(new)
769-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
782+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
783+
.map(Arc::new);
770784
let err = simple_entities(&parser)
771785
.add_entities(
772786
addl_entities,
@@ -787,7 +801,8 @@ mod json_parsing_tests {
787801
let new = serde_json::json!([{"uid":{ "type": "Test", "id": "alice" }, "attrs" : {}, "parents" : []}]);
788802
let addl_entities = parser
789803
.iter_from_json_value(new)
790-
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)));
804+
.unwrap_or_else(|e| panic!("{:?}", &miette::Report::new(e)))
805+
.map(Arc::new);
791806
let err = simple_entities(&parser).add_entities(
792807
addl_entities,
793808
None::<&NoEntitiesSchema>,

cedar-policy/src/api.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ use smol_str::SmolStr;
5858
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
5959
use std::io::Read;
6060
use std::str::FromStr;
61+
use std::sync::Arc;
6162

6263
// PANIC SAFETY: `CARGO_PKG_VERSION` should return a valid SemVer version string
6364
#[allow(clippy::unwrap_used)]
@@ -407,7 +408,7 @@ impl Entities {
407408
) -> Result<Self, EntitiesError> {
408409
Ok(Self(
409410
self.0.add_entities(
410-
entities.into_iter().map(|e| e.0),
411+
entities.into_iter().map(|e| Arc::new(e.0)),
411412
schema
412413
.map(|s| cedar_policy_validator::CoreSchema::new(&s.0))
413414
.as_ref(),
@@ -446,7 +447,7 @@ impl Entities {
446447
Extensions::all_available(),
447448
cedar_policy_core::entities::TCComputation::ComputeNow,
448449
);
449-
let new_entities = eparser.iter_from_json_str(json)?;
450+
let new_entities = eparser.iter_from_json_str(json)?.map(Arc::new);
450451
Ok(Self(self.0.add_entities(
451452
new_entities,
452453
schema.as_ref(),
@@ -484,7 +485,7 @@ impl Entities {
484485
Extensions::all_available(),
485486
cedar_policy_core::entities::TCComputation::ComputeNow,
486487
);
487-
let new_entities = eparser.iter_from_json_value(json)?;
488+
let new_entities = eparser.iter_from_json_value(json)?.map(Arc::new);
488489
Ok(Self(self.0.add_entities(
489490
new_entities,
490491
schema.as_ref(),
@@ -523,7 +524,7 @@ impl Entities {
523524
Extensions::all_available(),
524525
cedar_policy_core::entities::TCComputation::ComputeNow,
525526
);
526-
let new_entities = eparser.iter_from_json_file(json)?;
527+
let new_entities = eparser.iter_from_json_file(json)?.map(Arc::new);
527528
Ok(Self(self.0.add_entities(
528529
new_entities,
529530
schema.as_ref(),

0 commit comments

Comments
 (0)