Skip to content

Commit 5f77e4a

Browse files
committed
[ty] Support dataclass_transform for base class models
1 parent 1ee1dea commit 5f77e4a

File tree

4 files changed

+39
-27
lines changed

4 files changed

+39
-27
lines changed

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ class CustomerModel(ModelBase):
122122
id: int
123123
name: str
124124

125-
# TODO: this is not supported yet
126-
# error: [unknown-argument]
127-
# error: [unknown-argument]
128125
CustomerModel(id=1, name="Test")
129126
```
130127

@@ -216,10 +213,6 @@ class OrderedModelBase: ...
216213
class TestWithBase(OrderedModelBase):
217214
inner: int
218215

219-
# TODO: No errors here
220-
# error: [too-many-positional-arguments]
221-
# error: [too-many-positional-arguments]
222-
# error: [unsupported-operator]
223216
TestWithBase(1) < TestWithBase(2)
224217
```
225218

@@ -277,8 +270,7 @@ class ModelBase: ...
277270
class TestBase(ModelBase):
278271
name: str
279272

280-
# TODO: This should be `(self: TestBase, *, name: str) -> None`
281-
reveal_type(TestBase.__init__) # revealed: def __init__(self) -> None
273+
reveal_type(TestBase.__init__) # revealed: (self: TestBase, *, name: str) -> None
282274
```
283275

284276
### `frozen_default`
@@ -333,12 +325,9 @@ class ModelBase: ...
333325
class TestMeta(ModelBase):
334326
name: str
335327

336-
# TODO: no error here
337-
# error: [unknown-argument]
338328
t = TestMeta(name="test")
339329

340-
# TODO: this should be an `invalid-assignment` error
341-
t.name = "new"
330+
t.name = "new" # error: [invalid-assignment]
342331
```
343332

344333
### Combining parameters

crates/ty_python_semantic/src/types/class.rs

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,11 @@ pub(crate) enum CodeGeneratorKind {
184184
}
185185

