@@ -184,7 +184,11 @@ pub(crate) enum CodeGeneratorKind {
184184}
185185
186186impl CodeGeneratorKind {
187- pub ( crate ) fn from_class ( db : & dyn Db , class : ClassLiteral < ' _ > ) -> Option < Self > {
187+ pub ( crate ) fn from_class (
188+ db : & dyn Db ,
189+ class : ClassLiteral < ' _ > ,
190+ specialization : Option < Specialization < ' _ > > ,
191+ ) -> Option < Self > {
188192 #[ salsa:: tracked(
189193 cycle_fn=code_generator_of_class_recover,
190194 cycle_initial=code_generator_of_class_initial,
@@ -193,11 +197,20 @@ impl CodeGeneratorKind {
193197 fn code_generator_of_class < ' db > (
194198 db : & ' db dyn Db ,
195199 class : ClassLiteral < ' db > ,
200+ specialization : Option < Specialization < ' db > > ,
196201 ) -> Option < CodeGeneratorKind > {
197202 if class. dataclass_params ( db) . is_some ( ) {
198203 Some ( CodeGeneratorKind :: DataclassLike ( None ) )
199204 } else if let Ok ( ( _, Some ( transformer_params) ) ) = class. try_metaclass ( db) {
200205 Some ( CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) )
206+ } else if let Some ( transformer_params) =
207+ class. iter_mro ( db, specialization) . skip ( 1 ) . find_map ( |base| {
208+ base. into_class ( ) . and_then ( |class| {
209+ class. class_literal ( db) . 0 . dataclass_transformer_params ( db)
210+ } )
211+ } )
212+ {
213+ Some ( CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) )
201214 } else if class
202215 . explicit_bases ( db)
203216 . contains ( & Type :: SpecialForm ( SpecialFormType :: NamedTuple ) )
@@ -213,6 +226,7 @@ impl CodeGeneratorKind {
213226 fn code_generator_of_class_initial (
214227 _db : & dyn Db ,
215228 _class : ClassLiteral < ' _ > ,
229+ _specialization : Option < Specialization < ' _ > > ,
216230 ) -> Option < CodeGeneratorKind > {
217231 None
218232 }
@@ -223,16 +237,25 @@ impl CodeGeneratorKind {
223237 _value : & Option < CodeGeneratorKind > ,
224238 _count : u32 ,
225239 _class : ClassLiteral < ' _ > ,
240+ _specialization : Option < Specialization < ' _ > > ,
226241 ) -> salsa:: CycleRecoveryAction < Option < CodeGeneratorKind > > {
227242 salsa:: CycleRecoveryAction :: Iterate
228243 }
229244
230- code_generator_of_class ( db, class)
245+ code_generator_of_class ( db, class, specialization )
231246 }
232247
233- pub ( super ) fn matches ( self , db : & dyn Db , class : ClassLiteral < ' _ > ) -> bool {
248+ pub ( super ) fn matches (
249+ self ,
250+ db : & dyn Db ,
251+ class : ClassLiteral < ' _ > ,
252+ specialization : Option < Specialization < ' _ > > ,
253+ ) -> bool {
234254 matches ! (
235- ( CodeGeneratorKind :: from_class( db, class) , self ) ,
255+ (
256+ CodeGeneratorKind :: from_class( db, class, specialization) ,
257+ self
258+ ) ,
236259 ( Some ( Self :: DataclassLike ( _) ) , Self :: DataclassLike ( _) )
237260 | ( Some ( Self :: NamedTuple ) , Self :: NamedTuple )
238261 | ( Some ( Self :: TypedDict ) , Self :: TypedDict )
@@ -2094,7 +2117,7 @@ impl<'db> ClassLiteral<'db> {
20942117 . with_qualifiers ( TypeQualifiers :: CLASS_VAR ) ;
20952118 }
20962119
2097- if CodeGeneratorKind :: NamedTuple . matches ( db, self ) {
2120+ if CodeGeneratorKind :: NamedTuple . matches ( db, self , specialization ) {
20982121 if let Some ( field) = self
20992122 . own_fields ( db, specialization, CodeGeneratorKind :: NamedTuple )
21002123 . get ( name)
@@ -2156,7 +2179,7 @@ impl<'db> ClassLiteral<'db> {
21562179 ) -> Option < Type < ' db > > {
21572180 let dataclass_params = self . dataclass_params ( db) ;
21582181
2159- let field_policy = CodeGeneratorKind :: from_class ( db, self ) ?;
2182+ let field_policy = CodeGeneratorKind :: from_class ( db, self , specialization ) ?;
21602183
21612184 let transformer_params =
21622185 if let CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) = field_policy {
@@ -2699,7 +2722,7 @@ impl<'db> ClassLiteral<'db> {
26992722 . filter_map ( |superclass| {
27002723 if let Some ( class) = superclass. into_class ( ) {
27012724 let ( class_literal, specialization) = class. class_literal ( db) ;
2702- if field_policy. matches ( db, class_literal) {
2725+ if field_policy. matches ( db, class_literal, specialization ) {
27032726 Some ( ( class_literal, specialization) )
27042727 } else {
27052728 None
@@ -3493,7 +3516,7 @@ impl<'db> VarianceInferable<'db> for ClassLiteral<'db> {
34933516 . map ( |class| class. variance_of ( db, typevar) ) ;
34943517
34953518 let default_attribute_variance = {
3496- let is_namedtuple = CodeGeneratorKind :: NamedTuple . matches ( db, self ) ;
3519+ let is_namedtuple = CodeGeneratorKind :: NamedTuple . matches ( db, self , None ) ;
34973520 // Python 3.13 introduced a synthesized `__replace__` method on dataclasses which uses
34983521 // their field types in contravariant position, thus meaning a frozen dataclass must
34993522 // still be invariant in its field types. Other synthesized methods on dataclasses are
0 commit comments