Skip to content
38 changes: 35 additions & 3 deletions crates/ty_python_semantic/resources/mdtest/typed_dict.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,19 +450,51 @@ class Person(TypedDict, total=False):

```py
from typing import TypedDict
from typing_extensions import NotRequired

class Person(TypedDict):
name: str
age: int | None
extra: NotRequired[str]

def _(p: Person) -> None:
reveal_type(p.keys()) # revealed: dict_keys[str, object]
reveal_type(p.values()) # revealed: dict_values[str, object]

reveal_type(p.setdefault("name", "Alice")) # revealed: @Todo(Support for `TypedDict`)
# `get()` returns the field type for required keys (no None union)
reveal_type(p.get("name")) # revealed: str
reveal_type(p.get("age")) # revealed: int | None

reveal_type(p.get("name")) # revealed: @Todo(Support for `TypedDict`)
reveal_type(p.get("name", "Unknown")) # revealed: @Todo(Support for `TypedDict`)
# It doesn't matter if a default is specified:
reveal_type(p.get("name", "default")) # revealed: str
reveal_type(p.get("age", 999)) # revealed: int | None

# `get()` can return `None` for non-required keys
reveal_type(p.get("extra")) # revealed: str | None
reveal_type(p.get("extra", "default")) # revealed: str

# The type of the default parameter can be anything:
reveal_type(p.get("extra", 0)) # revealed: str | Literal[0]

# We allow access to unknown keys (they could be set for a subtype of Person)
reveal_type(p.get("unknown")) # revealed: Unknown
reveal_type(p.get("unknown", "default")) # revealed: Unknown | Literal["default"]

# `pop()` only works on non-required fields
reveal_type(p.pop("extra")) # revealed: str
reveal_type(p.pop("extra", "fallback")) # revealed: str
# error: [invalid-argument-type] "Cannot pop required field 'name' from TypedDict `Person`"
reveal_type(p.pop("name")) # revealed: Unknown
Comment on lines +486 to +487
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like ideally we'd still infer str here (despite the fact that the operation is illegal, it seems pretty clear that the result of this operation will always be a str), but I can see that that might be difficult to do consistently with the synthesized-overload approach. And consistently inferring Unknown here seems better than sometimes inferring Unknown, sometimes str (depending on exactly how the method is invoked on the typeddict object).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like ideally we'd still infer str here (despite the fact that the operation is illegal, it seems pretty clear that the result of this operation will always be a str)

Yes, interesting point. I agree that this would probably be the better return type, but as you said, it's not possible to achieve that with synthesized overloads, unless we manage to synthesize overloads that trigger specific CallErrors when selected, or similar. Until then, Unknown is not ideal, but it's also not wrong. And if you're going to # type: ignore your illegal .pop operation, you might as well throw in a cast or a redeclaration :-)


# Similar to above, the default parameter can be of any type:
reveal_type(p.pop("extra", 0)) # revealed: str | Literal[0]

# `setdefault()` always returns the field type
reveal_type(p.setdefault("name", "Alice")) # revealed: str
reveal_type(p.setdefault("extra", "default")) # revealed: str

# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extraz" - did you mean "extra"?"
reveal_type(p.setdefault("extraz", "value")) # revealed: Unknown
```

## Unlike normal classes
Expand Down
22 changes: 22 additions & 0 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7528,6 +7528,28 @@ pub struct BoundTypeVarInstance<'db> {
impl get_size2::GetSize for BoundTypeVarInstance<'_> {}

