Skip to content

Commit aa6c7c4

Browse files
committed
[red-knot] Dataclasses: support order=True
1 parent 03adae8 commit aa6c7c4

File tree

5 files changed

+221
-33
lines changed

5 files changed

+221
-33
lines changed

.github/workflows/mypy_primer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
--type-checker knot \
6969
--old base_commit \
7070
--new "$GITHUB_SHA" \
71-
--project-selector '/(mypy_primer|black|pyp|git-revise|zipp|arrow|isort|itsdangerous|rich|packaging|pybind11|pyinstrument|typeshed-stats|scrapy|werkzeug|bidict|async-utils)$' \
71+
--project-selector '/(mypy_primer|black|pyp|git-revise|zipp|arrow|isort|itsdangerous|rich|packaging|pybind11|pyinstrument|typeshed-stats|scrapy|werkzeug|bidict|async-utils|python-chess|dacite|python-htmlgen|paroxython|porcupine|psycopg)$' \
7272
--output concise \
7373
--debug > mypy_primer.diff || [ $? -eq 1 ]
7474

crates/red_knot_python_semantic/resources/mdtest/dataclasses.md

Lines changed: 174 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,112 @@ repr(C())
9191
C() == C()
9292
```
9393

94+
## Other dataclass parameters
95+
96+
### `repr`
97+
98+
A custom `__repr__` method is generated by default. It can be disabled by passing `repr=False`, but
99+
in that case `__repr__` is still available via `object.__repr__`:
100+
101+
```py
102+
from dataclasses import dataclass
103+
104+
@dataclass(repr=False)
105+
class WithoutRepr:
106+
x: int
107+
108+
reveal_type(WithoutRepr(1).__repr__) # revealed: bound method WithoutRepr.__repr__() -> str
109+
```
110+
111+
### `eq`
112+
113+
The same is true for `__eq__`. Setting `eq=False` disables the generated `__eq__` method, but
114+
`__eq__` is still available via `object.__eq__`:
115+
116+
```py
117+
from dataclasses import dataclass
118+
119+
@dataclass(eq=False)
120+
class WithoutEq:
121+
x: int
122+
123+
reveal_type(WithoutEq(1) == WithoutEq(2)) # revealed: bool
124+
```
125+
126+
### `order`
127+
128+
`order` is set to `False` by default. If `order=True`, `__lt__`, `__le__`, `__gt__`, and `__ge__`
129+
methods will be generated:
130+
131+
```py
132+
from dataclasses import dataclass
133+
134+
@dataclass
135+
class WithoutOrder:
136+
x: int
137+
138+
WithoutOrder(1) < WithoutOrder(2) # error: [unsupported-operator]
139+
WithoutOrder(1) <= WithoutOrder(2) # error: [unsupported-operator]
140+
WithoutOrder(1) > WithoutOrder(2) # error: [unsupported-operator]
141+
WithoutOrder(1) >= WithoutOrder(2) # error: [unsupported-operator]
142+
143+
@dataclass(order=True)
144+
class WithOrder:
145+
x: int
146+
147+
WithOrder(1) < WithOrder(2)
148+
WithOrder(1) <= WithOrder(2)
149+
WithOrder(1) > WithOrder(2)
150+
WithOrder(1) >= WithOrder(2)
151+
```
152+
153+
Comparisons are only allowed for `WithOrder` instances:
154+
155+
```py
156+
WithOrder(1) < 2 # error: [unsupported-operator]
157+
WithOrder(1) <= 2 # error: [unsupported-operator]
158+
WithOrder(1) > 2 # error: [unsupported-operator]
159+
WithOrder(1) >= 2 # error: [unsupported-operator]
160+
```
161+
162+
This also works for generic dataclasses:
163+
164+
```py
165+
from dataclasses import dataclass
166+
167+
@dataclass(order=True)
168+
class GenericWithOrder[T]:
169+
x: T
170+
171+
GenericWithOrder[int](1) < GenericWithOrder[int](1)
172+
173+
GenericWithOrder[int](1) < GenericWithOrder[str]("a") # error: [unsupported-operator]
174+
```
175+
176+
### `unsafe_hash`
177+
178+
To do
179+
180+
### `frozen`
181+
182+
To do
183+
184+
### `match_args`
185+
186+
To do
187+
188+
### `kw_only`
189+
190+
To do
191+
192+
### `slots`
193+
194+
To do
195+
196+
### `weakref_slot`
197+
198+
To do
199+
94200
## Inheritance
95201

96202
### Normal class inheriting from a dataclass
@@ -168,13 +274,30 @@ reveal_type(d_int.description) # revealed: str
168274
DataWithDescription[int](None, "description")
169275
```
170276

