Skip to content

Commit 0a9013d

Browse files
committed
use type context for inference of generic function calls
1 parent 706be0a commit 0a9013d

File tree

5 files changed

+90
-14
lines changed

5 files changed

+90
-14
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,48 @@ reveal_type(x) # revealed: Foo
234234
x: int = 1
235235
reveal_type(x) # revealed: Literal[1]
236236
```
237+
238+
## Annotations influence generic call inference
239+
240+
```toml
241+
[environment]
242+
python-version = "3.12"
243+
```
244+
245+
```py
246+
from typing import Literal
247+
248+
def f[T](x: T) -> list[T]:
249+
return [x]
250+
251+
a = f("a")
252+
reveal_type(a) # revealed: list[Literal["a"]]
253+
254+
b: list[int | Literal["a"]] = f("a")
255+
reveal_type(b) # revealed: list[int | Literal["a"]]
256+
257+
c: list[int | str] = f("a")
258+
reveal_type(c) # revealed: list[int | str]
259+
260+
d: list[int | tuple[int, int]] = f((1, 2))
261+
reveal_type(d) # revealed: list[int | tuple[int, int]]
262+
263+
e: list[int] = f(True)
264+
reveal_type(e) # revealed: list[int]
265+
266+
# TODO: the RHS should be inferred as `list[Literal["a"]]` here
267+
# error: [invalid-assignment] "Object of type `list[int | Literal["a"]]` is not assignable to `list[int]`"
268+
g: list[int] = f("a")
269+
270+
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
271+
h: tuple[int] = f("a")
272+
273+
def f2[T: int](x: T) -> T:
274+
return x
275+
276+
i: int = f2(True)
277+
reveal_type(i) # revealed: int
278+
279+
j: int | str = f2(True)
280+
reveal_type(j) # revealed: Literal[True]
281+
```

crates/ty_python_semantic/src/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4805,7 +4805,7 @@ impl<'db> Type<'db> {
48054805
) -> Result<Bindings<'db>, CallError<'db>> {
48064806
self.bindings(db)
48074807
.match_parameters(db, argument_types)
4808-
.check_types(db, argument_types)
4808+
.check_types(db, argument_types, &TypeContext::default())
48094809
}
48104810

48114811
/// Look up a dunder method on the meta-type of `self` and call it.
@@ -4854,7 +4854,7 @@ impl<'db> Type<'db> {
48544854
let bindings = dunder_callable
48554855
.bindings(db)
48564856
.match_parameters(db, argument_types)
4857-
.check_types(db, argument_types)?;
4857+
.check_types(db, argument_types, &TypeContext::default())?;
48584858
if boundness == Boundness::PossiblyUnbound {
48594859
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
48604860
}

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ use crate::types::tuple::{TupleLength, TupleType};
3131
use crate::types::{
3232
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
3333
KnownClass, KnownInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
34-
TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums, ide_support, todo_type,
34+
TypeAliasType, TypeContext, TypeMapping, UnionType, WrapperDescriptorKind, enums, ide_support,
35+
todo_type,
3536
};
3637
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
3738
use ruff_python_ast::{self as ast, PythonVersion};
@@ -120,16 +121,22 @@ impl<'db> Bindings<'db> {
120121
/// You must provide an `argument_types` that was created from the same `arguments` that you
121122
/// provided to [`match_parameters`][Self::match_parameters].
122123
///
124+
/// The type context of the call expression is also used to infer the specialization of generic
125+
/// calls.
126+
///
123127
/// We update the bindings to include the return type of the call, the bound types for all
124128
/// parameters, and any errors resulting from binding the call, all for each union element and
125129
/// overload (if any).
126130
pub(crate) fn check_types(
127131
mut self,
128132
db: &'db dyn Db,
129133
argument_types: &CallArguments<'_, 'db>,
134+
call_expression_tcx: &TypeContext<'db>,
130135
) -> Result<Self, CallError<'db>> {
131136
for element in &mut self.elements {
132-
if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) {
137+
if let Some(mut updated_argument_forms) =
138+
element.check_types(db, argument_types, call_expression_tcx)
139+
{
133140
// If this element returned a new set of argument forms (indicating successful
134141
// argument type expansion), update the `Bindings` with these forms.
135142
updated_argument_forms.shrink_to_fit();
@@ -1279,6 +1286,7 @@ impl<'db> CallableBinding<'db> {
12791286
&mut self,
12801287
db: &'db dyn Db,
12811288
argument_types: &CallArguments<'_, 'db>,
1289+
call_expression_tcx: &TypeContext<'db>,
12821290
) -> Option<ArgumentForms> {
12831291
// If this callable is a bound method, prepend the self instance onto the arguments list
12841292
// before checking.
@@ -1291,15 +1299,15 @@ impl<'db> CallableBinding<'db> {
12911299
// still perform type checking for non-overloaded function to provide better user
12921300
// experience.
12931301
if let [overload] = self.overloads.as_mut_slice() {
1294-
overload.check_types(db, argument_types.as_ref());
1302+
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
12951303
}
12961304
return None;
12971305
}
12981306
MatchingOverloadIndex::Single(index) => {
12991307
// If only one candidate overload remains, it is the winning match. Evaluate it as
13001308
// a regular (non-overloaded) call.
13011309
self.matching_overload_index = Some(index);
1302-
self.overloads[index].check_types(db, argument_types.as_ref());
1310+
self.overloads[index].check_types(db, argument_types.as_ref(), call_expression_tcx);
13031311
return None;
13041312
}
13051313
MatchingOverloadIndex::Multiple(indexes) => {
@@ -1311,7 +1319,7 @@ impl<'db> CallableBinding<'db> {
13111319
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
13121320
// whether it is compatible with the supplied argument list.
13131321
for (_, overload) in self.matching_overloads_mut() {
1314-
overload.check_types(db, argument_types.as_ref());
1322+
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
13151323
}
13161324

13171325
match self.matching_overload_index() {
@@ -1428,7 +1436,7 @@ impl<'db> CallableBinding<'db> {
14281436
merged_argument_forms.merge(&argument_forms);
14291437

14301438
for (_, overload) in self.matching_overloads_mut() {
1431-
overload.check_types(db, expanded_arguments);
1439+
overload.check_types(db, expanded_arguments, call_expression_tcx);
14321440
}
14331441

14341442
let return_type = match self.matching_overload_index() {
@@ -2186,6 +2194,7 @@ struct ArgumentTypeChecker<'a, 'db> {
21862194
arguments: &'a CallArguments<'a, 'db>,
21872195
argument_matches: &'a [MatchedArgument<'db>],
21882196
parameter_tys: &'a mut [Option<Type<'db>>],
2197+
call_expression_tcx: &'a TypeContext<'db>,
21892198
errors: &'a mut Vec<BindingError<'db>>,
21902199

21912200
specialization: Option<Specialization<'db>>,
@@ -2199,6 +2208,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
21992208
arguments: &'a CallArguments<'a, 'db>,
22002209
argument_matches: &'a [MatchedArgument<'db>],
22012210
parameter_tys: &'a mut [Option<Type<'db>>],
2211+
call_expression_tcx: &'a TypeContext<'db>,
22022212
errors: &'a mut Vec<BindingError<'db>>,
22032213
) -> Self {
22042214
Self {
@@ -2207,6 +2217,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22072217
arguments,
22082218
argument_matches,
22092219
parameter_tys,
2220+
call_expression_tcx: call_expression_tcx,
22102221
errors,
22112222
specialization: None,
22122223
inherited_specialization: None,
@@ -2247,8 +2258,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22472258
return;
22482259
}
22492260

2250-
let parameters = self.signature.parameters();
22512261
let mut builder = SpecializationBuilder::new(self.db);
2262+
2263+
// Note that we infer the annotated type _before_ the arguments if this call is part of
2264+
// an annotated assignment, to closer match the order of any unions written in the type
2265+
// annotation.
2266+
if let Some(return_ty) = self.signature.return_ty
2267+
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
2268+
{
2269+
// Ignore any specialization errors here, because the type context is only used to
2270+
// optionally widen the return type.
2271+
let _ = builder.infer(return_ty, call_expression_tcx);
2272+
}
2273+
2274+
let parameters = self.signature.parameters();
22522275
for (argument_index, adjusted_argument_index, _, argument_type) in
22532276
self.enumerate_argument_types()
22542277
{
@@ -2259,6 +2282,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22592282
let Some(expected_type) = parameter.annotated_type() else {
22602283
continue;
22612284
};
2285+
22622286
if let Err(error) = builder.infer(
22632287
expected_type,
22642288
variadic_argument_type.unwrap_or(argument_type),
@@ -2270,6 +2294,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22702294
}
22712295
}
22722296
}
2297+
22732298
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
22742299
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
22752300
// The inherited generic context is used when inferring the specialization of a generic
@@ -2516,13 +2541,19 @@ impl<'db> Binding<'db> {
25162541
self.argument_matches = matcher.finish();
25172542
}
25182543

2519-
fn check_types(&mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>) {
2544+
fn check_types(
2545+
&mut self,
2546+
db: &'db dyn Db,
2547+
arguments: &CallArguments<'_, 'db>,
2548+
call_expression_tcx: &TypeContext<'db>,
2549+
) {
25202550
let mut checker = ArgumentTypeChecker::new(
25212551
db,
25222552
&self.signature,
25232553
arguments,
25242554
&self.argument_matches,
25252555
&mut self.parameter_tys,
2556+
call_expression_tcx,
25262557
&mut self.errors,
25272558
);
25282559

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ struct ExpressionWithContext<'db> {
352352
/// more precise inference results, aka "bidirectional type inference".
353353
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
354354
pub(crate) struct TypeContext<'db> {
355-
annotation: Option<Type<'db>>,
355+
pub(crate) annotation: Option<Type<'db>>,
356356
}
357357

358358
impl<'db> TypeContext<'db> {

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5772,7 +5772,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
57725772
fn infer_call_expression(
57735773
&mut self,
57745774
call_expression: &ast::ExprCall,
5775-
_tcx: TypeContext<'db>,
5775+
tcx: TypeContext<'db>,
57765776
) -> Type<'db> {
57775777
let ast::ExprCall {
57785778
range: _,
@@ -5950,7 +5950,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59505950
}
59515951
}
59525952

5953-
let mut bindings = match bindings.check_types(self.db(), &call_arguments) {
5953+
let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) {
59545954
Ok(bindings) => bindings,
59555955
Err(CallError(_, bindings)) => {
59565956
bindings.report_diagnostics(&self.context, call_expression.into());
@@ -8516,7 +8516,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
85168516
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
85178517
let bindings = match Bindings::from(binding)
85188518
.match_parameters(self.db(), &call_argument_types)
8519-
.check_types(self.db(), &call_argument_types)
8519+
.check_types(self.db(), &call_argument_types, &TypeContext::default())
85208520
{
85218521
Ok(bindings) => bindings,
85228522
Err(CallError(_, bindings)) => {

0 commit comments

Comments
 (0)