@@ -31,7 +31,8 @@ use crate::types::tuple::{TupleLength, TupleType};
3131use crate :: types:: {
3232 BoundMethodType , ClassLiteral , DataclassParams , FieldInstance , KnownBoundMethodType ,
3333 KnownClass , KnownInstanceType , PropertyInstanceType , SpecialFormType , TrackedConstraintSet ,
34- TypeAliasType , TypeMapping , UnionType , WrapperDescriptorKind , enums, ide_support, todo_type,
34+ TypeAliasType , TypeContext , TypeMapping , UnionType , WrapperDescriptorKind , enums, ide_support,
35+ todo_type,
3536} ;
3637use ruff_db:: diagnostic:: { Annotation , Diagnostic , SubDiagnostic , SubDiagnosticSeverity } ;
3738use ruff_python_ast:: { self as ast, PythonVersion } ;
@@ -120,16 +121,22 @@ impl<'db> Bindings<'db> {
120121 /// You must provide an `argument_types` that was created from the same `arguments` that you
121122 /// provided to [`match_parameters`][Self::match_parameters].
122123 ///
124+ /// The type context of the call expression is also used to infer the specialization of generic
125+ /// calls.
126+ ///
123127 /// We update the bindings to include the return type of the call, the bound types for all
124128 /// parameters, and any errors resulting from binding the call, all for each union element and
125129 /// overload (if any).
126130 pub ( crate ) fn check_types (
127131 mut self ,
128132 db : & ' db dyn Db ,
129133 argument_types : & CallArguments < ' _ , ' db > ,
134+ call_expression_tcx : & TypeContext < ' db > ,
130135 ) -> Result < Self , CallError < ' db > > {
131136 for element in & mut self . elements {
132- if let Some ( mut updated_argument_forms) = element. check_types ( db, argument_types) {
137+ if let Some ( mut updated_argument_forms) =
138+ element. check_types ( db, argument_types, call_expression_tcx)
139+ {
133140 // If this element returned a new set of argument forms (indicating successful
134141 // argument type expansion), update the `Bindings` with these forms.
135142 updated_argument_forms. shrink_to_fit ( ) ;
@@ -1279,6 +1286,7 @@ impl<'db> CallableBinding<'db> {
12791286 & mut self ,
12801287 db : & ' db dyn Db ,
12811288 argument_types : & CallArguments < ' _ , ' db > ,
1289+ call_expression_tcx : & TypeContext < ' db > ,
12821290 ) -> Option < ArgumentForms > {
12831291 // If this callable is a bound method, prepend the self instance onto the arguments list
12841292 // before checking.
@@ -1291,15 +1299,15 @@ impl<'db> CallableBinding<'db> {
12911299 // still perform type checking for non-overloaded function to provide better user
12921300 // experience.
12931301 if let [ overload] = self . overloads . as_mut_slice ( ) {
1294- overload. check_types ( db, argument_types. as_ref ( ) ) ;
1302+ overload. check_types ( db, argument_types. as_ref ( ) , call_expression_tcx ) ;
12951303 }
12961304 return None ;
12971305 }
12981306 MatchingOverloadIndex :: Single ( index) => {
12991307 // If only one candidate overload remains, it is the winning match. Evaluate it as
13001308 // a regular (non-overloaded) call.
13011309 self . matching_overload_index = Some ( index) ;
1302- self . overloads [ index] . check_types ( db, argument_types. as_ref ( ) ) ;
1310+ self . overloads [ index] . check_types ( db, argument_types. as_ref ( ) , call_expression_tcx ) ;
13031311 return None ;
13041312 }
13051313 MatchingOverloadIndex :: Multiple ( indexes) => {
@@ -1311,7 +1319,7 @@ impl<'db> CallableBinding<'db> {
13111319 // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
13121320 // whether it is compatible with the supplied argument list.
13131321 for ( _, overload) in self . matching_overloads_mut ( ) {
1314- overload. check_types ( db, argument_types. as_ref ( ) ) ;
1322+ overload. check_types ( db, argument_types. as_ref ( ) , call_expression_tcx ) ;
13151323 }
13161324
13171325 match self . matching_overload_index ( ) {
@@ -1428,7 +1436,7 @@ impl<'db> CallableBinding<'db> {
14281436 merged_argument_forms. merge ( & argument_forms) ;
14291437
14301438 for ( _, overload) in self . matching_overloads_mut ( ) {
1431- overload. check_types ( db, expanded_arguments) ;
1439+ overload. check_types ( db, expanded_arguments, call_expression_tcx ) ;
14321440 }
14331441
14341442 let return_type = match self . matching_overload_index ( ) {
@@ -2186,6 +2194,7 @@ struct ArgumentTypeChecker<'a, 'db> {
21862194 arguments : & ' a CallArguments < ' a , ' db > ,
21872195 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
21882196 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
2197+ call_expression_tcx : & ' a TypeContext < ' db > ,
21892198 errors : & ' a mut Vec < BindingError < ' db > > ,
21902199
21912200 specialization : Option < Specialization < ' db > > ,
@@ -2199,6 +2208,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
21992208 arguments : & ' a CallArguments < ' a , ' db > ,
22002209 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
22012210 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
2211+ call_expression_tcx : & ' a TypeContext < ' db > ,
22022212 errors : & ' a mut Vec < BindingError < ' db > > ,
22032213 ) -> Self {
22042214 Self {
@@ -2207,6 +2217,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22072217 arguments,
22082218 argument_matches,
22092219 parameter_tys,
2220+ call_expression_tcx : call_expression_tcx,
22102221 errors,
22112222 specialization : None ,
22122223 inherited_specialization : None ,
@@ -2247,8 +2258,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22472258 return ;
22482259 }
22492260
2250- let parameters = self . signature . parameters ( ) ;
22512261 let mut builder = SpecializationBuilder :: new ( self . db ) ;
2262+
2263+ // Note that we infer the annotated type _before_ the arguments if this call is part of
2264+ // an annotated assignment, to closer match the order of any unions written in the type
2265+ // annotation.
2266+ if let Some ( return_ty) = self . signature . return_ty
2267+ && let Some ( call_expression_tcx) = self . call_expression_tcx . annotation
2268+ {
2269+ // Ignore any specialization errors here, because the type context is only used to
2270+ // optionally widen the return type.
2271+ let _ = builder. infer ( return_ty, call_expression_tcx) ;
2272+ }
2273+
2274+ let parameters = self . signature . parameters ( ) ;
22522275 for ( argument_index, adjusted_argument_index, _, argument_type) in
22532276 self . enumerate_argument_types ( )
22542277 {
@@ -2259,6 +2282,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22592282 let Some ( expected_type) = parameter. annotated_type ( ) else {
22602283 continue ;
22612284 } ;
2285+
22622286 if let Err ( error) = builder. infer (
22632287 expected_type,
22642288 variadic_argument_type. unwrap_or ( argument_type) ,
@@ -2270,6 +2294,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22702294 }
22712295 }
22722296 }
2297+
22732298 self . specialization = self . signature . generic_context . map ( |gc| builder. build ( gc) ) ;
22742299 self . inherited_specialization = self . signature . inherited_generic_context . map ( |gc| {
22752300 // The inherited generic context is used when inferring the specialization of a generic
@@ -2516,13 +2541,19 @@ impl<'db> Binding<'db> {
25162541 self . argument_matches = matcher. finish ( ) ;
25172542 }
25182543
2519- fn check_types ( & mut self , db : & ' db dyn Db , arguments : & CallArguments < ' _ , ' db > ) {
2544+ fn check_types (
2545+ & mut self ,
2546+ db : & ' db dyn Db ,
2547+ arguments : & CallArguments < ' _ , ' db > ,
2548+ call_expression_tcx : & TypeContext < ' db > ,
2549+ ) {
25202550 let mut checker = ArgumentTypeChecker :: new (
25212551 db,
25222552 & self . signature ,
25232553 arguments,
25242554 & self . argument_matches ,
25252555 & mut self . parameter_tys ,
2556+ call_expression_tcx,
25262557 & mut self . errors ,
25272558 ) ;
25282559
0 commit comments