Skip to content

Commit c026788

Browse files
committed
use type context for inference of generic function calls
1 parent c0fb235 commit c026788

File tree

5 files changed

+90
-15
lines changed

5 files changed

+90
-15
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,48 @@ reveal_type(x) # revealed: Foo
234234
x: int = 1
235235
reveal_type(x) # revealed: Literal[1]
236236
```
237+
238+
## Annotations influence generic call inference
239+
240+
```toml
241+
[environment]
242+
python-version = "3.12"
243+
```
244+
245+
```py
246+
from typing import Literal
247+
248+
def f[T](x: T) -> list[T]:
249+
return [x]
250+
251+
a = f("a")
252+
reveal_type(a) # revealed: list[Literal["a"]]
253+
254+
b: list[int | Literal["a"]] = f("a")
255+
reveal_type(b) # revealed: list[int | Literal["a"]]
256+
257+
c: list[int | str] = f("a")
258+
reveal_type(c) # revealed: list[int | str]
259+
260+
d: list[int | tuple[int, int]] = f((1, 2))
261+
reveal_type(d) # revealed: list[int | tuple[int, int]]
262+
263+
e: list[int] = f(True)
264+
reveal_type(e) # revealed: list[int]
265+
266+
# TODO: the RHS should be inferred as `list[Literal["a"]]` here
267+
# error: [invalid-assignment] "Object of type `list[int | Literal["a"]]` is not assignable to `list[int]`"
268+
g: list[int] = f("a")
269+
270+
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
271+
h: tuple[int] = f("a")
272+
273+
def f2[T: int](x: T) -> T:
274+
return x
275+
276+
i: int = f2(True)
277+
reveal_type(i) # revealed: int
278+
279+
j: int | str = f2(True)
280+
reveal_type(j) # revealed: Literal[True]
281+
```

crates/ty_python_semantic/src/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4805,7 +4805,7 @@ impl<'db> Type<'db> {
48054805
) -> Result<Bindings<'db>, CallError<'db>> {
48064806
self.bindings(db)
48074807
.match_parameters(db, argument_types)
4808-
.check_types(db, argument_types)
4808+
.check_types(db, argument_types, &TypeContext::default())
48094809
}
48104810

48114811
/// Look up a dunder method on the meta-type of `self` and call it.
@@ -4854,7 +4854,7 @@ impl<'db> Type<'db> {
48544854
let bindings = dunder_callable
48554855
.bindings(db)
48564856
.match_parameters(db, argument_types)
4857-
.check_types(db, argument_types)?;
4857+
.check_types(db, argument_types, &TypeContext::default())?;
48584858
if boundness == Boundness::PossiblyUnbound {
48594859
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
48604860
}

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ use crate::types::tuple::{TupleLength, TupleType};
3232
use crate::types::{
3333
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
3434
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
35-
TrackedConstraintSet, TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums,
36-
ide_support, todo_type,
35+
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType,
36+
WrapperDescriptorKind, enums, ide_support, todo_type,
3737
};
3838
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
3939
use ruff_python_ast::{self as ast, PythonVersion};
@@ -122,16 +122,22 @@ impl<'db> Bindings<'db> {
122122
/// You must provide an `argument_types` that was created from the same `arguments` that you
123123
/// provided to [`match_parameters`][Self::match_parameters].
124124
///
125+
/// The type context of the call expression is also used to infer the specialization of generic
126+
/// calls.
127+
///
125128
/// We update the bindings to include the return type of the call, the bound types for all
126129
/// parameters, and any errors resulting from binding the call, all for each union element and
127130
/// overload (if any).
128131
pub(crate) fn check_types(
129132
mut self,
130133
db: &'db dyn Db,
131134
argument_types: &CallArguments<'_, 'db>,
135+
call_expression_tcx: &TypeContext<'db>,
132136
) -> Result<Self, CallError<'db>> {
133137
for element in &mut self.elements {
134-
if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) {
138+
if let Some(mut updated_argument_forms) =
139+
element.check_types(db, argument_types, call_expression_tcx)
140+
{
135141
// If this element returned a new set of argument forms (indicating successful
136142
// argument type expansion), update the `Bindings` with these forms.
137143
updated_argument_forms.shrink_to_fit();
@@ -1281,6 +1287,7 @@ impl<'db> CallableBinding<'db> {
12811287
&mut self,
12821288
db: &'db dyn Db,
12831289
argument_types: &CallArguments<'_, 'db>,
1290+
call_expression_tcx: &TypeContext<'db>,
12841291
) -> Option<ArgumentForms> {
12851292
// If this callable is a bound method, prepend the self instance onto the arguments list
12861293
// before checking.
@@ -1293,15 +1300,15 @@ impl<'db> CallableBinding<'db> {
12931300
// still perform type checking for non-overloaded function to provide better user
12941301
// experience.
12951302
if let [overload] = self.overloads.as_mut_slice() {
1296-
overload.check_types(db, argument_types.as_ref());
1303+
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
12971304
}
12981305
return None;
12991306
}
13001307
MatchingOverloadIndex::Single(index) => {
13011308
// If only one candidate overload remains, it is the winning match. Evaluate it as
13021309
// a regular (non-overloaded) call.
13031310
self.matching_overload_index = Some(index);
1304-
self.overloads[index].check_types(db, argument_types.as_ref());
1311+
self.overloads[index].check_types(db, argument_types.as_ref(), call_expression_tcx);
13051312
return None;
13061313
}
13071314
MatchingOverloadIndex::Multiple(indexes) => {
@@ -1313,7 +1320,7 @@ impl<'db> CallableBinding<'db> {
13131320
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
13141321
// whether it is compatible with the supplied argument list.
13151322
for (_, overload) in self.matching_overloads_mut() {
1316-
overload.check_types(db, argument_types.as_ref());
1323+
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
13171324
}
13181325

13191326
match self.matching_overload_index() {
@@ -1430,7 +1437,7 @@ impl<'db> CallableBinding<'db> {
14301437
merged_argument_forms.merge(&argument_forms);
14311438

14321439
for (_, overload) in self.matching_overloads_mut() {
1433-
overload.check_types(db, expanded_arguments);
1440+
overload.check_types(db, expanded_arguments, call_expression_tcx);
14341441
}
14351442

14361443
let return_type = match self.matching_overload_index() {
@@ -2243,6 +2250,7 @@ struct ArgumentTypeChecker<'a, 'db> {
22432250
arguments: &'a CallArguments<'a, 'db>,
22442251
argument_matches: &'a [MatchedArgument<'db>],
22452252
parameter_tys: &'a mut [Option<Type<'db>>],
2253+
call_expression_tcx: &'a TypeContext<'db>,
22462254
errors: &'a mut Vec<BindingError<'db>>,
22472255

22482256
specialization: Option<Specialization<'db>>,
@@ -2256,6 +2264,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22562264
arguments: &'a CallArguments<'a, 'db>,
22572265
argument_matches: &'a [MatchedArgument<'db>],
22582266
parameter_tys: &'a mut [Option<Type<'db>>],
2267+
call_expression_tcx: &'a TypeContext<'db>,
22592268
errors: &'a mut Vec<BindingError<'db>>,
22602269
) -> Self {
22612270
Self {
@@ -2264,6 +2273,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
22642273
arguments,
22652274
argument_matches,
22662275
parameter_tys,
2276+
call_expression_tcx,
22672277
errors,
22682278
specialization: None,
22692279
inherited_specialization: None,
@@ -2304,8 +2314,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
23042314
return;
23052315
}
23062316

2307-
let parameters = self.signature.parameters();
23082317
let mut builder = SpecializationBuilder::new(self.db);
2318+
2319+
// Note that we infer the annotated type _before_ the arguments if this call is part of
2320+
// an annotated assignment, to closer match the order of any unions written in the type
2321+
// annotation.
2322+
if let Some(return_ty) = self.signature.return_ty
2323+
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
2324+
{
2325+
// Ignore any specialization errors here, because the type context is only used to
2326+
// optionally widen the return type.
2327+
let _ = builder.infer(return_ty, call_expression_tcx);
2328+
}
2329+
2330+
let parameters = self.signature.parameters();
23092331
for (argument_index, adjusted_argument_index, _, argument_type) in
23102332
self.enumerate_argument_types()
23112333
{
@@ -2316,6 +2338,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
23162338
let Some(expected_type) = parameter.annotated_type() else {
23172339
continue;
23182340
};
2341+
23192342
if let Err(error) = builder.infer(
23202343
expected_type,
23212344
variadic_argument_type.unwrap_or(argument_type),
@@ -2327,6 +2350,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
23272350
}
23282351
}
23292352
}
2353+
23302354
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
23312355
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
23322356
// The inherited generic context is used when inferring the specialization of a generic
@@ -2688,13 +2712,19 @@ impl<'db> Binding<'db> {
26882712
self.argument_matches = matcher.finish();
26892713
}
26902714

2691-
fn check_types(&mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>) {
2715+
fn check_types(
2716+
&mut self,
2717+
db: &'db dyn Db,
2718+
arguments: &CallArguments<'_, 'db>,
2719+
call_expression_tcx: &TypeContext<'db>,
2720+
) {
26922721
let mut checker = ArgumentTypeChecker::new(
26932722
db,
26942723
&self.signature,
26952724
arguments,
26962725
&self.argument_matches,
26972726
&mut self.parameter_tys,
2727+
call_expression_tcx,
26982728
&mut self.errors,
26992729
);
27002730

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ struct ExpressionWithContext<'db> {
352352
/// more precise inference results, aka "bidirectional type inference".
353353
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
354354
pub(crate) struct TypeContext<'db> {
355-
annotation: Option<Type<'db>>,
355+
pub(crate) annotation: Option<Type<'db>>,
356356
}
357357

358358
impl<'db> TypeContext<'db> {

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5775,7 +5775,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
57755775
fn infer_call_expression(
57765776
&mut self,
57775777
call_expression: &ast::ExprCall,
5778-
_tcx: TypeContext<'db>,
5778+
tcx: TypeContext<'db>,
57795779
) -> Type<'db> {
57805780
let ast::ExprCall {
57815781
range: _,
@@ -5955,7 +5955,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59555955
}
59565956
}
59575957

5958-
let mut bindings = match bindings.check_types(self.db(), &call_arguments) {
5958+
let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) {
59595959
Ok(bindings) => bindings,
59605960
Err(CallError(_, bindings)) => {
59615961
bindings.report_diagnostics(&self.context, call_expression.into());
@@ -8521,7 +8521,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
85218521
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
85228522
let bindings = match Bindings::from(binding)
85238523
.match_parameters(self.db(), &call_argument_types)
8524-
.check_types(self.db(), &call_argument_types)
8524+
.check_types(self.db(), &call_argument_types, &TypeContext::default())
85258525
{
85268526
Ok(bindings) => bindings,
85278527
Err(CallError(_, bindings)) => {

0 commit comments

Comments
 (0)