Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 20 additions & 40 deletions crates/ty_python_semantic/resources/mdtest/call/overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ reveal_type(f(A())) # revealed: A
reveal_type(f(*(A(),))) # revealed: A

reveal_type(f(B())) # revealed: A
# TODO: revealed: A
reveal_type(f(*(B(),))) # revealed: Unknown
reveal_type(f(*(B(),))) # revealed: A

# But, in this case, the arity check filters out the first overload, so we only have one match:
reveal_type(f(B(), 1)) # revealed: B
Expand Down Expand Up @@ -551,16 +550,13 @@ from overloaded import MyEnumSubclass, ActualEnum, f

def _(actual_enum: ActualEnum, my_enum_instance: MyEnumSubclass):
reveal_type(f(actual_enum)) # revealed: Both
# TODO: revealed: Both
reveal_type(f(*(actual_enum,))) # revealed: Unknown
reveal_type(f(*(actual_enum,))) # revealed: Both

reveal_type(f(ActualEnum.A)) # revealed: OnlyA
# TODO: revealed: OnlyA
reveal_type(f(*(ActualEnum.A,))) # revealed: Unknown
reveal_type(f(*(ActualEnum.A,))) # revealed: OnlyA

reveal_type(f(ActualEnum.B)) # revealed: OnlyB
# TODO: revealed: OnlyB
reveal_type(f(*(ActualEnum.B,))) # revealed: Unknown
reveal_type(f(*(ActualEnum.B,))) # revealed: OnlyB

reveal_type(f(my_enum_instance)) # revealed: MyEnumSubclass
reveal_type(f(*(my_enum_instance,))) # revealed: MyEnumSubclass
Expand Down Expand Up @@ -1097,12 +1093,10 @@ reveal_type(f(*(1,))) # revealed: str

def _(list_int: list[int], list_any: list[Any]):
reveal_type(f(list_int)) # revealed: int
# TODO: revealed: int
reveal_type(f(*(list_int,))) # revealed: Unknown
reveal_type(f(*(list_int,))) # revealed: int

