Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {

fn validate_float(&self, strict: bool) -> ValMatch<EitherFloat<'_>>;

fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
if strict {
self.strict_decimal(py)
} else {
self.lax_decimal(py)
}
}
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
self.strict_decimal(py)
}
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;

type Dict<'a>: ValidatedDict<'py>
where
Expand Down
13 changes: 7 additions & 6 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,13 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
fn validate_decimal(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
match self {
JsonValue::Float(f) => create_decimal(&PyString::new_bound(py, &f.to_string()), self),

JsonValue::Float(f) => {
create_decimal(&PyString::new_bound(py, &f.to_string()), self).map(ValidationMatch::strict)
}
JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => {
create_decimal(self.to_object(py).bind(py), self)
create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::lax)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if Str should be strict... we serialize decimals as strings now, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok! What about floats and ints, should those be strict as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is anything lax then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess they're probably all strict? Ungh 😂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't need lax hahah

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Python, I think they should be lax, based on existing tests...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pytest.mark.parametrize(
    'input_value,expected',
    [
        (Decimal(0), Decimal(0)),
        (Decimal(1), Decimal(1)),
        (Decimal(42), Decimal(42)),
        (Decimal('42.0'), Decimal('42.0')),
        (Decimal('42.5'), Decimal('42.5')),
        (42.0, Err('Input should be an instance of Decimal [type=is_instance_of, input_value=42.0, input_type=float]')),
        ('42', Err("Input should be an instance of Decimal [type=is_instance_of, input_value='42', input_type=str]")),
        (42, Err('Input should be an instance of Decimal [type=is_instance_of, input_value=42, input_type=int]')),
        (True, Err('Input should be an instance of Decimal [type=is_instance_of, input_value=True, input_type=bool]')),
    ],
    ids=repr,
)
def test_decimal_strict_py(input_value, expected):
    v = SchemaValidator({'type': 'decimal', 'strict': True})
    if isinstance(expected, Err):
        with pytest.raises(ValidationError, match=re.escape(expected.message)):
            v.validate_python(input_value)
    else:
        output = v.validate_python(input_value)
        assert output == expected
        assert isinstance(output, Decimal)

_ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
}
Expand Down Expand Up @@ -373,8 +374,8 @@ impl<'py> Input<'py> for str {
str_as_float(self, self).map(ValidationMatch::lax)
}

fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
create_decimal(self.to_object(py).bind(py), self)
fn validate_decimal(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::lax)
}

type Dict<'a> = Never;
Expand Down
64 changes: 27 additions & 37 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
str_as_int(self, s)
} else if self.is_exact_instance_of::<PyFloat>() {
float_as_int(self, self.extract::<f64>()?)
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
decimal_as_int(self, &decimal)
} else if let Ok(decimal) = self.validate_decimal(strict, self.py()) {
decimal_as_int(self, &decimal.into_inner())
} else if let Ok(float) = self.extract::<f64>() {
float_as_int(self, float)
} else if let Some(enum_val) = maybe_as_enum(self) {
Expand Down Expand Up @@ -307,47 +307,37 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
}

fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
let decimal_type = get_decimal_type(py);
// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self.to_owned());
}

// Try subclasses of decimals, they will be upcast to Decimal
if self.is_instance(decimal_type)? {
return create_decimal(self, self);
}

Err(ValError::new(
ErrorType::IsInstanceOf {
class: decimal_type
.qualname()
.and_then(|name| name.extract())
.unwrap_or_else(|_| "Decimal".to_owned()),
context: None,
},
self,
))
}

fn lax_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
let decimal_type = get_decimal_type(py);
// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self.to_owned().clone());
}

if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>()) {
// checking isinstance for str / int / bool is fast compared to decimal / float
create_decimal(self, self)
Ok(ValidationMatch::exact(self.to_owned().clone()))
} else if self.is_instance(decimal_type)? {
// upcast subclasses to decimal
return create_decimal(self, self);
} else if self.is_instance_of::<PyFloat>() {
create_decimal(self.str()?.as_any(), self)
// Upcast subclasses to decimal
create_decimal(self, self).map(ValidationMatch::strict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was deliberately below the PyString / PyInt / PyBool check because those are faster to check for (see comment now on line 334), maybe move those up (but have them with a !strict guard too).

} else {
Err(ValError::new(ErrorTypeDefaults::DecimalType, self))
if strict {
return Err(ValError::new(
ErrorType::IsInstanceOf {
class: decimal_type
.qualname()
.and_then(|name| name.extract())
.unwrap_or_else(|_| "Decimal".to_owned()),
context: None,
},
self,
));
}
if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>())
{
// Checking isinstance for str / int / bool is fast compared to decimal / float
create_decimal(self, self).map(ValidationMatch::lax)
} else if self.is_instance_of::<PyFloat>() {
create_decimal(self.str()?.as_any(), self).map(ValidationMatch::lax)
} else {
Err(ValError::new(ErrorTypeDefaults::DecimalType, self))
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ impl<'py> Input<'py> for StringMapping<'py> {
}
}

fn strict_decimal(&self, _py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
fn validate_decimal(&self, _strict: bool, _py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
match self {
Self::String(s) => create_decimal(s, self),
Self::String(s) => create_decimal(s, self).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/validators/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl Validator for DecimalValidator {
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?;
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?.unpack(state);

if !self.allow_inf_nan || self.check_digits {
if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? {
Expand Down
22 changes: 21 additions & 1 deletion tests/validators/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from dirty_equals import FunctionCheck, IsStr

from pydantic_core import SchemaValidator, ValidationError
from pydantic_core import SchemaValidator, ValidationError, core_schema

from ..conftest import Err, PyAndJson, plain_repr

Expand Down Expand Up @@ -467,3 +467,23 @@ def test_validate_max_digits_and_decimal_places_edge_case() -> None:
assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal(
'9999999999999999.999999999999999999'
)


def test_str_validation_w_strict() -> None:
s = SchemaValidator(core_schema.decimal_schema(strict=True))

with pytest.raises(ValidationError):
assert s.validate_python('1.23')


def test_str_validation_w_lax() -> None:
s = SchemaValidator(core_schema.decimal_schema(strict=False))

assert s.validate_python('1.23') == Decimal('1.23')


def test_union_with_str_prefers_str() -> None:
s = SchemaValidator(core_schema.union_schema([core_schema.decimal_schema(), core_schema.str_schema()]))

assert s.validate_python('1.23') == '1.23'
assert s.validate_python(1.23) == Decimal('1.23')