Skip to content

Commit 9a6deeb

Browse files
Update entities hashmap to hold Arc<Entity>, change add_entities to Arc the incoming entities, update function parameters to receive Arc entities, implement TCNode for Arc<Entity>
Signed-off-by: Thomas Hill <[email protected]>
1 parent dc9163c commit 9a6deeb

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

cedar-policy-core/src/ast/entity.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use serde_with::{serde_as, TryFromInto};
3030
use smol_str::SmolStr;
3131
use std::collections::{BTreeMap, HashMap, HashSet};
3232
use std::str::FromStr;
33+
use std::sync::Arc;
3334
use thiserror::Error;
3435

3536
/// The entity type that Actions must have
@@ -657,6 +658,25 @@ impl TCNode<EntityUID> for Entity {
657658
}
658659
}
659660

661+
impl TCNode<EntityUID> for Arc<Entity> {
662+
fn get_key(&self) -> EntityUID {
663+
self.uid().clone()
664+
}
665+
666+
fn add_edge_to(&mut self, k: EntityUID) {
667+
// Use Arc::make_mut to get a mutable reference to the inner value
668+
Arc::make_mut(self).add_ancestor(k)
669+
}
670+
671+
fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
672+
Box::new(self.ancestors())
673+
}
674+
675+
fn has_edge_to(&self, e: &EntityUID) -> bool {
676+
self.is_descendant_of(e)
677+
}
678+
}
679+
660680
impl std::fmt::Display for Entity {
661681
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
662682
write!(

cedar-policy-core/src/entities.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pub struct Entities {
5858
/// Important internal invariant: for any `Entities` object that exists, the
5959
/// the `ancestor` relation is transitively closed.
6060
#[serde_as(as = "Vec<(_, _)>")]
61-
entities: HashMap<EntityUID, Entity>,
61+
entities: HashMap<EntityUID, Arc<Entity>>,
6262

6363
/// The mode flag determines whether this store functions as a partial store or
6464
/// as a fully concrete store.
@@ -109,7 +109,7 @@ impl Entities {
109109

110110
/// Iterate over the `Entity`s in the `Entities`
111111
pub fn iter(&self) -> impl Iterator<Item = &Entity> {
112-
self.entities.values()
112+
self.entities.values().map(|e| e.as_ref())
113113
}
114114

115115
/// Adds the [`crate::ast::Entity`]s in the iterator to this [`Entities`].
@@ -140,7 +140,7 @@ impl Entities {
140140
return Err(EntitiesError::duplicate(entity.uid().clone()))
141141
}
142142
hash_map::Entry::Vacant(vacant_entry) => {
143-
vacant_entry.insert(entity);
143+
vacant_entry.insert(Arc::new(entity));
144144
}
145145
}
146146
}
@@ -213,7 +213,7 @@ impl Entities {
213213
schema
214214
.action_entities()
215215
.into_iter()
216-
.map(|e| (e.uid().clone(), Arc::unwrap_or_clone(e))),
216+
.map(|e: Arc<Entity>| (e.uid().clone(), e)),
217217
);
218218
}
219219
Ok(Self {
@@ -252,6 +252,7 @@ impl Entities {
252252
fn to_ejsons(&self) -> Result<Vec<EntityJson>> {
253253
self.entities
254254
.values()
255+
.map(Arc::as_ref)
255256
.map(EntityJson::from_entity)
256257
.collect::<std::result::Result<_, JsonSerializationError>>()
257258
.map_err(Into::into)
@@ -322,13 +323,13 @@ impl Entities {
322323
}
323324

324325
/// Create a map from EntityUids to Entities, erroring if there are any duplicates
325-
fn create_entity_map(es: impl Iterator<Item = Entity>) -> Result<HashMap<EntityUID, Entity>> {
326+
fn create_entity_map(es: impl Iterator<Item = Entity>) -> Result<HashMap<EntityUID, Arc<Entity>>> {
326327
let mut map = HashMap::new();
327328
for e in es {
328329
match map.entry(e.uid().clone()) {
329330
hash_map::Entry::Occupied(_) => return Err(EntitiesError::duplicate(e.uid().clone())),
330331
hash_map::Entry::Vacant(v) => {
331-
v.insert(e);
332+
v.insert(Arc::new(e));
332333
}
333334
};
334335
}
@@ -338,10 +339,14 @@ fn create_entity_map(es: impl Iterator<Item = Entity>) -> Result<HashMap<EntityU
338339
impl IntoIterator for Entities {
339340
type Item = Entity;
340341

341-
type IntoIter = hash_map::IntoValues<EntityUID, Entity>;
342+
type IntoIter = std::vec::IntoIter<Entity>;
342343

343344
fn into_iter(self) -> Self::IntoIter {
344-
self.entities.into_values()
345+
self.entities
346+
.into_values()
347+
.map(Arc::unwrap_or_clone)
348+
.collect::<Vec<Entity>>()
349+
.into_iter()
345350
}
346351
}
347352

0 commit comments

Comments
 (0)