reveal_type(f(list_any)) # revealed: int
# TODO: revealed: int
reveal_type(f(*(list_any,))) # revealed: Unknown
reveal_type(f(*(list_any,))) # revealed: int
```

### Single list argument (ambiguous)
Expand Down Expand Up @@ -1136,8 +1130,7 @@ def _(list_int: list[int], list_any: list[Any]):
# All materializations of `list[int]` are assignable to `list[int]`, so it matches the first
# overload.
reveal_type(f(list_int)) # revealed: int
# TODO: revealed: int
reveal_type(f(*(list_int,))) # revealed: Unknown
reveal_type(f(*(list_int,))) # revealed: int

# All materializations of `list[Any]` are assignable to `list[int]` and `list[Any]`, but the
# return type of first and second overloads are not equivalent, so the overload matching
Expand Down Expand Up @@ -1170,25 +1163,21 @@ reveal_type(f("a")) # revealed: str
reveal_type(f(*("a",))) # revealed: str

reveal_type(f((1, "b"))) # revealed: int
# TODO: revealed: int
reveal_type(f(*((1, "b"),))) # revealed: Unknown
reveal_type(f(*((1, "b"),))) # revealed: int

reveal_type(f((1, 2))) # revealed: int
# TODO: revealed: int
reveal_type(f(*((1, 2),))) # revealed: Unknown
reveal_type(f(*((1, 2),))) # revealed: int

def _(int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, Any]):
# All materializations are assignable to first overload, so second and third overloads are
# eliminated
reveal_type(f(int_str)) # revealed: int
# TODO: revealed: int
reveal_type(f(*(int_str,))) # revealed: Unknown
reveal_type(f(*(int_str,))) # revealed: int

# All materializations are assignable to second overload, so the third overload is eliminated;
# the return type of first and second overload is equivalent
reveal_type(f(int_any)) # revealed: int
# TODO: revealed: int
reveal_type(f(*(int_any,))) # revealed: Unknown
reveal_type(f(*(int_any,))) # revealed: int

# All materializations of `tuple[Any, Any]` are assignable to the parameters of all the
# overloads, but the return types aren't equivalent, so the overload matching is ambiguous
Expand Down Expand Up @@ -1266,26 +1255,22 @@ def _(list_int: list[int], list_any: list[Any], int_str: tuple[int, str], int_an
# All materializations of both argument types are assignable to the first overload, so the
# second and third overloads are filtered out
reveal_type(f(list_int, int_str)) # revealed: A
# TODO: revealed: A
reveal_type(f(*(list_int, int_str))) # revealed: Unknown
reveal_type(f(*(list_int, int_str))) # revealed: A

# All materialization of first argument is assignable to first overload and for the second
# argument, they're assignable to the second overload, so the third overload is filtered out
reveal_type(f(list_int, int_any)) # revealed: A
# TODO: revealed: A
reveal_type(f(*(list_int, int_any))) # revealed: Unknown
reveal_type(f(*(list_int, int_any))) # revealed: A

# All materialization of first argument is assignable to second overload and for the second
# argument, they're assignable to the first overload, so the third overload is filtered out
reveal_type(f(list_any, int_str)) # revealed: A
# TODO: revealed: A
reveal_type(f(*(list_any, int_str))) # revealed: Unknown
reveal_type(f(*(list_any, int_str))) # revealed: A

# All materializations of both arguments are assignable to the second overload, so the third
# overload is filtered out
reveal_type(f(list_any, int_any)) # revealed: A
# TODO: revealed: A
reveal_type(f(*(list_any, int_any))) # revealed: Unknown
reveal_type(f(*(list_any, int_any))) # revealed: A

# All materializations of first argument is assignable to the second overload and for the second
# argument, they're assignable to the third overload, so no overloads are filtered out; the
Expand Down Expand Up @@ -1316,8 +1301,7 @@ from overloaded import f

def _(literal: LiteralString, string: str, any: Any):
reveal_type(f(literal)) # revealed: LiteralString
# TODO: revealed: LiteralString
reveal_type(f(*(literal,))) # revealed: Unknown
reveal_type(f(*(literal,))) # revealed: LiteralString

reveal_type(f(string)) # revealed: str
reveal_type(f(*(string,))) # revealed: str
Expand Down Expand Up @@ -1355,12 +1339,10 @@ from overloaded import f

def _(list_int: list[int], list_str: list[str], list_any: list[Any], any: Any):
reveal_type(f(list_int)) # revealed: A
# TODO: revealed: A
reveal_type(f(*(list_int,))) # revealed: Unknown
reveal_type(f(*(list_int,))) # revealed: A

reveal_type(f(list_str)) # revealed: str
# TODO: Should be `str`
reveal_type(f(*(list_str,))) # revealed: Unknown
reveal_type(f(*(list_str,))) # revealed: str

reveal_type(f(list_any)) # revealed: Unknown
reveal_type(f(*(list_any,))) # revealed: Unknown
Expand Down Expand Up @@ -1561,12 +1543,10 @@ def _(any: Any):
reveal_type(f(*(any,), flag=False)) # revealed: str

def _(args: tuple[Any, Literal[True]]):
# TODO: revealed: int
reveal_type(f(*args)) # revealed: Unknown
reveal_type(f(*args)) # revealed: int

def _(args: tuple[Any, Literal[False]]):
# TODO: revealed: str
reveal_type(f(*args)) # revealed: Unknown
reveal_type(f(*args)) # revealed: str
```

