Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions cedar-policy/protobuf_schema/core.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,19 @@ message Request {
EntityUid principal = 1;
EntityUid action = 2;
EntityUid resource = 3;
Expr context = 4;
map<string, Expr> context = 4;
}

// the protobuf PolicySet message describes a complete policy set, including
// templates, static policies, and/or template-linked policies.
message PolicySet {
// Key is PolicyID as a string.
// Value is a `TemplateBody`.
// Both templates and static policies are included in this map, with static
// Both templates and static policies are included here, with static
// policies represented as templates with zero slots.
map<string, TemplateBody> templates = 1;
// Key is PolicyID as a string.
// Value is a `Policy`.
// All static policies and template-linked policies are included in this map.
// Static policies must have exactly one entry in this map, and the PolicyID
// of the static policy must be the same in this map and the above map.
map<string, Policy> links = 2;
repeated TemplateBody templates = 1;
// All static policies and template-linked policies are included here.
// Static policies must appear exactly once, and the PolicyID of the static
// policy must be the same in this list and the above list.
repeated Policy links = 2;
}

message Entities {
Expand Down
46 changes: 42 additions & 4 deletions cedar-policy/src/proto/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use super::models;
use cedar_policy_core::{
ast, evaluator::RestrictedEvaluator, extensions::Extensions, FromNormalizedStr,
};
use smol_str::ToSmolStr;
use std::{collections::HashSet, sync::Arc};

// PANIC SAFETY: experimental feature
Expand Down Expand Up @@ -436,6 +437,12 @@ impl From<&ast::Expr> for models::Expr {
}
}

impl From<&ast::Value> for models::Expr {
fn from(v: &ast::Value) -> Self {
(&ast::Expr::from(v.clone())).into()
}
}

impl From<&models::expr::Var> for ast::Var {
fn from(v: &models::expr::Var) -> Self {
match v {
Expand Down Expand Up @@ -620,18 +627,49 @@ impl From<&models::Request> for ast::Request {
ast::EntityUIDEntry::from(v.principal.as_ref().expect("principal.as_ref()")),
ast::EntityUIDEntry::from(v.action.as_ref().expect("action.as_ref()")),
ast::EntityUIDEntry::from(v.resource.as_ref().expect("resource.as_ref()")),
v.context.as_ref().map(ast::Context::from),
Some(
ast::Context::from_pairs(
v.context.iter().map(|(k, v)| {
(
k.to_smolstr(),
ast::RestrictedExpr::new(ast::Expr::from(v))
.expect("encoded context should be a valid RestrictedExpr"),
)
}),
Extensions::all_available(),
)
.expect("encoded context should be valid"),
),
)
}
}

impl From<&ast::Request> for models::Request {
// PANIC SAFETY: experimental feature
#[allow(clippy::unimplemented)]
fn from(v: &ast::Request) -> Self {
Self {
principal: Some(models::EntityUid::from(v.principal())),
action: Some(models::EntityUid::from(v.action())),
resource: Some(models::EntityUid::from(v.resource())),
context: v.context().map(models::Expr::from),
context: {
let ctx = match v.context() {
Some(ctx) => ctx,
None => unimplemented!(
"Requests with unknown context currently cannot be modeled in protobuf"
),
};
match ctx {
ast::Context::Value(map) => map
.iter()
.map(|(k, v)| (k.to_string(), models::Expr::from(v)))
.collect(),
ast::Context::RestrictedResidual(map) => map
.iter()
.map(|(k, v)| (k.to_string(), models::Expr::from(v)))
.collect(),
}
},
}
}
}
Expand All @@ -642,10 +680,10 @@ impl From<&models::Expr> for ast::Context {
#[allow(clippy::expect_used)]
ast::Context::from_expr(
ast::BorrowedRestrictedExpr::new(&ast::Expr::from(v))
.expect("context should be valid restricted expr"),
.expect("encoded context should be valid restricted expr"),
Extensions::none(),
)
.expect("Context::from_expr")
.expect("encoded context should be valid")
}
}

Expand Down
118 changes: 72 additions & 46 deletions cedar-policy/src/proto/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,19 +427,27 @@ impl From<&ast::Effect> for models::Effect {
}

