Skip to content

Commit 542d52f

Browse files
Cherry-pick SymCC fixes/refactorings to 4.6.x (#1872)
Signed-off-by: Shaobo He <[email protected]>
1 parent 11fe99c commit 542d52f

File tree

15 files changed

+154
-106
lines changed

15 files changed

+154
-106
lines changed

cedar-policy-core/src/extensions/decimal.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,12 @@ impl Decimal {
162162

163163
impl std::fmt::Display for Decimal {
164164
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165-
write!(
166-
f,
167-
"{}.{}",
168-
self.value / i64::pow(10, NUM_DIGITS),
169-
(self.value % i64::pow(10, NUM_DIGITS)).abs()
170-
)
165+
let abs = i128::from(self.value).abs();
166+
if self.value.is_negative() {
167+
write!(f, "-")?;
168+
}
169+
let pow = i128::pow(10, NUM_DIGITS);
170+
write!(f, "{}.{:04}", abs / pow, abs % pow)
171171
}
172172
}
173173

@@ -664,7 +664,7 @@ mod tests {
664664

665665
fn check_round_trip(s: &str) {
666666
let d = Decimal::from_str(s).expect("should be a valid decimal");
667-
assert_eq!(s, d.to_string());
667+
assert_eq!(d, Decimal::from_str(d.to_string()).unwrap());
668668
}
669669

670670
#[test]

cedar-policy-symcc/src/symcc/concretizer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ impl Term {
282282
impl Udf {
283283
fn get_all_entity_uids(&self, uids: &mut BTreeSet<EntityUid>) {
284284
self.default.get_all_entity_uids(uids);
285-
for (k, v) in &self.table {
285+
for (k, v) in self.table.iter() {
286286
k.get_all_entity_uids(uids);
287287
v.get_all_entity_uids(uids);
288288
}

cedar-policy-symcc/src/symcc/decoder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ impl Uuf {
676676
Udf {
677677
arg: self.arg.clone(),
678678
out: self.out.clone(),
679-
table: BTreeMap::new(),
679+
table: Arc::new(BTreeMap::new()),
680680
default: self.out.default_literal(env),
681681
}
682682
}
@@ -1139,7 +1139,7 @@ impl SExpr {
11391139
Udf {
11401140
arg: uuf.arg.clone(),
11411141
out: uuf.out.clone(),
1142-
table,
1142+
table: Arc::new(table),
11431143
default,
11441144
},
11451145
));

cedar-policy-symcc/src/symcc/encoder.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,13 @@ impl<S: tokio::io::AsyncWrite + Unpin + Send> Encoder<'_, S> {
329329
ty_enc: &str,
330330
t_encs: impl IntoIterator<Item = &'s str>,
331331
) -> Result<String> {
332-
self.define_term(
333-
ty_enc,
334-
&format!("({ty_enc} {})", t_encs.into_iter().join(" ")),
335-
)
336-
.await
332+
let t_encs = t_encs.into_iter().join(" ");
333+
let t_enc = if t_encs.is_empty() {
334+
format!("{ty_enc}")
335+
} else {
336+
format!("({ty_enc} {})", t_encs)
337+
};
338+
self.define_term(ty_enc, &t_enc).await
337339
}
338340

339341
pub async fn encode_uuf(&mut self, uuf: &Uuf) -> Result<String> {
@@ -827,7 +829,7 @@ mod unit_tests {
827829
};
828830
let mut encoder = Encoder::new(&symenv, Vec::<u8>::new()).unwrap();
829831
let my_uuf = crate::symcc::op::Uuf {
830-
id: "my_fun".to_string(),
832+
id: "my_fun".into(),
831833
arg: TermType::Bool,
832834
out: TermType::Bool,
833835
};

cedar-policy-symcc/src/symcc/env.rs

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use cedar_policy_core::validator::{
3535
ValidatorActionId,
3636
};
3737
use cedar_policy_core::validator::{ValidatorEntityType, ValidatorEntityTypeKind};
38-
use smol_str::SmolStr;
38+
use smol_str::{format_smolstr, SmolStr};
3939
use std::collections::{BTreeMap, BTreeSet};
4040
use std::ops::Deref;
4141
use std::sync::Arc;
@@ -55,15 +55,15 @@ impl SymRequest {
5555
pub fn empty_sym_req() -> Self {
5656
SymRequest {
5757
principal: Term::Var(TermVar {
58-
id: "principal".to_string(),
58+
id: "principal".into(),
5959
ty: TermType::Bool,
6060
}),
6161
action: Term::Var(TermVar {
62-
id: "action".to_string(),
62+
id: "action".into(),
6363
ty: TermType::Bool,
6464
}),
6565
resource: Term::Var(TermVar {
66-
id: "resource".to_string(),
66+
id: "resource".into(),
6767
ty: TermType::Bool,
6868
}),
6969
context: Term::Record(Arc::new(BTreeMap::new())),
@@ -183,30 +183,30 @@ impl SymEntityData {
183183
match EntitySchemaEntry::of_schema(ety, validator_ety, schema) {
184184
// Corresponds to `SymEntityData.ofStandardEntityType` in Lean
185185
EntitySchemaEntry::Standard(sch) => {
186-
let attrs_uuf = Uuf(op::Uuf {
187-
id: format!("attrs[{ety}]"),
186+
let attrs_uuf = Uuf(Arc::new(op::Uuf {
187+
id: format_smolstr!("attrs[{ety}]"),
188188
arg: entity(ety.clone()), // more efficient than the Lean: avoids `TermType::of_type()` and constructs the `TermType` directly
189189
out: TermType::of_type(&record(sch.attrs))?,
190-
});
190+
}));
191191
let ancs_uuf = |anc_ty: &EntityType| {
192-
Uuf(op::Uuf {
193-
id: format!("ancs[{ety}, {anc_ty}]"),
192+
Uuf(Arc::new(op::Uuf {
193+
id: format_smolstr!("ancs[{ety}, {anc_ty}]"),
194194
arg: entity(ety.clone()), // more efficient than the Lean: avoids `TermType::of_type()` and constructs the `TermType` directly
195195
out: TermType::set_of(entity(anc_ty.clone())), // more efficient than the Lean: avoids `TermType::of_type()` and constructs the `TermType` directly
196-
})
196+
}))
197197
};
198198
let sym_tags = |tag_ty: Type| -> Result<SymTags, CompileError> {
199199
Ok(SymTags {
200-
keys: Uuf(op::Uuf {
201-
id: format!("tagKeys[{ety}]"),
200+
keys: Uuf(Arc::new(op::Uuf {
201+
id: format_smolstr!("tagKeys[{ety}]"),
202202
arg: entity(ety.clone()), // more efficient than the Lean: avoids `TermType::of_type()` and constructs the `TermType` directly
203203
out: TermType::set_of(TermType::String),
204-
}),
205-
vals: Uuf(op::Uuf {
206-
id: format!("tagVals[{ety}]"),
204+
})),
205+
vals: Uuf(Arc::new(op::Uuf {
206+
id: format_smolstr!("tagVals[{ety}]"),
207207
arg: TermType::tag_for(ety.clone()), // record representing the pair type (ety, .string)
208208
out: TermType::of_type(&tag_ty)?,
209-
}),
209+
})),
210210
})
211211
};
212212

@@ -227,14 +227,14 @@ impl SymEntityData {
227227

228228
// Corresponds to `SymEntityData.ofEnumEntityType` in Lean
229229
EntitySchemaEntry::Enum(eids) => {
230-
let attrs_udf = Udf(function::Udf {
230+
let attrs_udf = Udf(Arc::new(function::Udf {
231231
arg: entity(ety.clone()),
232232
out: TermType::Record {
233233
rty: Arc::new(BTreeMap::new()),
234234
},
235-
table: BTreeMap::new(),
235+
table: Arc::new(BTreeMap::new()),
236236
default: Term::Record(Arc::new(BTreeMap::new())),
237-
});
237+
}));
238238
Ok(SymEntityData {
239239
attrs: attrs_udf,
240240
ancestors: BTreeMap::new(),
@@ -251,14 +251,14 @@ impl SymEntityData {
251251
schema: &ValidatorSchema,
252252
) -> Self {
253253
let sch = ActionSchemaEntries::of_schema(schema);
254-
let attrs_udf = Udf(function::Udf {
254+
let attrs_udf = Udf(Arc::new(function::Udf {
255255
arg: entity(act_ty.clone()),
256256
out: TermType::Record {
257257
rty: Arc::new(BTreeMap::new()),
258258
},
259-
table: BTreeMap::new(),
259+
table: Arc::new(BTreeMap::new()),
260260
default: Term::Record(Arc::new(BTreeMap::new())),
261-
});
261+
}));
262262
let term_of_type = |ety: EntityType, uid: EntityUID| -> Option<Term> {
263263
if uid.type_name() == &ety {
264264
Some(Term::Prim(TermPrim::Entity(uid)))
@@ -277,23 +277,24 @@ impl SymEntityData {
277277
}
278278
};
279279
let ancs_udf = |anc_ty: &EntityType| -> UnaryFunction {
280-
Udf(function::Udf {
280+
Udf(Arc::new(function::Udf {
281281
arg: entity(act_ty.clone()),
282282
out: TermType::set_of(entity(anc_ty.clone())),
283-
table: sch
284-
.iter()
285-
.filter_map(|(uid, entry)| {
286-
Some((
287-
term_of_type(act_ty.clone(), uid.clone())?,
288-
ancs_term(anc_ty, &entry.ancestors),
289-
))
290-
})
291-
.collect(),
283+
table: Arc::new(
284+
sch.iter()
285+
.filter_map(|(uid, entry)| {
286+
Some((
287+
term_of_type(act_ty.clone(), uid.clone())?,
288+
ancs_term(anc_ty, &entry.ancestors),
289+
))
290+
})
291+
.collect(),
292+
),
292293
default: Term::Set {
293294
elts: Arc::new(BTreeSet::new()),
294295
elts_ty: entity(anc_ty.clone()),
295296
},
296-
})
297+
}))
297298
};
298299
let acts = sch
299300
.iter()
@@ -372,20 +373,20 @@ impl SymRequest {
372373
fn of_request_type(req_ty: &RequestType<'_>) -> Result<Self, CompileError> {
373374
Ok(Self {
374375
principal: Term::Var(TermVar {
375-
id: "principal".to_string(),
376+
id: "principal".into(),
376377
ty: TermType::Entity {
377378
ety: req_ty.principal.clone(),
378379
},
379380
}),
380381
action: Term::Prim(TermPrim::Entity(req_ty.action.clone())),
381382
resource: Term::Var(TermVar {
382-
id: "resource".to_string(),
383+
id: "resource".into(),
383384
ty: TermType::Entity {
384385
ety: req_ty.resource.clone(),
385386
},
386387
}),
387388
context: Term::Var(TermVar {
388-
id: "context".to_string(),
389+
id: "context".into(),
389390
ty: TermType::of_type(&record(req_ty.context.clone()))?,
390391
}),
391392
})

cedar-policy-symcc/src/symcc/extension_types/decimal.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ impl FromStr for Decimal {
107107

108108
impl std::fmt::Display for Decimal {
109109
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110-
write!(
111-
f,
112-
"{}.{}",
113-
self.0 / i64::pow(10, DECIMAL_DIGITS),
114-
(self.0 % i64::pow(10, DECIMAL_DIGITS)).abs()
115-
)
110+
let abs = i128::from(self.0).abs();
111+
if self.0.is_negative() {
112+
write!(f, "-")?;
113+
}
114+
let pow = i128::pow(10, DECIMAL_DIGITS);
115+
write!(f, "{}.{:04}", abs / pow, abs % pow)
116116
}
117117
}
118118

cedar-policy-symcc/src/symcc/extractor.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
//! and transitive (assuming the suitable acyclicity and transitivity
2626
//! constraints are satisfied for the footprint).
2727
28-
use std::collections::BTreeSet;
28+
use std::{collections::BTreeSet, sync::Arc};
2929

3030
use cedar_policy_core::ast::Expr;
3131

@@ -57,17 +57,21 @@ impl Uuf {
5757
if uid.type_name() == arg_ety {
5858
let t = Term::Prim(TermPrim::Entity(uid.clone()));
5959
// In the domain of this ancestor function
60-
Some((t.clone(), factory::app(UnaryFunction::Udf(udf.clone()), t)))
60+
Some((
61+
t.clone(),
62+
factory::app(UnaryFunction::Udf(Arc::new(udf.clone())), t),
63+
))
6164
} else {
6265
None
6366
}
6467
})
6568
.collect();
6669

6770
Udf {
68-
table: new_table,
69-
default: udf.out.default_literal(interp.env), // i.e., empty set
70-
..udf
71+
table: Arc::new(new_table),
72+
default: udf.default.clone(),
73+
arg: udf.arg.clone(),
74+
out: udf.out,
7175
}
7276
}
7377
}
@@ -98,7 +102,7 @@ impl Interpretation<'_> {
98102
for fun in ent_data.ancestors.values() {
99103
if let UnaryFunction::Uuf(uuf) = fun {
100104
funs.insert(
101-
uuf.clone(),
105+
uuf.as_ref().clone(),
102106
uuf.repair_as_counterexample(ety, &footprint_uids, self),
103107
);
104108
}

cedar-policy-symcc/src/symcc/factory.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ pub fn app(f: UnaryFunction, t: Term) -> Term {
215215
UnaryFunction::Uuf(f) => {
216216
let ret_ty = f.out.clone();
217217
Term::App {
218-
op: Op::Uuf(Arc::new(f)),
218+
op: Op::Uuf(f),
219219
args: Arc::new(vec![t]),
220220
ret_ty,
221221
}
@@ -224,10 +224,10 @@ pub fn app(f: UnaryFunction, t: Term) -> Term {
224224
if t.is_literal() {
225225
match f.table.get(&t) {
226226
Some(v) => v.clone(),
227-
None => f.default,
227+
None => f.default.clone(),
228228
}
229229
} else {
230-
f.table.iter().rfold(f.default, |acc, (t1, t2)| {
230+
f.table.iter().rfold(f.default.clone(), |acc, (t1, t2)| {
231231
ite(eq(t.clone(), t1.clone()), t2.clone(), acc)
232232
})
233233
}

cedar-policy-symcc/src/symcc/function.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
use std::collections::BTreeMap;
17+
use std::{collections::BTreeMap, sync::Arc};
1818

1919
use super::{op::Uuf, term::Term, term_type::TermType};
2020

@@ -23,7 +23,7 @@ use super::{op::Uuf, term::Term, term_type::TermType};
2323
pub struct Udf {
2424
pub arg: TermType,
2525
pub out: TermType,
26-
pub table: BTreeMap<Term, Term>,
26+
pub table: Arc<BTreeMap<Term, Term>>,
2727
pub default: Term,
2828
}
2929

@@ -50,22 +50,22 @@ impl Udf {
5050
/// solver (CVC5) always returns interpretations of this form.
5151
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd)]
5252
pub enum UnaryFunction {
53-
Uuf(Uuf),
54-
Udf(Udf),
53+
Uuf(Arc<Uuf>),
54+
Udf(Arc<Udf>),
5555
}
5656

5757
impl UnaryFunction {
58-
pub fn arg_type(self) -> TermType {
58+
pub fn arg_type(&self) -> &TermType {
5959
match self {
60-
UnaryFunction::Uuf(f) => f.arg,
61-
UnaryFunction::Udf(f) => f.arg,
60+
UnaryFunction::Uuf(f) => &f.arg,
61+
UnaryFunction::Udf(f) => &f.arg,
6262
}
6363
}
6464

65-
pub fn out_type(self) -> TermType {
65+
pub fn out_type(&self) -> &TermType {
6666
match self {
67-
UnaryFunction::Uuf(f) => f.out,
68-
UnaryFunction::Udf(f) => f.out,
67+
UnaryFunction::Uuf(f) => &f.out,
68+
UnaryFunction::Udf(f) => &f.out,
6969
}
7070
}
7171

0 commit comments

Comments
 (0)