Skip to content

Commit 14e2880

Browse files
cdisselkoenjuanvgarcia
authored andcommitted
a few more tweaks to protobuf format (cedar-policy#1535)
Signed-off-by: Craig Disselkoen <[email protected]>
1 parent 723080d commit 14e2880

File tree

3 files changed

+121
-61
lines changed

3 files changed

+121
-61
lines changed

cedar-policy/protobuf_schema/core.proto

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,19 @@ message Request {
2121
EntityUid principal = 1;
2222
EntityUid action = 2;
2323
EntityUid resource = 3;
24-
Expr context = 4;
24+
map<string, Expr> context = 4;
2525
}
2626

2727
// the protobuf PolicySet message describes a complete policy set, including
2828
// templates, static policies, and/or template-linked policies.
2929
message PolicySet {
30-
// Key is PolicyID as a string.
31-
// Value is a `TemplateBody`.
32-
// Both templates and static policies are included in this map, with static
30+
// Both templates and static policies are included here, with static
3331
// policies represented as templates with zero slots.
34-
map<string, TemplateBody> templates = 1;
35-
// Key is PolicyID as a string.
36-
// Value is a `Policy`.
37-
// All static policies and template-linked policies are included in this map.
38-
// Static policies must have exactly one entry in this map, and the PolicyID
39-
// of the static policy must be the same in this map and the above map.
40-
map<string, Policy> links = 2;
32+
repeated TemplateBody templates = 1;
33+
// All static policies and template-linked policies are included here.
34+
// Static policies must appear exactly once, and the PolicyID of the static
35+
// policy must be the same in this list and the above list.
36+
repeated Policy links = 2;
4137
}
4238

4339
message Entities {

cedar-policy/src/proto/ast.rs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use super::models;
2020
use cedar_policy_core::{
2121
ast, evaluator::RestrictedEvaluator, extensions::Extensions, FromNormalizedStr,
2222
};
23+
use smol_str::ToSmolStr;
2324
use std::{collections::HashSet, sync::Arc};
2425

2526
// PANIC SAFETY: experimental feature
@@ -436,6 +437,12 @@ impl From<&ast::Expr> for models::Expr {
436437
}
437438
}
438439

440+
impl From<&ast::Value> for models::Expr {
441+
fn from(v: &ast::Value) -> Self {
442+
(&ast::Expr::from(v.clone())).into()
443+
}
444+
}
445+
439446
impl From<&models::expr::Var> for ast::Var {
440447
fn from(v: &models::expr::Var) -> Self {
441448
match v {
@@ -620,18 +627,49 @@ impl From<&models::Request> for ast::Request {
620627
ast::EntityUIDEntry::from(v.principal.as_ref().expect("principal.as_ref()")),
621628
ast::EntityUIDEntry::from(v.action.as_ref().expect("action.as_ref()")),
622629
ast::EntityUIDEntry::from(v.resource.as_ref().expect("resource.as_ref()")),
623-
v.context.as_ref().map(ast::Context::from),
630+
Some(
631+
ast::Context::from_pairs(
632+
v.context.iter().map(|(k, v)| {
633+
(
634+
k.to_smolstr(),
635+
ast::RestrictedExpr::new(ast::Expr::from(v))
636+
.expect("encoded context should be a valid RestrictedExpr"),
637+
)
638+
}),
639+
Extensions::all_available(),
640+
)
641+
.expect("encoded context should be valid"),
642+
),
624643
)
625644
}
626645
}
627646

628647
impl From<&ast::Request> for models::Request {
648+
// PANIC SAFETY: experimental feature
649+
#[allow(clippy::unimplemented)]
629650
fn from(v: &ast::Request) -> Self {
630651
Self {
631652
principal: Some(models::EntityUid::from(v.principal())),
632653
action: Some(models::EntityUid::from(v.action())),
633654
resource: Some(models::EntityUid::from(v.resource())),
634-
context: v.context().map(models::Expr::from),
655+
context: {
656+
let ctx = match v.context() {
657+
Some(ctx) => ctx,
658+
None => unimplemented!(
659+
"Requests with unknown context currently cannot be modeled in protobuf"
660+
),
661+
};
662+
match ctx {
663+
ast::Context::Value(map) => map
664+
.iter()
665+
.map(|(k, v)| (k.to_string(), models::Expr::from(v)))
666+
.collect(),
667+
ast::Context::RestrictedResidual(map) => map
668+
.iter()
669+
.map(|(k, v)| (k.to_string(), models::Expr::from(v)))
670+
.collect(),
671+
}
672+
},
635673
}
636674
}
637675
}
@@ -642,10 +680,10 @@ impl From<&models::Expr> for ast::Context {
642680
#[allow(clippy::expect_used)]
643681
ast::Context::from_expr(
644682
ast::BorrowedRestrictedExpr::new(&ast::Expr::from(v))
645-
.expect("context should be valid restricted expr"),
683+
.expect("encoded context should be valid restricted expr"),
646684
Extensions::none(),
647685
)
648-
.expect("Context::from_expr")
686+
.expect("encoded context should be valid")
649687
}
650688
}
651689

cedar-policy/src/proto/policy.rs

Lines changed: 72 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -427,19 +427,27 @@ impl From<&ast::Effect> for models::Effect {
427427
}
428428

429429
impl From<&models::PolicySet> for ast::LiteralPolicySet {
430+
// PANIC SAFETY: experimental feature
431+
#[allow(clippy::expect_used)]
430432
fn from(v: &models::PolicySet) -> Self {
431-
let templates = v.templates.iter().map(|(key, value)| {
433+
let templates = v.templates.iter().map(|tb| {
432434
(
433-
ast::PolicyID::from_string(key),
434-
ast::Template::from(ast::TemplateBody::from(value)),
435+
ast::PolicyID::from_string(&tb.id),
436+
ast::Template::from(ast::TemplateBody::from(tb)),
435437
)
436438
});
437439

438-
let links = v.links.iter().map(|(key, value)| {
439-
(
440-
ast::PolicyID::from_string(key),
441-
ast::LiteralPolicy::from(value),
442-
)
440+
let links = v.links.iter().map(|p| {
441+
// per docs in core.proto, for static policies, `link_id` is omitted/ignored,
442+
// and the ID of the policy is the `template_id`.
443+
let id = if p.is_template_link {
444+
p.link_id
445+
.as_ref()
446+
.expect("template link should have a link_id")
447+
} else {
448+
&p.template_id
449+
};
450+
(ast::PolicyID::from_string(id), ast::LiteralPolicy::from(p))
443451
});
444452

445453
Self::new(templates, links)
@@ -448,45 +456,16 @@ impl From<&models::PolicySet> for ast::LiteralPolicySet {
448456

449457
impl From<&ast::LiteralPolicySet> for models::PolicySet {
450458
fn from(v: &ast::LiteralPolicySet) -> Self {
451-
let templates = v
452-
.templates()
453-
.map(|template| {
454-
(
455-
String::from(template.id().as_ref()),
456-
models::TemplateBody::from(template),
457-
)
458-
})
459-
.collect();
460-
let links = v
461-
.policies()
462-
.map(|policy| {
463-
(
464-
String::from(policy.id().as_ref()),
465-
models::Policy::from(policy),
466-
)
467-
})
468-
.collect();
469-
459+
let templates = v.templates().map(models::TemplateBody::from).collect();
460+
let links = v.policies().map(models::Policy::from).collect();
470461
Self { templates, links }
471462
}
472463
}
473464

474465
impl From<&ast::PolicySet> for models::PolicySet {
475466
fn from(v: &ast::PolicySet) -> Self {
476-
let templates: HashMap<String, models::TemplateBody> = v
477-
.all_templates()
478-
.map(|t| (String::from(t.id().as_ref()), models::TemplateBody::from(t)))
479-
.collect();
480-
let links: HashMap<String, models::Policy> = v
481-
.policies()
482-
.map(|policy| {
483-
(
484-
String::from(policy.id().as_ref()),
485-
models::Policy::from(policy),
486-
)
487-
})
488-
.collect();
489-
467+
let templates = v.all_templates().map(models::TemplateBody::from).collect();
468+
let links = v.policies().map(models::Policy::from).collect();
490469
Self { templates, links }
491470
}
492471
}
@@ -504,6 +483,39 @@ mod test {
504483

505484
use super::*;
506485

486+
// We add `PartialOrd` and `Ord` implementations for both `models::Policy` and
487+
// `models::TemplateBody`, so that these can be sorted for testing purposes
488+
impl Eq for models::Policy {}
489+
impl PartialOrd for models::Policy {
490+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
491+
Some(self.cmp(other))
492+
}
493+
}
494+
impl Ord for models::Policy {
495+
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
496+
// assumes that (link-id, template-id) pair is unique, otherwise we're
497+
// technically violating `Ord` contract because there could exist two
498+
// policies that return `Ordering::Equal` but are not equal with `Eq`
499+
self.link_id()
500+
.cmp(other.link_id())
501+
.then_with(|| self.template_id.cmp(&other.template_id))
502+
}
503+
}
504+
impl Eq for models::TemplateBody {}
505+
impl PartialOrd for models::TemplateBody {
506+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
507+
Some(self.cmp(other))
508+
}
509+
}
510+
impl Ord for models::TemplateBody {
511+
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
512+
// assumes that IDs are unique, otherwise we're technically violating
513+
// `Ord` contract because there could exist two template-bodies that
514+
// return `Ordering::Equal` but are not equal with `Eq`
515+
self.id.cmp(&other.id)
516+
}
517+
}
518+
507519
#[test]
508520
#[allow(clippy::too_many_lines)]
509521
fn policy_roundtrip() {
@@ -739,10 +751,17 @@ mod test {
739751
)]),
740752
)
741753
.unwrap();
742-
let mps = models::PolicySet::from(&ps);
743-
let mps_roundtrip = models::PolicySet::from(&ast::LiteralPolicySet::from(&mps));
754+
let mut mps = models::PolicySet::from(&ps);
755+
let mut mps_roundtrip = models::PolicySet::from(&ast::LiteralPolicySet::from(&mps));
756+
757+
// we accept permutations as equivalent, so before comparison, we sort
758+
// both `.templates` and `.links`
759+
mps.templates.sort();
760+
mps_roundtrip.templates.sort();
761+
mps.links.sort();
762+
mps_roundtrip.links.sort();
744763

745-
// Can't compare LiteralPolicySets directly, so we compare their fields
764+
// Can't compare `models::PolicySet` directly, so we compare their fields
746765
assert_eq!(mps.templates, mps_roundtrip.templates);
747766
assert_eq!(mps.links, mps_roundtrip.links);
748767
}
@@ -802,8 +821,15 @@ mod test {
802821
)]),
803822
)
804823
.unwrap();
805-
let mps = models::PolicySet::from(&ps);
806-
let mps_roundtrip = models::PolicySet::from(&ast::LiteralPolicySet::from(&mps));
824+
let mut mps = models::PolicySet::from(&ps);
825+
let mut mps_roundtrip = models::PolicySet::from(&ast::LiteralPolicySet::from(&mps));
826+
827+
// we accept permutations as equivalent, so before comparison, we sort
828+
// both `.templates` and `.links`
829+
mps.templates.sort();
830+
mps_roundtrip.templates.sort();
831+
mps.links.sort();
832+
mps_roundtrip.links.sort();
807833

808834
// Can't compare `models::PolicySet` directly, so we compare their fields
809835
assert_eq!(mps.templates, mps_roundtrip.templates);

0 commit comments

Comments
 (0)