@@ -31,7 +31,8 @@ use crate::types::tuple::{TupleLength, TupleType};
3131use crate :: types:: {
3232 BoundMethodType , ClassLiteral , DataclassParams , FieldInstance , KnownBoundMethodType ,
3333 KnownClass , KnownInstanceType , PropertyInstanceType , SpecialFormType , TrackedConstraintSet ,
34- TypeAliasType , TypeMapping , UnionType , WrapperDescriptorKind , enums, ide_support, todo_type,
34+ TypeAliasType , TypeContext , TypeMapping , UnionType , WrapperDescriptorKind , enums, ide_support,
35+ todo_type,
3536} ;
3637use ruff_db:: diagnostic:: { Annotation , Diagnostic , SubDiagnostic , SubDiagnosticSeverity } ;
3738use ruff_python_ast:: { self as ast, PythonVersion } ;
@@ -120,16 +121,21 @@ impl<'db> Bindings<'db> {
120121 /// You must provide an `argument_types` that was created from the same `arguments` that you
121122 /// provided to [`match_parameters`][Self::match_parameters].
122123 ///
124+ /// The return type annotation is also used to infer the specialization of generic calls.
125+ ///
123126 /// We update the bindings to include the return type of the call, the bound types for all
124127 /// parameters, and any errors resulting from binding the call, all for each union element and
125128 /// overload (if any).
126129 pub ( crate ) fn check_types (
127130 mut self ,
128131 db : & ' db dyn Db ,
129132 argument_types : & CallArguments < ' _ , ' db > ,
133+ return_tcx : & TypeContext < ' db > ,
130134 ) -> Result < Self , CallError < ' db > > {
131135 for element in & mut self . elements {
132- if let Some ( mut updated_argument_forms) = element. check_types ( db, argument_types) {
136+ if let Some ( mut updated_argument_forms) =
137+ element. check_types ( db, argument_types, return_tcx)
138+ {
133139 // If this element returned a new set of argument forms (indicating successful
134140 // argument type expansion), update the `Bindings` with these forms.
135141 updated_argument_forms. shrink_to_fit ( ) ;
@@ -1279,6 +1285,7 @@ impl<'db> CallableBinding<'db> {
12791285 & mut self ,
12801286 db : & ' db dyn Db ,
12811287 argument_types : & CallArguments < ' _ , ' db > ,
1288+ return_tcx : & TypeContext < ' db > ,
12821289 ) -> Option < ArgumentForms > {
12831290 // If this callable is a bound method, prepend the self instance onto the arguments list
12841291 // before checking.
@@ -1291,15 +1298,15 @@ impl<'db> CallableBinding<'db> {
12911298 // still perform type checking for non-overloaded function to provide better user
12921299 // experience.
12931300 if let [ overload] = self . overloads . as_mut_slice ( ) {
1294- overload. check_types ( db, argument_types. as_ref ( ) ) ;
1301+ overload. check_types ( db, argument_types. as_ref ( ) , return_tcx ) ;
12951302 }
12961303 return None ;
12971304 }
12981305 MatchingOverloadIndex :: Single ( index) => {
12991306 // If only one candidate overload remains, it is the winning match. Evaluate it as
13001307 // a regular (non-overloaded) call.
13011308 self . matching_overload_index = Some ( index) ;
1302- self . overloads [ index] . check_types ( db, argument_types. as_ref ( ) ) ;
1309+ self . overloads [ index] . check_types ( db, argument_types. as_ref ( ) , return_tcx ) ;
13031310 return None ;
13041311 }
13051312 MatchingOverloadIndex :: Multiple ( indexes) => {
@@ -1311,7 +1318,7 @@ impl<'db> CallableBinding<'db> {
13111318 // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
13121319 // whether it is compatible with the supplied argument list.
13131320 for ( _, overload) in self . matching_overloads_mut ( ) {
1314- overload. check_types ( db, argument_types. as_ref ( ) ) ;
1321+ overload. check_types ( db, argument_types. as_ref ( ) , return_tcx ) ;
13151322 }
13161323
13171324 match self . matching_overload_index ( ) {
@@ -1428,7 +1435,7 @@ impl<'db> CallableBinding<'db> {
14281435 merged_argument_forms. merge ( & argument_forms) ;
14291436
14301437 for ( _, overload) in self . matching_overloads_mut ( ) {
1431- overload. check_types ( db, expanded_arguments) ;
1438+ overload. check_types ( db, expanded_arguments, return_tcx ) ;
14321439 }
14331440
14341441 let return_type = match self . matching_overload_index ( ) {
@@ -2186,6 +2193,7 @@ struct ArgumentTypeChecker<'a, 'db> {
21862193 arguments : & ' a CallArguments < ' a , ' db > ,
21872194 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
21882195 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
2196+ return_tcx : & ' a TypeContext < ' db > ,
21892197 errors : & ' a mut Vec < BindingError < ' db > > ,
21902198
21912199 specialization : Option < Specialization < ' db > > ,
@@ -2199,6 +2207,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
21992207 arguments : & ' a CallArguments < ' a , ' db > ,
22002208 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
22012209 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
2210+ return_tcx : & ' a TypeContext < ' db > ,
22022211 errors : & ' a mut Vec < BindingError < ' db > > ,
22032212 ) -> Self {
22042213 Self {
@@ -2207,6 +2216,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22072216 arguments,
22082217 argument_matches,
22092218 parameter_tys,
2219+ return_tcx,
22102220 errors,
22112221 specialization : None ,
22122222 inherited_specialization : None ,
@@ -2247,8 +2257,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22472257 return ;
22482258 }
22492259
2250- let parameters = self . signature . parameters ( ) ;
22512260 let mut builder = SpecializationBuilder :: new ( self . db ) ;
2261+
2262+ // Note that we infer the annotated type _before_ the arguments if this call is part of
2263+ // an annotated assignment, to closer match the order of any unions written in the type
2264+ // annotation.
2265+ if let Some ( return_ty) = self . signature . return_ty
2266+ && let Some ( return_tcx) = self . return_tcx . annotation
2267+ {
2268+ // Ignore any specialization errors here, because the type context is only used to
2269+ // optionally widen the return type.
2270+ let _ = builder. infer ( return_ty, return_tcx) ;
2271+ }
2272+
2273+ let parameters = self . signature . parameters ( ) ;
22522274 for ( argument_index, adjusted_argument_index, _, argument_type) in
22532275 self . enumerate_argument_types ( )
22542276 {
@@ -2259,6 +2281,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22592281 let Some ( expected_type) = parameter. annotated_type ( ) else {
22602282 continue ;
22612283 } ;
2284+
22622285 if let Err ( error) = builder. infer (
22632286 expected_type,
22642287 variadic_argument_type. unwrap_or ( argument_type) ,
@@ -2270,6 +2293,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22702293 }
22712294 }
22722295 }
2296+
22732297 self . specialization = self . signature . generic_context . map ( |gc| builder. build ( gc) ) ;
22742298 self . inherited_specialization = self . signature . inherited_generic_context . map ( |gc| {
22752299 // The inherited generic context is used when inferring the specialization of a generic
@@ -2516,13 +2540,19 @@ impl<'db> Binding<'db> {
25162540 self . argument_matches = matcher. finish ( ) ;
25172541 }
25182542
2519- fn check_types ( & mut self , db : & ' db dyn Db , arguments : & CallArguments < ' _ , ' db > ) {
2543+ fn check_types (
2544+ & mut self ,
2545+ db : & ' db dyn Db ,
2546+ arguments : & CallArguments < ' _ , ' db > ,
2547+ return_tcx : & TypeContext < ' db > ,
2548+ ) {
25202549 let mut checker = ArgumentTypeChecker :: new (
25212550 db,
25222551 & self . signature ,
25232552 arguments,
25242553 & self . argument_matches ,
25252554 & mut self . parameter_tys ,
2555+ return_tcx,
25262556 & mut self . errors ,
25272557 ) ;
25282558
0 commit comments