|
| 1 | +/* |
| 2 | + * Copyright 2022-2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +//! Definition of a `CedarTestImplementation` trait that describes an |
| 18 | +//! implementation of Cedar to use during testing. |
| 19 | +
|
| 20 | +pub use cedar_policy::frontend::is_authorized::InterfaceResponse; |
| 21 | +use cedar_policy_core::ast::{Expr, PolicySet, Request, Value}; |
| 22 | +use cedar_policy_core::authorizer::Authorizer; |
| 23 | +use cedar_policy_core::entities::Entities; |
| 24 | +use cedar_policy_core::evaluator::Evaluator; |
| 25 | +use cedar_policy_core::extensions::Extensions; |
| 26 | +use cedar_policy_validator::{ValidationMode, Validator, ValidatorSchema}; |
| 27 | +use serde::Deserialize; |
| 28 | +use std::collections::HashMap; |
| 29 | +use std::time::{Duration, Instant}; |
| 30 | + |
| 31 | +/// Return type for `CedarTestImplementation` methods |
| 32 | +#[derive(Debug, Deserialize)] |
| 33 | +pub enum TestResult<T> { |
| 34 | + /// The request succeeded |
| 35 | + Success(T), |
| 36 | + /// The request failed (e.g., due to a parse error) |
| 37 | + Failure(String), |
| 38 | +} |
| 39 | + |
| 40 | +impl<T> TestResult<T> { |
| 41 | + /// Get the underlying value of a `TestResult`. |
| 42 | + /// # Panics |
| 43 | + /// If the `TestResult` is a `Failure`. |
| 44 | + /// PANIC SAFETY only used in testing code |
| 45 | + #[allow(clippy::panic)] |
| 46 | + pub fn expect(self, msg: &str) -> T { |
| 47 | + match self { |
| 48 | + Self::Success(t) => t, |
| 49 | + Self::Failure(err) => panic!("{msg}: {err}"), |
| 50 | + } |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +/// Simple wrapper around u128 to remind ourselves that timing info is in microseconds. |
| 55 | +#[derive(Debug, Deserialize)] |
| 56 | +pub struct Micros(pub u128); |
| 57 | + |
| 58 | +/// Version of `Response` used for testing. Includes an `InterfaceResponse` and |
| 59 | +/// a map with timing information. |
| 60 | +#[derive(Debug, Deserialize)] |
| 61 | +pub struct TestResponse { |
| 62 | + /// Actual response |
| 63 | + pub response: InterfaceResponse, |
| 64 | + /// Timing info in microseconds. This field is a `HashMap` to allow timing |
| 65 | + /// multiple components (or none at all). |
| 66 | + pub timing_info: HashMap<String, Micros>, |
| 67 | +} |
| 68 | + |
| 69 | +/// Version of `ValidationResult` used for testing. |
| 70 | +#[derive(Debug, Deserialize)] |
| 71 | +pub struct TestValidationResult { |
| 72 | + /// Validation errors |
| 73 | + pub errors: Vec<String>, |
| 74 | + /// Timing info in microseconds. This field is a `HashMap` to allow timing |
| 75 | + /// multiple components (or none at all). |
| 76 | + pub timing_info: HashMap<String, Micros>, |
| 77 | +} |
| 78 | + |
| 79 | +impl TestValidationResult { |
| 80 | + /// Check if validation succeeded |
| 81 | + pub fn validation_passed(&self) -> bool { |
| 82 | + self.errors.is_empty() |
| 83 | + } |
| 84 | +} |
| 85 | + |
| 86 | +/// Custom implementation of the Cedar authorizer, evaluator, and validator for testing. |
| 87 | +pub trait CedarTestImplementation { |
| 88 | + /// Custom authorizer entry point. |
| 89 | + fn is_authorized( |
| 90 | + &self, |
| 91 | + request: &Request, |
| 92 | + policies: &PolicySet, |
| 93 | + entities: &Entities, |
| 94 | + ) -> TestResult<TestResponse>; |
| 95 | + |
| 96 | + /// Custom evaluator entry point. The bool return value indicates the whether |
| 97 | + /// evaluating the provided expression produces the expected value. |
| 98 | + /// `expected` is optional to allow for the case where no return value is |
| 99 | + /// expected due to errors. |
| 100 | + fn interpret( |
| 101 | + &self, |
| 102 | + request: &Request, |
| 103 | + entities: &Entities, |
| 104 | + expr: &Expr, |
| 105 | + enable_extensions: bool, |
| 106 | + expected: Option<Value>, |
| 107 | + ) -> TestResult<bool>; |
| 108 | + |
| 109 | + /// Custom validator entry point. |
| 110 | + fn validate( |
| 111 | + &self, |
| 112 | + schema: &ValidatorSchema, |
| 113 | + policies: &PolicySet, |
| 114 | + mode: ValidationMode, |
| 115 | + ) -> TestResult<TestValidationResult>; |
| 116 | + |
| 117 | + /// `ErrorComparisonMode` that should be used for this `CedarTestImplementation` |
| 118 | + fn error_comparison_mode(&self) -> ErrorComparisonMode; |
| 119 | +} |
| 120 | + |
| 121 | +/// Specifies how errors coming from a `CedarTestImplementation` should be |
| 122 | +/// compared against errors coming from the Rust implementation. |
| 123 | +#[derive(Debug, Clone, PartialEq, Eq, Hash)] |
| 124 | +pub enum ErrorComparisonMode { |
| 125 | + /// Don't compare errors at all; the `CedarTestImplementation` is not |
| 126 | + /// expected to produce errors matching the Rust implementation's errors in |
| 127 | + /// any way. |
| 128 | + /// In fact, the `CedarTestImplementation` will be expected to never report |
| 129 | + /// errors. |
| 130 | + Ignore, |
| 131 | + /// The `CedarTestImplementation` is expected to produce "error messages" that |
| 132 | + /// are actually just the id of the erroring policy. This will be compared to |
| 133 | + /// ensure that the `CedarTestImplementation` agrees with the Rust |
| 134 | + /// implementation on which policies produce errors. |
| 135 | + PolicyIds, |
| 136 | + /// The `CedarTestImplementation` is expected to produce error messages that |
| 137 | + /// exactly match the Rust implementation's error messages' `Display` text. |
| 138 | + Full, |
| 139 | +} |
| 140 | + |
| 141 | +/// Basic struct to support implementing the `CedarTestImplementation` trait |
| 142 | +#[derive(Debug, Default)] |
| 143 | +pub struct RustEngine {} |
| 144 | + |
| 145 | +impl RustEngine { |
| 146 | + /// Create a new `RustEngine` |
| 147 | + pub fn new() -> Self { |
| 148 | + Self {} |
| 149 | + } |
| 150 | +} |
| 151 | + |
| 152 | +/// Timing function |
| 153 | +pub fn time_function<X, F>(f: F) -> (X, Duration) |
| 154 | +where |
| 155 | + F: FnOnce() -> X, |
| 156 | +{ |
| 157 | + let start = Instant::now(); |
| 158 | + let result = f(); |
| 159 | + (result, start.elapsed()) |
| 160 | +} |
| 161 | + |
| 162 | +/// An implementation of `CedarTestImplementation` using `cedar-policy`. |
| 163 | +/// Used for running integration tests. |
| 164 | +impl CedarTestImplementation for RustEngine { |
| 165 | + fn is_authorized( |
| 166 | + &self, |
| 167 | + request: &Request, |
| 168 | + policies: &PolicySet, |
| 169 | + entities: &Entities, |
| 170 | + ) -> TestResult<TestResponse> { |
| 171 | + let authorizer = Authorizer::new(); |
| 172 | + let (response, duration) = |
| 173 | + time_function(|| authorizer.is_authorized(request.clone(), policies, entities)); |
| 174 | + // Error messages should only include the policy id to use the |
| 175 | + // `ErrorComparisonMode::PolicyIds` mode. |
| 176 | + let response = cedar_policy::Response::from(response); |
| 177 | + let response = InterfaceResponse::new( |
| 178 | + response.decision(), |
| 179 | + response.diagnostics().reason().cloned().collect(), |
| 180 | + response |
| 181 | + .diagnostics() |
| 182 | + .errors() |
| 183 | + .map(cedar_policy::AuthorizationError::id) |
| 184 | + .map(ToString::to_string) |
| 185 | + .collect(), |
| 186 | + ); |
| 187 | + let response = TestResponse { |
| 188 | + response, |
| 189 | + timing_info: HashMap::from([("authorize".into(), Micros(duration.as_micros()))]), |
| 190 | + }; |
| 191 | + TestResult::Success(response) |
| 192 | + } |
| 193 | + |
| 194 | + fn interpret( |
| 195 | + &self, |
| 196 | + request: &Request, |
| 197 | + entities: &Entities, |
| 198 | + expr: &Expr, |
| 199 | + enable_extensions: bool, |
| 200 | + expected: Option<Value>, |
| 201 | + ) -> TestResult<bool> { |
| 202 | + let exts = if enable_extensions { |
| 203 | + Extensions::all_available() |
| 204 | + } else { |
| 205 | + Extensions::none() |
| 206 | + }; |
| 207 | + let evaluator = Evaluator::new(request.clone(), entities, &exts); |
| 208 | + let result = evaluator.interpret(expr, &HashMap::default()); |
| 209 | + let response = result.ok() == expected; |
| 210 | + TestResult::Success(response) |
| 211 | + } |
| 212 | + |
| 213 | + fn validate( |
| 214 | + &self, |
| 215 | + schema: &ValidatorSchema, |
| 216 | + policies: &PolicySet, |
| 217 | + mode: ValidationMode, |
| 218 | + ) -> TestResult<TestValidationResult> { |
| 219 | + let validator = Validator::new(schema.clone()); |
| 220 | + let (result, duration) = time_function(|| validator.validate(policies, mode)); |
| 221 | + let response = TestValidationResult { |
| 222 | + errors: result |
| 223 | + .validation_errors() |
| 224 | + .map(|err| format!("{err:?}")) |
| 225 | + .collect(), |
| 226 | + timing_info: HashMap::from([("validate".into(), Micros(duration.as_micros()))]), |
| 227 | + }; |
| 228 | + TestResult::Success(response) |
| 229 | + } |
| 230 | + |
| 231 | + fn error_comparison_mode(&self) -> ErrorComparisonMode { |
| 232 | + ErrorComparisonMode::PolicyIds |
| 233 | + } |
| 234 | +} |
0 commit comments