impl From<&models::PolicySet> for ast::LiteralPolicySet {
// PANIC SAFETY: experimental feature
#[allow(clippy::expect_used)]
fn from(v: &models::PolicySet) -> Self {
let templates = v.templates.iter().map(|(key, value)| {
let templates = v.templates.iter().map(|tb| {
(
ast::PolicyID::from_string(key),
ast::Template::from(ast::TemplateBody::from(value)),
ast::PolicyID::from_string(&tb.id),
ast::Template::from(ast::TemplateBody::from(tb)),
)
});

let links = v.links.iter().map(|(key, value)| {
(
ast::PolicyID::from_string(key),
ast::LiteralPolicy::from(value),
)
let links = v.links.iter().map(|p| {
// per docs in core.proto, for static policies, `link_id` is omitted/ignored,
// and the ID of the policy is the `template_id`.
let id = if p.is_template_link {
p.link_id
.as_ref()
.expect("template link should have a link_id")
} else {
&p.template_id
};
(ast::PolicyID::from_string(id), ast::LiteralPolicy::from(p))
});

Self::new(templates, links)
Expand All @@ -448,45 +456,16 @@ impl From<&models::PolicySet> for ast::LiteralPolicySet {

impl From<&ast::LiteralPolicySet> for models::PolicySet {
fn from(v: &ast::LiteralPolicySet) -> Self {
let templates = v
.templates()
.map(|template| {
(
String::from(template.id().as_ref()),
models::TemplateBody::from(template),
)
})
.collect();
let links = v
.policies()
.map(|policy| {
(
String::from(policy.id().as_ref()),
models::Policy::from(policy),
)
})
.collect();

let templates = v.templates().map(models::TemplateBody::from).collect();
let links = v.policies().map(models::Policy::from).collect();
Self { templates, links }
}
}

impl From<&ast::PolicySet> for models::PolicySet {
fn from(v: &ast::PolicySet) -> Self {
let templates: HashMap<String, models::TemplateBody> = v
.all_templates()
.map(|t| (String::from(t.id().as_ref()), models::TemplateBody::from(t)))
.collect();
let links: HashMap<String, models::Policy> = v
.policies()
.map(|policy| {
(
String::from(policy.id().as_ref()),
models::Policy::from(policy),
)
})
.collect();

let templates = v.all_templates().map(models::TemplateBody::from).collect();
let links = v.policies().map(models::Policy::from).collect();
Self { templates, links }
}
}
Expand All @@ -504,6 +483,39 @@ mod test {

use super::*;

// We add `PartialOrd` and `Ord` implementations for both `models::Policy` and
// `models::TemplateBody`, so that these can be sorted for testing purposes
impl Eq for models::Policy {}
impl PartialOrd for models::Policy {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for models::Policy {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// assumes that (link-id, template-id) pair is unique, otherwise we're
// technically violating `Ord` contract because there could exist two
// policies that return `Ordering::Equal` but are not equal with `Eq`
self.link_id()
.cmp(other.link_id())
.then_with(|| self.template_id.cmp(&other.template_id))
}
}
impl Eq for models::TemplateBody {}
impl PartialOrd for models::TemplateBody {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for models::TemplateBody {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// assumes that IDs are unique, otherwise we're technically violating
// `Ord` contract because there could exist two template-bodies that
// return `Ordering::Equal` but are not equal with `Eq`
self.id.cmp(&other.id)
}
}

#[test]
#[allow(clippy::too_many_lines)]
fn policy_roundtrip() {
Expand Down Expand Up @@ -739,10 +751,17 @@ mod test {
)]),
)
.unwrap();
let mps = models::PolicySet::from(&ps);
let mps_roundtrip = models::PolicySet::from(&ast::LiteralPolicySet::from(&mps));
let mut mps = models::PolicySet::from(&ps);
let mut mps_roundtrip = models::PolicySet::from(&ast::LiteralPolicySet::from(&mps));

// we accept permutations as equivalent, so before comparison, we sort
// both `.templates` and `.links`
mps.templates.sort();
mps_roundtrip.templates.sort();
mps.links.sort();
mps_roundtrip.links.sort();

// Can't compare LiteralPolicySets directly, so we compare their fields
// Can't compare `models::PolicySet` directly, so we compare their fields
assert_eq!(mps.templates, mps_roundtrip.templates);
assert_eq!(mps.links, mps_roundtrip.links);
}
Expand Down Expand Up @@ -802,8 +821,15 @@ mod test {
)]),
)
.unwrap();
let mps = models::PolicySet::from(&ps);
let mps_roundtrip = models::PolicySet::from(&ast::LiteralPolicySet::from(&mps));
let mut mps = models::PolicySet::from(&ps);
let mut mps_roundtrip = models::PolicySet::from(&ast::LiteralPolicySet::from(&mps));

// we accept permutations as equivalent, so before comparison, we sort
// both `.templates` and `.links`
mps.templates.sort();
mps_roundtrip.templates.sort();
mps.links.sort();
mps_roundtrip.links.sort();

// Can't compare `models::PolicySet` directly, so we compare their fields
assert_eq!(mps.templates, mps_roundtrip.templates);
Expand Down
Loading