Skip to content

Commit 32edea8

Browse files
committed
use type context for inference of generic function calls
1 parent 706be0a commit 32edea8

File tree

5 files changed

+89
-14
lines changed

5 files changed

+89
-14
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,48 @@ reveal_type(x) # revealed: Foo
234234
x: int = 1
235235
reveal_type(x) # revealed: Literal[1]
236236
```
237+
238+
## Annotations influence generic call inference
239+
240+
```toml
241+
[environment]
242+
python-version = "3.12"
243+
```
244+
245+
```py
246+
from typing import Literal
247+
248+
def f[T](x: T) -> list[T]:
249+
return [x]
250+
251+
a = f("a")
252+
reveal_type(a) # revealed: list[Literal["a"]]
253+
254+
b: list[int | Literal["a"]] = f("a")
255+
reveal_type(b) # revealed: list[int | Literal["a"]]
256+
257+
c: list[int | str] = f("a")
258+
reveal_type(c) # revealed: list[int | str]
259+
260+
d: list[int | tuple[int, int]] = f((1, 2))
261+
reveal_type(d) # revealed: list[int | tuple[int, int]]
262+
263+
e: list[int] = f(True)
264+
reveal_type(e) # revealed: list[int]
265+
266+
# TODO: the RHS should be inferred as `list[Literal["a"]]` here
267+
# error: [invalid-assignment] "Object of type `list[int | Literal["a"]]` is not assignable to `list[int]`"
268+
g: list[int] = f("a")
269+
270+
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
271+
h: tuple[int] = f("a")
272+
273+
def f2[T: int](x: T) -> T:
274+
return x
275+
276+
i: int = f2(True)
277+
reveal_type(i) # revealed: int
278+
279+
j: int | str = f2(True)
280+
reveal_type(j) # revealed: Literal[True]
281+
```

crates/ty_python_semantic/src/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4805,7 +4805,7 @@ impl<'db> Type<'db> {
48054805
) -> Result<Bindings<'db>, CallError<'db>> {
48064806
self.bindings(db)
48074807
.match_parameters(db, argument_types)
4808-
.check_types(db, argument_types)
4808+
.check_types(db, argument_types, &TypeContext::default())
48094809
}
48104810

48114811
/// Look up a dunder method on the meta-type of `self` and call it.
@@ -4854,7 +4854,7 @@ impl<'db> Type<'db> {
48544854
let bindings = dunder_callable
48554855
.bindings(db)
48564856
.match_parameters(db, argument_types)
4857-
.check_types(db, argument_types)?;
4857+
.check_types(db, argument_types, &TypeContext::default())?;
48584858
if boundness == Boundness::PossiblyUnbound {
48594859
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
48604860
}

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ use crate::types::tuple::{TupleLength, TupleType};
3131
use crate::types::{
3232
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
3333
KnownClass, KnownInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
34-
TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums, ide_support, todo_type,
34+
TypeAliasType, TypeContext, TypeMapping, UnionType, WrapperDescriptorKind, enums, ide_support,
35+
todo_type,
3536
};
3637
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
3738
use ruff_python_ast::{self as ast, PythonVersion};
@@ -120,16 +121,21 @@ impl<'db> Bindings<'db> {
120121
/// You must provide an `argument_types` that was created from the same `arguments` that you
121122
/// provided to [`match_parameters`][Self::match_parameters].
122123
///
124+
/// The return type annotation is also used to infer the specialization of generic calls.
125+
///
123126
/// We update the bindings to include the return type of the call, the bound types for all
124127
/// parameters, and any errors resulting from binding the call, all for each union element and
125128
/// overload (if any).
126129
pub(crate) fn check_types(
127130
mut self,
128131
db: &'db dyn Db,
129132
argument_types: &CallArguments<'_, 'db>,
133+
return_tcx: &TypeContext<'db>,
130134
) -> Result<Self, CallError<'db>> {
131135
for element in &mut self.elements {
132-
if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) {
136+
if let Some(mut updated_argument_forms) =
137+
element.check_types(db, argument_types, return_tcx)
138+
{
133139
// If this element returned a new set of argument forms (indicating successful
134140
// argument type expansion), update the `Bindings` with these forms.
135141
updated_argument_forms.shrink_to_fit();
@@ -1279,6 +1285,7 @@ impl<'db> CallableBinding<'db> {
12791285
&mut self,
12801286
db: &'db dyn Db,
12811287
argument_types: &CallArguments<'_, 'db>,
1288+
return_tcx: &TypeContext<'db>,
12821289
) -> Option<ArgumentForms> {
12831290
// If this callable is a bound method, prepend the self instance onto the arguments list
12841291
// before checking.
@@ -1291,15 +1298,15 @@ impl<'db> CallableBinding<'db> {
12911298
// still perform type checking for non-overloaded function to provide better user
12921299
// experience.
12931300
if let [overload] = self.overloads.as_mut_slice() {
1294-
overload.check_types(db, argument_types.as_ref());
1301+
overload.check_types(db, argument_types.as_ref(), return_tcx);
12951302
}
12961303
return None;
12971304
}
12981305
MatchingOverloadIndex::Single(index) => {
12991306
// If only one candidate overload remains, it is the winning match. Evaluate it as
13001307
// a regular (non-overloaded) call.
13011308
self.matching_overload_index = Some(index);
1302-
self.overloads[index].check_types(db, argument_types.as_ref());
1309+
self.overloads[index].check_types(db, argument_types.as_ref(), return_tcx);
13031310
return None;
13041311
}
13051312
MatchingOverloadIndex::Multiple(indexes) => {
@@ -1311,7 +1318,7 @@ impl<'db> CallableBinding<'db> {
13111318
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
13121319
// whether it is compatible with the supplied argument list.
13131320
for (_, overload) in self.matching_overloads_mut() {
1314-
overload.check_types(db, argument_types.as_ref());
1321+
overload.check_types(db, argument_types.as_ref(), return_tcx);
13151322
}
13161323

13171324
match self.matching_overload_index() {
@@ -1428,7 +1435,7 @@ impl<'db> CallableBinding<'db> {
14281435
merged_argument_forms.merge(&argument_forms);
14291436

14301437
for (_, overload) in self.matching_overloads_mut() {
1431-
overload.check_types(db, expanded_arguments);
1438+
overload.check_types(db, expanded_arguments, return_tcx);
14321439
}
14331440

14341441
let return_type = match self.matching_overload_index() {
@@ -2186,6 +2193,7 @@ struct ArgumentTypeChecker<'a, 'db> {
21862193
arguments: &'a CallArguments<'a, 'db>,
21872194
argument_matches: &'a [MatchedArgument<'db>],
21882195
parameter_tys: &'a mut [Option<Type<'db>>],
2196+
return_tcx: &'a TypeContext<'db>,
21892197
errors: &'a mut Vec<BindingError<'db>>,
21902198

21912199
specialization: Option<Specialization<'db>>,
@@ -2199,6 +2207,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
21992207
arguments: &'a CallArguments<'a, 'db>,
22002208
argument_matches: &'a [MatchedArgument<'db>],
22012209
parameter_tys: &'a mut [Option<Type<'db>>],
2210+
return_tcx: &'a TypeContext<'db>,
22022211
errors: &'a mut Vec<BindingError<'db>>,
22032212
) -> Self {
22042213
Self {
@@ -2207,6 +2216,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22072216
arguments,
22082217
argument_matches,
22092218
parameter_tys,
2219+
return_tcx,
22102220
errors,
22112221
specialization: None,
22122222
inherited_specialization: None,
@@ -2247,8 +2257,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22472257
return;
22482258
}
22492259

2250-
let parameters = self.signature.parameters();
22512260
let mut builder = SpecializationBuilder::new(self.db);
2261+
2262+
// Note that we infer the annotated type _before_ the arguments if this call is part of
2263+
// an annotated assignment, to closer match the order of any unions written in the type
2264+
// annotation.
2265+
if let Some(return_ty) = self.signature.return_ty
2266+
&& let Some(return_tcx) = self.return_tcx.annotation
2267+
{
2268+
// Ignore any specialization errors here, because the type context is only used to
2269+
// optionally widen the return type.
2270+
let _ = builder.infer(return_ty, return_tcx);
2271+
}
2272+
2273+
let parameters = self.signature.parameters();
22522274
for (argument_index, adjusted_argument_index, _, argument_type) in
22532275
self.enumerate_argument_types()
22542276
{
@@ -2259,6 +2281,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22592281
let Some(expected_type) = parameter.annotated_type() else {
22602282
continue;
22612283
};
2284+
22622285
if let Err(error) = builder.infer(
22632286
expected_type,
22642287
variadic_argument_type.unwrap_or(argument_type),
@@ -2270,6 +2293,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22702293
}
22712294
}
22722295
}
2296+
22732297
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
22742298
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
22752299
// The inherited generic context is used when inferring the specialization of a generic
@@ -2516,13 +2540,19 @@ impl<'db> Binding<'db> {
25162540
self.argument_matches = matcher.finish();
25172541
}
25182542

2519-
fn check_types(&mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>) {
2543+
fn check_types(
2544+
&mut self,
2545+
db: &'db dyn Db,
2546+
arguments: &CallArguments<'_, 'db>,
2547+
return_tcx: &TypeContext<'db>,
2548+
) {
25202549
let mut checker = ArgumentTypeChecker::new(
25212550
db,
25222551
&self.signature,
25232552
arguments,
25242553
&self.argument_matches,
25252554
&mut self.parameter_tys,
2555+
return_tcx,
25262556
&mut self.errors,
25272557
);
25282558

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ struct ExpressionWithContext<'db> {
352352
/// more precise inference results, aka "bidirectional type inference".
353353
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
354354
pub(crate) struct TypeContext<'db> {
355-
annotation: Option<Type<'db>>,
355+
pub(crate) annotation: Option<Type<'db>>,
356356
}
357357

358358
impl<'db> TypeContext<'db> {

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5772,7 +5772,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
57725772
fn infer_call_expression(
57735773
&mut self,
57745774
call_expression: &ast::ExprCall,
5775-
_tcx: TypeContext<'db>,
5775+
tcx: TypeContext<'db>,
57765776
) -> Type<'db> {
57775777
let ast::ExprCall {
57785778
range: _,
@@ -5950,7 +5950,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59505950
}
59515951
}
59525952

5953-
let mut bindings = match bindings.check_types(self.db(), &call_arguments) {
5953+
let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) {
59545954
Ok(bindings) => bindings,
59555955
Err(CallError(_, bindings)) => {
59565956
bindings.report_diagnostics(&self.context, call_expression.into());
@@ -8516,7 +8516,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
85168516
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
85178517
let bindings = match Bindings::from(binding)
85188518
.match_parameters(self.db(), &call_argument_types)
8519-
.check_types(self.db(), &call_argument_types)
8519+
.check_types(self.db(), &call_argument_types, &TypeContext::default())
85208520
{
85218521
Ok(bindings) => bindings,
85228522
Err(CallError(_, bindings)) => {

0 commit comments

Comments
 (0)