186186
impl CodeGeneratorKind {
187-
pub(crate) fn from_class(db: &dyn Db, class: ClassLiteral<'_>) -> Option<Self> {
187+
pub(crate) fn from_class(
188+
db: &dyn Db,
189+
class: ClassLiteral<'_>,
190+
specialization: Option<Specialization<'_>>,
191+
) -> Option<Self> {
188192
#[salsa::tracked(
189193
cycle_fn=code_generator_of_class_recover,
190194
cycle_initial=code_generator_of_class_initial,
@@ -193,11 +197,20 @@ impl CodeGeneratorKind {
193197
fn code_generator_of_class<'db>(
194198
db: &'db dyn Db,
195199
class: ClassLiteral<'db>,
200+
specialization: Option<Specialization<'db>>,
196201
) -> Option<CodeGeneratorKind> {
197202
if class.dataclass_params(db).is_some() {
198203
Some(CodeGeneratorKind::DataclassLike(None))
199204
} else if let Ok((_, Some(transformer_params))) = class.try_metaclass(db) {
200205
Some(CodeGeneratorKind::DataclassLike(Some(transformer_params)))
206+
} else if let Some(transformer_params) =
207+
class.iter_mro(db, specialization).skip(1).find_map(|base| {
208+
base.into_class().and_then(|class| {
209+
class.class_literal(db).0.dataclass_transformer_params(db)
210+
})
211+
})
212+
{
213+
Some(CodeGeneratorKind::DataclassLike(Some(transformer_params)))
201214
} else if class
202215
.explicit_bases(db)
203216
.contains(&Type::SpecialForm(SpecialFormType::NamedTuple))
@@ -213,6 +226,7 @@ impl CodeGeneratorKind {
213226
fn code_generator_of_class_initial(
214227
_db: &dyn Db,
215228
_class: ClassLiteral<'_>,
229+
_specialization: Option<Specialization<'_>>,
216230
) -> Option<CodeGeneratorKind> {
217231
None
218232
}
@@ -223,16 +237,25 @@ impl CodeGeneratorKind {
223237
_value: &Option<CodeGeneratorKind>,
224238
_count: u32,
225239
_class: ClassLiteral<'_>,
240+
_specialization: Option<Specialization<'_>>,
226241
) -> salsa::CycleRecoveryAction<Option<CodeGeneratorKind>> {
227242
salsa::CycleRecoveryAction::Iterate
228243
}
229244

230-
code_generator_of_class(db, class)
245+
code_generator_of_class(db, class, specialization)
231246
}
232247

233-
pub(super) fn matches(self, db: &dyn Db, class: ClassLiteral<'_>) -> bool {
248+
pub(super) fn matches(
249+
self,
250+
db: &dyn Db,
251+
class: ClassLiteral<'_>,
252+
specialization: Option<Specialization<'_>>,
253+
) -> bool {
234254
matches!(
235-
(CodeGeneratorKind::from_class(db, class), self),
255+
(
256+
CodeGeneratorKind::from_class(db, class, specialization),
257+
self
258+
),
236259
(Some(Self::DataclassLike(_)), Self::DataclassLike(_))
237260
| (Some(Self::NamedTuple), Self::NamedTuple)
238261
| (Some(Self::TypedDict), Self::TypedDict)
@@ -2094,7 +2117,7 @@ impl<'db> ClassLiteral<'db> {
20942117
.with_qualifiers(TypeQualifiers::CLASS_VAR);
20952118
}
20962119

2097-
if CodeGeneratorKind::NamedTuple.matches(db, self) {
2120+
if CodeGeneratorKind::NamedTuple.matches(db, self, specialization) {
20982121
if let Some(field) = self
20992122
.own_fields(db, specialization, CodeGeneratorKind::NamedTuple)
21002123
.get(name)
@@ -2156,7 +2179,7 @@ impl<'db> ClassLiteral<'db> {
21562179
) -> Option<Type<'db>> {
21572180
let dataclass_params = self.dataclass_params(db);
21582181

2159-
let field_policy = CodeGeneratorKind::from_class(db, self)?;
2182+
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;
21602183

21612184
let transformer_params =
21622185
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
@@ -2699,7 +2722,7 @@ impl<'db> ClassLiteral<'db> {
26992722
.filter_map(|superclass| {
27002723
if let Some(class) = superclass.into_class() {
27012724
let (class_literal, specialization) = class.class_literal(db);
2702-
if field_policy.matches(db, class_literal) {
2725+
if field_policy.matches(db, class_literal, specialization) {
27032726
Some((class_literal, specialization))
27042727
} else {
27052728
None
@@ -3493,7 +3516,7 @@ impl<'db> VarianceInferable<'db> for ClassLiteral<'db> {
34933516
.map(|class| class.variance_of(db, typevar));
34943517

34953518
let default_attribute_variance = {
3496-
let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self);
3519+
let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self, None);
34973520
// Python 3.13 introduced a synthesized `__replace__` method on dataclasses which uses
34983521
// their field types in contravariant position, thus meaning a frozen dataclass must
34993522
// still be invariant in its field types. Other synthesized methods on dataclasses are

crates/ty_python_semantic/src/types/ide_support.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl<'db> AllMembers<'db> {
122122
self.extend_with_instance_members(db, ty, class_literal);
123123

124124
// If this is a NamedTuple instance, include members from NamedTupleFallback
125-
if CodeGeneratorKind::NamedTuple.matches(db, class_literal) {
125+
if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) {
126126
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
127127
}
128128
}
@@ -142,7 +142,7 @@ impl<'db> AllMembers<'db> {
142142
Type::ClassLiteral(class_literal) => {
143143
self.extend_with_class_members(db, ty, class_literal);
144144

145-
if CodeGeneratorKind::NamedTuple.matches(db, class_literal) {
145+
if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) {
146146
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
147147
}
148148

@@ -153,7 +153,7 @@ impl<'db> AllMembers<'db> {
153153

154154
Type::GenericAlias(generic_alias) => {
155155
let class_literal = generic_alias.origin(db);
156-
if CodeGeneratorKind::NamedTuple.matches(db, class_literal) {
156+
if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) {
157157
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
158158
}
159159
self.extend_with_class_members(db, ty, class_literal);
@@ -164,7 +164,7 @@ impl<'db> AllMembers<'db> {
164164
let class_literal = class_type.class_literal(db).0;
165165
self.extend_with_class_members(db, ty, class_literal);
166166

167-
if CodeGeneratorKind::NamedTuple.matches(db, class_literal) {
167+
if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) {
168168
self.extend_with_type(
169169
db,
170170
KnownClass::NamedTupleFallback.to_class_literal(db),

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
555555
continue;
556556
}
557557

558-
let is_named_tuple = CodeGeneratorKind::NamedTuple.matches(self.db(), class);
558+
let is_named_tuple = CodeGeneratorKind::NamedTuple.matches(self.db(), class, None);
559559

560560
// (2) If it's a `NamedTuple` class, check that no field without a default value
561561
// appears after a field with a default value.
@@ -852,7 +852,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
852852

853853
// (6) Check that a dataclass does not have more than one `KW_ONLY`.
854854
if let Some(field_policy @ CodeGeneratorKind::DataclassLike(_)) =
855-
CodeGeneratorKind::from_class(self.db(), class)
855+
CodeGeneratorKind::from_class(self.db(), class, None)
856856
{
857857
let specialization = None;
858858

0 commit comments

Comments
 (0)