impl<'db> BoundTypeVarInstance<'db> {
/// Create a new PEP 695 type variable that can be used in signatures
/// of synthetic generic functions.
pub(crate) fn synthetic(
db: &'db dyn Db,
name: &'static str,
variance: TypeVarVariance,
) -> Self {
Self::new(
db,
TypeVarInstance::new(
db,
Name::new_static(name),
None, // definition
None, // _bound_or_constraints
Some(variance),
None, // _default
TypeVarKind::Pep695,
),
BindingContext::Synthetic,
)
}

pub(crate) fn variance_with_polarity(
self,
db: &'db dyn Db,
Expand Down
203 changes: 165 additions & 38 deletions crates/ty_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::types::{
IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping,
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams,
VarianceInferable, declaration_type, infer_definition_types, todo_type,
UnionBuilder, VarianceInferable, declaration_type, infer_definition_types,
};
use crate::{
Db, FxIndexMap, FxOrderSet, Program,
Expand All @@ -51,7 +51,7 @@ use crate::{
semantic_index, use_def_map,
},
types::{
CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionBuilder, UnionType,
CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionType,
definition_expression_type,
},
};
Expand Down Expand Up @@ -2331,49 +2331,176 @@ impl<'db> ClassLiteral<'db> {
)))
}
(CodeGeneratorKind::TypedDict, "get") => {
// TODO: synthesize a set of overloads with precise types
let signature = Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key"))),
Parameter::positional_only(Some(Name::new_static("default")))
.with_default_type(Type::unknown()),
]),
Some(todo_type!("Support for `TypedDict`")),
);
let overloads = self
.fields(db, specialization, field_policy)
.into_iter()
.flat_map(|(name, field)| {
let key_type =
Type::StringLiteral(StringLiteralType::new(db, name.as_str()));

// For a required key, `.get()` always returns the value type. For a non-required key,
// `.get()` returns the union of the value type and the type of the default argument
// (which defaults to `None`).

// TODO: For now, we use two overloads here. They can be merged into a single function
// once the generics solver takes default arguments into account.

let get_sig = Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
]),
Some(if field.is_required() {
field.declared_ty
} else {
UnionType::from_elements(db, [field.declared_ty, Type::none(db)])
}),
);

Some(CallableType::function_like(db, signature))
let t_default =
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);

let get_with_default_sig = Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [t_default])),
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
Parameter::positional_only(Some(Name::new_static("default")))
.with_annotated_type(Type::TypeVar(t_default)),
]),
Some(if field.is_required() {
field.declared_ty
} else {
UnionType::from_elements(
db,
[field.declared_ty, Type::TypeVar(t_default)],
)
}),
);

[get_sig, get_with_default_sig]
})
// Fallback overloads for unknown keys
.chain(std::iter::once({
Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(KnownClass::Str.to_instance(db)),
]),
Some(Type::unknown()),
)
}))
.chain(std::iter::once({
let t_default =
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);

Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [t_default])),
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(KnownClass::Str.to_instance(db)),
Parameter::positional_only(Some(Name::new_static("default")))
.with_annotated_type(Type::TypeVar(t_default)),
]),
Some(UnionType::from_elements(
db,
[Type::unknown(), Type::TypeVar(t_default)],
)),
)
}));

Some(Type::Callable(CallableType::new(
db,
CallableSignature::from_overloads(overloads),
true,
)))
}
(CodeGeneratorKind::TypedDict, "pop") => {
// TODO: synthesize a set of overloads with precise types.
// Required keys should be forbidden to be popped.
let signature = Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key"))),
Parameter::positional_only(Some(Name::new_static("default")))
.with_default_type(Type::unknown()),
]),
Some(todo_type!("Support for `TypedDict`")),
);
let fields = self.fields(db, specialization, field_policy);
let overloads = fields
.iter()
.filter(|(_, field)| {
// Only synthesize `pop` for fields that are not required.
!field.is_required()
})
.flat_map(|(name, field)| {
let key_type =
Type::StringLiteral(StringLiteralType::new(db, name.as_str()));

// TODO: Similar to above: consider merging these two overloads into one

// `.pop()` without default
let pop_sig = Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
]),
Some(field.declared_ty),
);

Some(CallableType::function_like(db, signature))
// `.pop()` with a default value
let t_default =
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);