171-
## Frozen instances
277+
## Descriptor-typed fields
172278

173-
To do
279+
```py
280+
from dataclasses import dataclass
174281

175-
## Descriptor-typed fields
282+
class Descriptor:
283+
_value: int = 0
176284

177-
To do
285+
def __get__(self, instance, owner) -> str:
286+
return str(self._value)
287+
288+
def __set__(self, instance, value: int) -> None:
289+
self._value = value
290+
291+
@dataclass
292+
class C:
293+
d: Descriptor = Descriptor()
294+
295+
c = C(1)
296+
reveal_type(c.d) # revealed: str
297+
298+
# TODO: should be an error
299+
C("a")
300+
```
178301

179302
## `dataclasses.field`
180303

@@ -197,18 +320,61 @@ class C:
197320
reveal_type(C.__init__) # revealed: (*args: Any, **kwargs: Any) -> None
198321
```
199322

200-
### Dataclass with `init=False`
323+
### Dataclass with custom `__init__` method
201324

202-
To do
325+
If a class already defines `__init__`, it is not replaced by the `dataclass` decorator.
203326

204-
### Dataclass with custom `__init__` method
327+
```py
328+
from dataclasses import dataclass
205329

206-
To do
330+
@dataclass(init=True)
331+
class C:
332+
x: str
333+
334+
def __init__(self, x: int) -> None:
335+
self.x = str(x)
336+
337+
C(1) # OK
338+
339+
# TODO: should be an error
340+
C("a")
341+
```
342+
343+
Similarly, if we set `init=False`, we still recognize the custom `__init__` method:
344+
345+
```py
346+
@dataclass(init=False)
347+
class D:
348+
def __init__(self, x: int) -> None:
349+
self.x = str(x)
350+
351+
D(1) # OK
352+
D() # error: [missing-argument]
353+
```
207354

208355
### Dataclass with `ClassVar`s
209356

210357
To do
211358

359+
### Return type of `dataclass(...)`
360+
361+
A call like `dataclass(order=True)` returns a callable itself, which is then used as the decorator.
362+
We can store the callable in a variable and later use it as a decorator:
363+
364+
```py
365+
from dataclasses import dataclass
366+
367+
dataclass_with_order = dataclass(order=True)
368+
369+
reveal_type(dataclass_with_order) # revealed: <decorator produced by dataclasses.dataclass>
370+
371+
@dataclass_with_order
372+
class C:
373+
x: int
374+
375+
C(1) < C(2) # ok
376+
```
377+
212378
### Using `dataclass` as a function
213379

214380
To do

crates/red_knot_python_semantic/src/types/class.rs

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,15 @@ impl<'db> ClassType<'db> {
297297
/// Returns [`Symbol::Unbound`] if `name` cannot be found in this class's scope
298298
/// directly. Use [`ClassType::class_member`] if you require a method that will
299299
/// traverse through the MRO until it finds the member.
300-
pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
300+
pub(super) fn own_class_member(
301+
self,
302+
db: &'db dyn Db,
303+
specialization: Option<Specialization<'db>>,
304+
name: &str,
305+
) -> SymbolAndQualifiers<'db> {
301306
let (class_literal, _) = self.class_literal(db);
302307
class_literal
303-
.own_class_member(db, name)
308+
.own_class_member(db, specialization, name)
304309
.map_type(|ty| self.specialize_type(db, ty))
305310
}
306311

@@ -745,7 +750,8 @@ impl<'db> ClassLiteralType<'db> {
745750
}
746751

747752
lookup_result = lookup_result.or_else(|lookup_error| {
748-
lookup_error.or_fall_back_to(db, class.own_class_member(db, name))
753+
lookup_error
754+
.or_fall_back_to(db, class.own_class_member(db, specialization, name))
749755
});
750756
}
751757
}
@@ -790,23 +796,38 @@ impl<'db> ClassLiteralType<'db> {
790796
/// Returns [`Symbol::Unbound`] if `name` cannot be found in this class's scope
791797
/// directly. Use [`ClassLiteralType::class_member`] if you require a method that will
792798
/// traverse through the MRO until it finds the member.
793-
pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
799+
pub(super) fn own_class_member(
800+
self,
801+
db: &'db dyn Db,
802+
specialization: Option<Specialization<'db>>,
803+
name: &str,
804+
) -> SymbolAndQualifiers<'db> {
794805
if let Some(metadata) = self.dataclass_metadata(db) {
795-
if name == "__init__" {
796-
if metadata.contains(DataclassMetadata::INIT) {
797-
// TODO: Generate the signature from the attributes on the class
798-
let init_signature = Signature::new(
799-
Parameters::new([
800-
Parameter::variadic(Name::new_static("args"))
801-
.with_annotated_type(Type::any()),
802-
Parameter::keyword_variadic(Name::new_static("kwargs"))
803-
.with_annotated_type(Type::any()),
804-
]),
805-
Some(Type::none(db)),
806+
if name == "__init__" && metadata.contains(DataclassMetadata::INIT) {
807+
// TODO: Generate the signature from the attributes on the class
808+
let init_signature = Signature::new(
809+
Parameters::new([
810+
Parameter::variadic(Name::new_static("args"))
811+
.with_annotated_type(Type::any()),
812+
Parameter::keyword_variadic(Name::new_static("kwargs"))
813+
.with_annotated_type(Type::any()),
814+
]),
815+
Some(Type::none(db)),
816+
);
817+
818+
return Symbol::bound(Type::Callable(CallableType::new(db, init_signature))).into();
819+
} else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
820+
if metadata.contains(DataclassMetadata::ORDER) {
821+
let signature = Signature::new(
822+
Parameters::new([Parameter::positional_or_keyword(Name::new_static(
823+
"other",
824+
))
825+
.with_annotated_type(Type::instance(
826+
self.apply_optional_specialization(db, specialization),
827+
))]),
828+
Some(KnownClass::Bool.to_instance(db)),
806829
);
807-
808-
return Symbol::bound(Type::Callable(CallableType::new(db, init_signature)))
809-
.into();
830+
return Symbol::bound(Type::Callable(CallableType::new(db, signature))).into();
810831
}
811832
}
812833
}

crates/red_knot_python_semantic/src/types/generics.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::types::{
99
use crate::Db;
1010

1111
/// A list of formal type variables for a generic function, class, or type alias.
12-
#[salsa::tracked(debug)]
12+
#[salsa::interned(debug)]
1313
pub struct GenericContext<'db> {
1414
#[return_ref]
1515
pub(crate) variables: Box<[TypeVarInstance<'db>]>,
@@ -21,7 +21,7 @@ impl<'db> GenericContext<'db> {
2121
index: &'db SemanticIndex<'db>,
2222
type_params_node: &ast::TypeParams,
2323
) -> Self {
24-
let variables = type_params_node
24+
let variables: Box<[_]> = type_params_node
2525
.iter()
2626
.filter_map(|type_param| Self::variable_from_type_param(db, index, type_param))
2727
.collect();
@@ -100,7 +100,7 @@ impl<'db> GenericContext<'db> {
100100
}
101101

102102
/// An assignment of a specific type to each type variable in a generic scope.
103-
#[salsa::tracked(debug)]
103+
#[salsa::interned(debug)]
104104
pub struct Specialization<'db> {
105105
pub(crate) generic_context: GenericContext<'db>,
106106
#[return_ref]
@@ -122,7 +122,7 @@ impl<'db> Specialization<'db> {
122122
/// That lets us produce the generic alias `A[int]`, which is the corresponding entry in the
123123
/// MRO of `B[int]`.
124124
pub(crate) fn apply_specialization(self, db: &'db dyn Db, other: Specialization<'db>) -> Self {
125-
let types = self
125+
let types: Box<[_]> = self
126126
.types(db)
127127
.into_iter()
128128
.map(|ty| ty.apply_specialization(db, other))
@@ -131,7 +131,7 @@ impl<'db> Specialization<'db> {
131131
}
132132

133133
pub(crate) fn normalized(self, db: &'db dyn Db) -> Self {
134-
let types = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
134+
let types: Box<[_]> = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
135135
Self::new(db, self.generic_context(db), types)
136136
}
137137

crates/red_knot_python_semantic/src/types/slots.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ enum SlotsKind {
2424

2525
impl SlotsKind {
2626
fn from(db: &dyn Db, base: ClassLiteralType) -> Self {
27-
let Symbol::Type(slots_ty, bound) = base.own_class_member(db, "__slots__").symbol else {
27+
let Symbol::Type(slots_ty, bound) = base.own_class_member(db, None, "__slots__").symbol
28+
else {
2829
return Self::NotSpecified;
2930
};
3031

0 commit comments

Comments
 (0)