@@ -32,8 +32,8 @@ use crate::types::tuple::{TupleLength, TupleType};
3232use crate :: types:: {
3333 BoundMethodType , ClassLiteral , DataclassParams , FieldInstance , KnownBoundMethodType ,
3434 KnownClass , KnownInstanceType , MemberLookupPolicy , PropertyInstanceType , SpecialFormType ,
35- TrackedConstraintSet , TypeAliasType , TypeMapping , UnionType , WrapperDescriptorKind , enums ,
36- ide_support, todo_type,
35+ TrackedConstraintSet , TypeAliasType , TypeContext , TypeMapping , UnionType ,
36+ WrapperDescriptorKind , enums , ide_support, todo_type,
3737} ;
3838use ruff_db:: diagnostic:: { Annotation , Diagnostic , SubDiagnostic , SubDiagnosticSeverity } ;
3939use ruff_python_ast:: { self as ast, PythonVersion } ;
@@ -122,16 +122,22 @@ impl<'db> Bindings<'db> {
122122 /// You must provide an `argument_types` that was created from the same `arguments` that you
123123 /// provided to [`match_parameters`][Self::match_parameters].
124124 ///
125+ /// The type context of the call expression is also used to infer the specialization of generic
126+ /// calls.
127+ ///
125128 /// We update the bindings to include the return type of the call, the bound types for all
126129 /// parameters, and any errors resulting from binding the call, all for each union element and
127130 /// overload (if any).
128131 pub ( crate ) fn check_types (
129132 mut self ,
130133 db : & ' db dyn Db ,
131134 argument_types : & CallArguments < ' _ , ' db > ,
135+ call_expression_tcx : & TypeContext < ' db > ,
132136 ) -> Result < Self , CallError < ' db > > {
133137 for element in & mut self . elements {
134- if let Some ( mut updated_argument_forms) = element. check_types ( db, argument_types) {
138+ if let Some ( mut updated_argument_forms) =
139+ element. check_types ( db, argument_types, call_expression_tcx)
140+ {
135141 // If this element returned a new set of argument forms (indicating successful
136142 // argument type expansion), update the `Bindings` with these forms.
137143 updated_argument_forms. shrink_to_fit ( ) ;
@@ -1281,6 +1287,7 @@ impl<'db> CallableBinding<'db> {
12811287 & mut self ,
12821288 db : & ' db dyn Db ,
12831289 argument_types : & CallArguments < ' _ , ' db > ,
1290+ call_expression_tcx : & TypeContext < ' db > ,
12841291 ) -> Option < ArgumentForms > {
12851292 // If this callable is a bound method, prepend the self instance onto the arguments list
12861293 // before checking.
@@ -1293,15 +1300,15 @@ impl<'db> CallableBinding<'db> {
12931300 // still perform type checking for non-overloaded function to provide better user
12941301 // experience.
12951302 if let [ overload] = self . overloads . as_mut_slice ( ) {
1296- overload. check_types ( db, argument_types. as_ref ( ) ) ;
1303+ overload. check_types ( db, argument_types. as_ref ( ) , call_expression_tcx ) ;
12971304 }
12981305 return None ;
12991306 }
13001307 MatchingOverloadIndex :: Single ( index) => {
13011308 // If only one candidate overload remains, it is the winning match. Evaluate it as
13021309 // a regular (non-overloaded) call.
13031310 self . matching_overload_index = Some ( index) ;
1304- self . overloads [ index] . check_types ( db, argument_types. as_ref ( ) ) ;
1311+ self . overloads [ index] . check_types ( db, argument_types. as_ref ( ) , call_expression_tcx ) ;
13051312 return None ;
13061313 }
13071314 MatchingOverloadIndex :: Multiple ( indexes) => {
@@ -1313,7 +1320,7 @@ impl<'db> CallableBinding<'db> {
13131320 // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
13141321 // whether it is compatible with the supplied argument list.
13151322 for ( _, overload) in self . matching_overloads_mut ( ) {
1316- overload. check_types ( db, argument_types. as_ref ( ) ) ;
1323+ overload. check_types ( db, argument_types. as_ref ( ) , call_expression_tcx ) ;
13171324 }
13181325
13191326 match self . matching_overload_index ( ) {
@@ -1430,7 +1437,7 @@ impl<'db> CallableBinding<'db> {
14301437 merged_argument_forms. merge ( & argument_forms) ;
14311438
14321439 for ( _, overload) in self . matching_overloads_mut ( ) {
1433- overload. check_types ( db, expanded_arguments) ;
1440+ overload. check_types ( db, expanded_arguments, call_expression_tcx ) ;
14341441 }
14351442
14361443 let return_type = match self . matching_overload_index ( ) {
@@ -2243,6 +2250,7 @@ struct ArgumentTypeChecker<'a, 'db> {
22432250 arguments : & ' a CallArguments < ' a , ' db > ,
22442251 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
22452252 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
2253+ call_expression_tcx : & ' a TypeContext < ' db > ,
22462254 errors : & ' a mut Vec < BindingError < ' db > > ,
22472255
22482256 specialization : Option < Specialization < ' db > > ,
@@ -2256,6 +2264,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22562264 arguments : & ' a CallArguments < ' a , ' db > ,
22572265 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
22582266 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
2267+ call_expression_tcx : & ' a TypeContext < ' db > ,
22592268 errors : & ' a mut Vec < BindingError < ' db > > ,
22602269 ) -> Self {
22612270 Self {
@@ -2264,6 +2273,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22642273 arguments,
22652274 argument_matches,
22662275 parameter_tys,
2276+ call_expression_tcx,
22672277 errors,
22682278 specialization : None ,
22692279 inherited_specialization : None ,
@@ -2304,8 +2314,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
23042314 return ;
23052315 }
23062316
2307- let parameters = self . signature . parameters ( ) ;
23082317 let mut builder = SpecializationBuilder :: new ( self . db ) ;
2318+
2319+ // Note that we infer the annotated type _before_ the arguments if this call is part of
2320+ // an annotated assignment, to closer match the order of any unions written in the type
2321+ // annotation.
2322+ if let Some ( return_ty) = self . signature . return_ty
2323+ && let Some ( call_expression_tcx) = self . call_expression_tcx . annotation
2324+ {
2325+ // Ignore any specialization errors here, because the type context is only used to
2326+ // optionally widen the return type.
2327+ let _ = builder. infer ( return_ty, call_expression_tcx) ;
2328+ }
2329+
2330+ let parameters = self . signature . parameters ( ) ;
23092331 for ( argument_index, adjusted_argument_index, _, argument_type) in
23102332 self . enumerate_argument_types ( )
23112333 {
@@ -2316,6 +2338,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
23162338 let Some ( expected_type) = parameter. annotated_type ( ) else {
23172339 continue ;
23182340 } ;
2341+
23192342 if let Err ( error) = builder. infer (
23202343 expected_type,
23212344 variadic_argument_type. unwrap_or ( argument_type) ,
@@ -2327,6 +2350,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
23272350 }
23282351 }
23292352 }
2353+
23302354 self . specialization = self . signature . generic_context . map ( |gc| builder. build ( gc) ) ;
23312355 self . inherited_specialization = self . signature . inherited_generic_context . map ( |gc| {
23322356 // The inherited generic context is used when inferring the specialization of a generic
@@ -2688,13 +2712,19 @@ impl<'db> Binding<'db> {
26882712 self . argument_matches = matcher. finish ( ) ;
26892713 }
26902714
2691- fn check_types ( & mut self , db : & ' db dyn Db , arguments : & CallArguments < ' _ , ' db > ) {
2715+ fn check_types (
2716+ & mut self ,
2717+ db : & ' db dyn Db ,
2718+ arguments : & CallArguments < ' _ , ' db > ,
2719+ call_expression_tcx : & TypeContext < ' db > ,
2720+ ) {
26922721 let mut checker = ArgumentTypeChecker :: new (
26932722 db,
26942723 & self . signature ,
26952724 arguments,
26962725 & self . argument_matches ,
26972726 & mut self . parameter_tys ,
2727+ call_expression_tcx,
26982728 & mut self . errors ,
26992729 ) ;
27002730
0 commit comments