let pop_with_default_sig = Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [t_default])),
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
Parameter::positional_only(Some(Name::new_static("default")))
.with_annotated_type(Type::TypeVar(t_default)),
]),
Some(UnionType::from_elements(
db,
[field.declared_ty, Type::TypeVar(t_default)],
)),
);

[pop_sig, pop_with_default_sig]
});

Some(Type::Callable(CallableType::new(
db,
CallableSignature::from_overloads(overloads),
true,
)))
}
(CodeGeneratorKind::TypedDict, "setdefault") => {
// TODO: synthesize a set of overloads with precise types
let signature = Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key"))),
Parameter::positional_only(Some(Name::new_static("default"))),
]),
Some(todo_type!("Support for `TypedDict`")),
);
let fields = self.fields(db, specialization, field_policy);
let overloads = fields.iter().map(|(name, field)| {
let key_type = Type::StringLiteral(StringLiteralType::new(db, name.as_str()));

Some(CallableType::function_like(db, signature))
// `setdefault` always returns the field type
Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
Parameter::positional_only(Some(Name::new_static("default")))
.with_annotated_type(field.declared_ty),
]),
Some(field.declared_ty),
)
});

Some(Type::Callable(CallableType::new(
db,
CallableSignature::from_overloads(overloads),
true,
)))
}
(CodeGeneratorKind::TypedDict, "update") => {
// TODO: synthesize a set of overloads with precise types
Expand Down
15 changes: 15 additions & 0 deletions crates/ty_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2952,6 +2952,21 @@ pub(crate) fn report_missing_typed_dict_key<'db>(
}
}

pub(crate) fn report_cannot_pop_required_field_on_typed_dict<'db>(
context: &InferContext<'db, '_>,
key_node: AnyNodeRef,
typed_dict_ty: Type<'db>,
field_name: &str,
) {
let db = context.db();
if let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, key_node) {
let typed_dict_name = typed_dict_ty.display(db);
builder.into_diagnostic(format_args!(
"Cannot pop required field '{field_name}' from TypedDict `{typed_dict_name}`",
));
}
}

/// This function receives an unresolved `from foo import bar` import,
/// where `foo` can be resolved to a module but that module does not
/// have a `bar` member or submodule.
Expand Down
28 changes: 17 additions & 11 deletions crates/ty_python_semantic/src/types/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,19 @@ impl<'db> GenericContext<'db> {
binding_context: Definition<'db>,
type_params_node: &ast::TypeParams,
) -> Self {
let variables: FxOrderSet<_> = type_params_node
.iter()
.filter_map(|type_param| {
Self::variable_from_type_param(db, index, binding_context, type_param)
})
.collect();
Self::new(db, variables)
let variables = type_params_node.iter().filter_map(|type_param| {
Self::variable_from_type_param(db, index, binding_context, type_param)
});

Self::from_typevar_instances(db, variables)
}

/// Creates a generic context from a list of `BoundTypeVarInstance`s.
pub(crate) fn from_typevar_instances(
db: &'db dyn Db,
type_params: impl IntoIterator<Item = BoundTypeVarInstance<'db>>,
) -> Self {
Self::new(db, type_params.into_iter().collect::<FxOrderSet<_>>())
}

fn variable_from_type_param(
Expand Down Expand Up @@ -365,12 +371,12 @@ impl<'db> GenericContext<'db> {
}

pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
let variables: FxOrderSet<_> = self
let variables = self
.variables(db)
.iter()
.map(|bound_typevar| bound_typevar.normalized_impl(db, visitor))
.collect();
Self::new(db, variables)
.map(|bound_typevar| bound_typevar.normalized_impl(db, visitor));

Self::from_typevar_instances(db, variables)
}

fn heap_size((variables,): &(FxOrderSet<BoundTypeVarInstance<'db>>,)) -> usize {
Expand Down
Loading
Loading