### Argument type expansion
Expand Down
110 changes: 79 additions & 31 deletions crates/ty_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType,
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionBuilder, UnionType,
WrapperDescriptorKind, enums, ide_support, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
Expand Down Expand Up @@ -1588,48 +1588,82 @@ impl<'db> CallableBinding<'db> {
arguments: &CallArguments<'_, 'db>,
matching_overload_indexes: &[usize],
) {
// The maximum number of parameters across all the overloads that are being considered
// for filtering.
let max_parameter_count = matching_overload_indexes
.iter()
.map(|&index| self.overloads[index].signature.parameters().len())
.max()
.unwrap_or(0);

// These are the parameter indexes that matches the arguments that participate in the
// filtering process.
//
// The parameter types at these indexes have at least one overload where the type isn't
// gradual equivalent to the parameter types at the same index for other overloads.
let mut participating_parameter_indexes = HashSet::new();

// These only contain the top materialized argument types for the corresponding
// participating parameter indexes.
let mut top_materialized_argument_types = vec![];

for (argument_index, argument_type) in arguments.iter_types().enumerate() {
let mut first_parameter_type: Option<Type<'db>> = None;
let mut participating_parameter_index = None;
// The parameter types at each index for the first overload containing a parameter at
// that index.
let mut first_parameter_types: Vec<Option<Type<'db>>> = vec![None; max_parameter_count];

'overload: for overload_index in matching_overload_indexes {
for argument_index in 0..arguments.len() {
for overload_index in matching_overload_indexes {
let overload = &self.overloads[*overload_index];
for parameter_index in &overload.argument_matches[argument_index].parameters {
for &parameter_index in &overload.argument_matches[argument_index].parameters {
// TODO: For an unannotated `self` / `cls` parameter, the type should be
// `typing.Self` / `type[typing.Self]`
let current_parameter_type = overload.signature.parameters()[*parameter_index]
let current_parameter_type = overload.signature.parameters()[parameter_index]
.annotated_type()
.unwrap_or(Type::unknown());
let first_parameter_type = &mut first_parameter_types[parameter_index];
if let Some(first_parameter_type) = first_parameter_type {
if !first_parameter_type.is_equivalent_to(db, current_parameter_type) {
participating_parameter_index = Some(*parameter_index);
break 'overload;
participating_parameter_indexes.insert(parameter_index);
}
} else {
first_parameter_type = Some(current_parameter_type);
*first_parameter_type = Some(current_parameter_type);
}
}
}
}

if let Some(parameter_index) = participating_parameter_index {
participating_parameter_indexes.insert(parameter_index);
top_materialized_argument_types.push(argument_type.top_materialization(db));
let mut union_argument_type_builders = std::iter::repeat_with(|| UnionBuilder::new(db))
.take(max_parameter_count)
.collect::<Vec<_>>();

for (argument_index, argument_type) in arguments.iter_types().enumerate() {
for overload_index in matching_overload_indexes {
let overload = &self.overloads[*overload_index];
for (parameter_index, variadic_argument_type) in
overload.argument_matches[argument_index].iter()
{
if !participating_parameter_indexes.contains(&parameter_index) {
continue;
}
union_argument_type_builders[parameter_index].add_in_place(
variadic_argument_type
.unwrap_or(argument_type)
.top_materialization(db),
);
}
}
}

let top_materialized_argument_type =
Type::heterogeneous_tuple(db, top_materialized_argument_types);
// These only contain the top materialized argument types for the corresponding
// participating parameter indexes.
let top_materialized_argument_type = Type::heterogeneous_tuple(
db,
union_argument_type_builders
.into_iter()
.filter_map(|builder| {
if builder.is_empty() {
None
} else {
Some(builder.build())
}
}),
);

// A flag to indicate whether we've found the overload that makes the remaining overloads
// unmatched for the given argument types.
Expand All @@ -1640,15 +1674,22 @@ impl<'db> CallableBinding<'db> {
self.overloads[*current_index].mark_as_unmatched_overload();
continue;
}
let mut parameter_types = Vec::with_capacity(arguments.len());

let mut union_parameter_types = std::iter::repeat_with(|| UnionBuilder::new(db))
.take(max_parameter_count)
.collect::<Vec<_>>();

// The number of parameters that have been skipped because they don't participate in
// the filtering process. This is used to make sure the types are added to the
// corresponding parameter index in `union_parameter_types`.
let mut skipped_parameters = 0;

for argument_index in 0..arguments.len() {
// The parameter types at the current argument index.
let mut current_parameter_types = vec![];
for overload_index in &matching_overload_indexes[..=upto] {
let overload = &self.overloads[*overload_index];
for parameter_index in &overload.argument_matches[argument_index].parameters {
if !participating_parameter_indexes.contains(parameter_index) {
// This parameter doesn't participate in the filtering process.
skipped_parameters += 1;
continue;
}
// TODO: For an unannotated `self` / `cls` parameter, the type should be
Expand All @@ -1664,17 +1705,24 @@ impl<'db> CallableBinding<'db> {
parameter_type =
parameter_type.apply_specialization(db, inherited_specialization);
}
current_parameter_types.push(parameter_type);
union_parameter_types[parameter_index.saturating_sub(skipped_parameters)]
.add_in_place(parameter_type);
}
}
if current_parameter_types.is_empty() {
continue;
}
parameter_types.push(UnionType::from_elements(db, current_parameter_types));
}
if top_materialized_argument_type
.is_assignable_to(db, Type::heterogeneous_tuple(db, parameter_types))
{

let parameter_types = Type::heterogeneous_tuple(
db,
union_parameter_types.into_iter().filter_map(|builder| {
if builder.is_empty() {
None
} else {
Some(builder.build())
}
}),
);

if top_materialized_argument_type.is_assignable_to(db, parameter_types) {
filter_remaining_overloads = true;
}
}
Expand Down
Loading