Skip to content

Commit 3bf43d9

Browse files
PrettyWoodsharkdp
authored andcommitted
[ty] typecheck dict methods for TypedDict (astral-sh#19874)
## Summary Typecheck `get()`, `setdefault()`, `pop()` for `TypedDict` ```py from typing import TypedDict from typing_extensions import NotRequired class Employee(TypedDict): name: str department: NotRequired[str] emp = Employee(name="Alice", department="Engineering") emp.get("name") emp.get("departmen", "Unknown") emp.pop("department") emp.pop("name") ``` <img width="838" height="529" alt="Screenshot 2025-08-12 at 11 42 12" src="https://github.com/user-attachments/assets/77ce150a-223c-4931-b914-551095d8a3a6" /> part of astral-sh/ty#154 ## Test Plan Updated Markdown tests --------- Co-authored-by: David Peter <[email protected]>
1 parent e592973 commit 3bf43d9

File tree

6 files changed

+312
-60
lines changed

6 files changed

+312
-60
lines changed

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,19 +450,51 @@ class Person(TypedDict, total=False):
450450

451451
```py
452452
from typing import TypedDict
453+
from typing_extensions import NotRequired
453454

454455
class Person(TypedDict):
455456
name: str
456457
age: int | None
458+
extra: NotRequired[str]
457459

458460
def _(p: Person) -> None:
459461
reveal_type(p.keys()) # revealed: dict_keys[str, object]
460462
reveal_type(p.values()) # revealed: dict_values[str, object]
461463

462-
reveal_type(p.setdefault("name", "Alice")) # revealed: @Todo(Support for `TypedDict`)
464+
# `get()` returns the field type for required keys (no None union)
465+
reveal_type(p.get("name")) # revealed: str
466+
reveal_type(p.get("age")) # revealed: int | None
463467

464-
reveal_type(p.get("name")) # revealed: @Todo(Support for `TypedDict`)
465-
reveal_type(p.get("name", "Unknown")) # revealed: @Todo(Support for `TypedDict`)
468+
# It doesn't matter if a default is specified:
469+
reveal_type(p.get("name", "default")) # revealed: str
470+
reveal_type(p.get("age", 999)) # revealed: int | None
471+
472+
# `get()` can return `None` for non-required keys
473+
reveal_type(p.get("extra")) # revealed: str | None
474+
reveal_type(p.get("extra", "default")) # revealed: str
475+
476+
# The type of the default parameter can be anything:
477+
reveal_type(p.get("extra", 0)) # revealed: str | Literal[0]
478+
479+
# We allow access to unknown keys (they could be set for a subtype of Person)
480+
reveal_type(p.get("unknown")) # revealed: Unknown | None
481+
reveal_type(p.get("unknown", "default")) # revealed: Unknown | Literal["default"]
482+
483+
# `pop()` only works on non-required fields
484+
reveal_type(p.pop("extra")) # revealed: str
485+
reveal_type(p.pop("extra", "fallback")) # revealed: str
486+
# error: [invalid-argument-type] "Cannot pop required field 'name' from TypedDict `Person`"
487+
reveal_type(p.pop("name")) # revealed: Unknown
488+
489+
# Similar to above, the default parameter can be of any type:
490+
reveal_type(p.pop("extra", 0)) # revealed: str | Literal[0]
491+
492+
# `setdefault()` always returns the field type
493+
reveal_type(p.setdefault("name", "Alice")) # revealed: str
494+
reveal_type(p.setdefault("extra", "default")) # revealed: str
495+
496+
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extraz" - did you mean "extra"?"
497+
reveal_type(p.setdefault("extraz", "value")) # revealed: Unknown
466498
```
467499

468500
## Unlike normal classes

crates/ty_python_semantic/src/types.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7528,6 +7528,28 @@ pub struct BoundTypeVarInstance<'db> {
75287528
impl get_size2::GetSize for BoundTypeVarInstance<'_> {}
75297529

75307530
impl<'db> BoundTypeVarInstance<'db> {
7531+
/// Create a new PEP 695 type variable that can be used in signatures
7532+
/// of synthetic generic functions.
7533+
pub(crate) fn synthetic(
7534+
db: &'db dyn Db,
7535+
name: &'static str,
7536+
variance: TypeVarVariance,
7537+
) -> Self {
7538+
Self::new(
7539+
db,
7540+
TypeVarInstance::new(
7541+
db,
7542+
Name::new_static(name),
7543+
None, // definition
7544+
None, // _bound_or_constraints
7545+
Some(variance),
7546+
None, // _default
7547+
TypeVarKind::Pep695,
7548+
),
7549+
BindingContext::Synthetic,
7550+
)
7551+
}
7552+
75317553
pub(crate) fn variance_with_polarity(
75327554
self,
75337555
db: &'db dyn Db,

crates/ty_python_semantic/src/types/class.rs

Lines changed: 168 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::types::{
3434
IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind,
3535
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping,
3636
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams,
37-
VarianceInferable, declaration_type, infer_definition_types, todo_type,
37+
UnionBuilder, VarianceInferable, declaration_type, infer_definition_types,
3838
};
3939
use crate::{
4040
Db, FxIndexMap, FxOrderSet, Program,
@@ -51,7 +51,7 @@ use crate::{
5151
semantic_index, use_def_map,
5252
},
5353
types::{
54-
CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionBuilder, UnionType,
54+
CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionType,
5555
definition_expression_type,
5656
},
5757
};
@@ -2331,49 +2331,179 @@ impl<'db> ClassLiteral<'db> {
23312331
)))
23322332
}
23332333
(CodeGeneratorKind::TypedDict, "get") => {
2334-
// TODO: synthesize a set of overloads with precise types
2335-
let signature = Signature::new(
2336-
Parameters::new([
2337-
Parameter::positional_only(Some(Name::new_static("self")))
2338-
.with_annotated_type(instance_ty),
2339-
Parameter::positional_only(Some(Name::new_static("key"))),
2340-
Parameter::positional_only(Some(Name::new_static("default")))
2341-
.with_default_type(Type::unknown()),
2342-
]),
2343-
Some(todo_type!("Support for `TypedDict`")),
2344-
);
2334+
let overloads = self
2335+
.fields(db, specialization, field_policy)
2336+
.into_iter()
2337+
.flat_map(|(name, field)| {
2338+
let key_type =
2339+
Type::StringLiteral(StringLiteralType::new(db, name.as_str()));
2340+
2341+
// For a required key, `.get()` always returns the value type. For a non-required key,
2342+
// `.get()` returns the union of the value type and the type of the default argument
2343+
// (which defaults to `None`).
2344+
2345+
// TODO: For now, we use two overloads here. They can be merged into a single function
2346+
// once the generics solver takes default arguments into account.
2347+
2348+
let get_sig = Signature::new(
2349+
Parameters::new([
2350+
Parameter::positional_only(Some(Name::new_static("self")))
2351+
.with_annotated_type(instance_ty),
2352+
Parameter::positional_only(Some(Name::new_static("key")))
2353+
.with_annotated_type(key_type),
2354+
]),
2355+
Some(if field.is_required() {
2356+
field.declared_ty
2357+
} else {
2358+
UnionType::from_elements(db, [field.declared_ty, Type::none(db)])
2359+
}),
2360+
);
23452361

2346-
Some(CallableType::function_like(db, signature))
2362+
let t_default =
2363+
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);
2364+
2365+
let get_with_default_sig = Signature::new_generic(
2366+
Some(GenericContext::from_typevar_instances(db, [t_default])),
2367+
Parameters::new([
2368+
Parameter::positional_only(Some(Name::new_static("self")))
2369+
.with_annotated_type(instance_ty),
2370+
Parameter::positional_only(Some(Name::new_static("key")))
2371+
.with_annotated_type(key_type),
2372+
Parameter::positional_only(Some(Name::new_static("default")))
2373+
.with_annotated_type(Type::TypeVar(t_default)),
2374+
]),
2375+
Some(if field.is_required() {
2376+
field.declared_ty
2377+
} else {
2378+
UnionType::from_elements(
2379+
db,
2380+
[field.declared_ty, Type::TypeVar(t_default)],
2381+
)
2382+
}),
2383+
);
2384+
2385+
[get_sig, get_with_default_sig]
2386+
})
2387+
// Fallback overloads for unknown keys
2388+
.chain(std::iter::once({
2389+
Signature::new(
2390+
Parameters::new([
2391+
Parameter::positional_only(Some(Name::new_static("self")))
2392+
.with_annotated_type(instance_ty),
2393+
Parameter::positional_only(Some(Name::new_static("key")))
2394+
.with_annotated_type(KnownClass::Str.to_instance(db)),
2395+
]),
2396+
Some(UnionType::from_elements(
2397+
db,
2398+
[Type::unknown(), Type::none(db)],
2399+
)),
2400+
)
2401+
}))
2402+
.chain(std::iter::once({
2403+
let t_default =
2404+
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);
2405+
2406+
Signature::new_generic(
2407+
Some(GenericContext::from_typevar_instances(db, [t_default])),
2408+
Parameters::new([
2409+
Parameter::positional_only(Some(Name::new_static("self")))
2410+
.with_annotated_type(instance_ty),
2411+
Parameter::positional_only(Some(Name::new_static("key")))
2412+
.with_annotated_type(KnownClass::Str.to_instance(db)),
2413+
Parameter::positional_only(Some(Name::new_static("default")))
2414+
.with_annotated_type(Type::TypeVar(t_default)),
2415+
]),
2416+
Some(UnionType::from_elements(
2417+
db,
2418+
[Type::unknown(), Type::TypeVar(t_default)],
2419+
)),
2420+
)
2421+
}));
2422+
2423+
Some(Type::Callable(CallableType::new(
2424+
db,
2425+
CallableSignature::from_overloads(overloads),
2426+
true,
2427+
)))
23472428
}
23482429
(CodeGeneratorKind::TypedDict, "pop") => {
2349-
// TODO: synthesize a set of overloads with precise types.
2350-
// Required keys should be forbidden to be popped.
2351-
let signature = Signature::new(
2352-
Parameters::new([
2353-
Parameter::positional_only(Some(Name::new_static("self")))
2354-
.with_annotated_type(instance_ty),
2355-
Parameter::positional_only(Some(Name::new_static("key"))),
2356-
Parameter::positional_only(Some(Name::new_static("default")))
2357-
.with_default_type(Type::unknown()),
2358-
]),
2359-
Some(todo_type!("Support for `TypedDict`")),
2360-
);
2430+
let fields = self.fields(db, specialization, field_policy);
2431+
let overloads = fields
2432+
.iter()
2433+
.filter(|(_, field)| {
2434+
// Only synthesize `pop` for fields that are not required.
2435+
!field.is_required()
2436+
})
2437+
.flat_map(|(name, field)| {
2438+
let key_type =
2439+
Type::StringLiteral(StringLiteralType::new(db, name.as_str()));
2440+
2441+
// TODO: Similar to above: consider merging these two overloads into one
2442+
2443+
// `.pop()` without default
2444+
let pop_sig = Signature::new(
2445+
Parameters::new([
2446+
Parameter::positional_only(Some(Name::new_static("self")))
2447+
.with_annotated_type(instance_ty),
2448+
Parameter::positional_only(Some(Name::new_static("key")))
2449+
.with_annotated_type(key_type),
2450+
]),
2451+
Some(field.declared_ty),
2452+
);
23612453

2362-
Some(CallableType::function_like(db, signature))
2454+
// `.pop()` with a default value
2455+
let t_default =
2456+
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);
2457+
2458+
let pop_with_default_sig = Signature::new_generic(
2459+
Some(GenericContext::from_typevar_instances(db, [t_default])),
2460+
Parameters::new([
2461+
Parameter::positional_only(Some(Name::new_static("self")))
2462+
.with_annotated_type(instance_ty),
2463+
Parameter::positional_only(Some(Name::new_static("key")))
2464+
.with_annotated_type(key_type),
2465+
Parameter::positional_only(Some(Name::new_static("default")))
2466+
.with_annotated_type(Type::TypeVar(t_default)),
2467+
]),
2468+
Some(UnionType::from_elements(
2469+
db,
2470+
[field.declared_ty, Type::TypeVar(t_default)],
2471+
)),
2472+
);
2473+
2474+
[pop_sig, pop_with_default_sig]
2475+
});
2476+
2477+
Some(Type::Callable(CallableType::new(
2478+
db,
2479+
CallableSignature::from_overloads(overloads),
2480+
true,
2481+
)))
23632482
}
23642483
(CodeGeneratorKind::TypedDict, "setdefault") => {
2365-
// TODO: synthesize a set of overloads with precise types
2366-
let signature = Signature::new(
2367-
Parameters::new([
2368-
Parameter::positional_only(Some(Name::new_static("self")))
2369-
.with_annotated_type(instance_ty),
2370-
Parameter::positional_only(Some(Name::new_static("key"))),
2371-
Parameter::positional_only(Some(Name::new_static("default"))),
2372-
]),
2373-
Some(todo_type!("Support for `TypedDict`")),
2374-
);
2484+
let fields = self.fields(db, specialization, field_policy);
2485+
let overloads = fields.iter().map(|(name, field)| {
2486+
let key_type = Type::StringLiteral(StringLiteralType::new(db, name.as_str()));
23752487

2376-
Some(CallableType::function_like(db, signature))
2488+
// `setdefault` always returns the field type
2489+
Signature::new(
2490+
Parameters::new([
2491+
Parameter::positional_only(Some(Name::new_static("self")))
2492+
.with_annotated_type(instance_ty),
2493+
Parameter::positional_only(Some(Name::new_static("key")))
2494+
.with_annotated_type(key_type),
2495+
Parameter::positional_only(Some(Name::new_static("default")))
2496+
.with_annotated_type(field.declared_ty),
2497+
]),
2498+
Some(field.declared_ty),
2499+
)
2500+
});
2501+
2502+
Some(Type::Callable(CallableType::new(
2503+
db,
2504+
CallableSignature::from_overloads(overloads),
2505+
true,
2506+
)))
23772507
}
23782508
(CodeGeneratorKind::TypedDict, "update") => {
23792509
// TODO: synthesize a set of overloads with precise types

crates/ty_python_semantic/src/types/diagnostic.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,6 +2952,21 @@ pub(crate) fn report_missing_typed_dict_key<'db>(
29522952
}
29532953
}
29542954

2955+
pub(crate) fn report_cannot_pop_required_field_on_typed_dict<'db>(
2956+
context: &InferContext<'db, '_>,
2957+
key_node: AnyNodeRef,
2958+
typed_dict_ty: Type<'db>,
2959+
field_name: &str,
2960+
) {
2961+
let db = context.db();
2962+
if let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, key_node) {
2963+
let typed_dict_name = typed_dict_ty.display(db);
2964+
builder.into_diagnostic(format_args!(
2965+
"Cannot pop required field '{field_name}' from TypedDict `{typed_dict_name}`",
2966+
));
2967+
}
2968+
}
2969+
29552970
/// This function receives an unresolved `from foo import bar` import,
29562971
/// where `foo` can be resolved to a module but that module does not
29572972
/// have a `bar` member or submodule.

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,19 @@ impl<'db> GenericContext<'db> {
119119
binding_context: Definition<'db>,
120120
type_params_node: &ast::TypeParams,
121121
) -> Self {
122-
let variables: FxOrderSet<_> = type_params_node
123-
.iter()
124-
.filter_map(|type_param| {
125-
Self::variable_from_type_param(db, index, binding_context, type_param)
126-
})
127-
.collect();
128-
Self::new(db, variables)
122+
let variables = type_params_node.iter().filter_map(|type_param| {
123+
Self::variable_from_type_param(db, index, binding_context, type_param)
124+
});
125+
126+
Self::from_typevar_instances(db, variables)
127+
}
128+
129+
/// Creates a generic context from a list of `BoundTypeVarInstance`s.
130+
pub(crate) fn from_typevar_instances(
131+
db: &'db dyn Db,
132+
type_params: impl IntoIterator<Item = BoundTypeVarInstance<'db>>,
133+
) -> Self {
134+
Self::new(db, type_params.into_iter().collect::<FxOrderSet<_>>())
129135
}
130136

131137
fn variable_from_type_param(
@@ -365,12 +371,12 @@ impl<'db> GenericContext<'db> {
365371
}
366372

367373
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
368-
let variables: FxOrderSet<_> = self
374+
let variables = self
369375
.variables(db)
370376
.iter()
371-
.map(|bound_typevar| bound_typevar.normalized_impl(db, visitor))
372-
.collect();
373-
Self::new(db, variables)
377+
.map(|bound_typevar| bound_typevar.normalized_impl(db, visitor));
378+
379+
Self::from_typevar_instances(db, variables)
374380
}
375381

376382
fn heap_size((variables,): &(FxOrderSet<BoundTypeVarInstance<'db>>,)) -> usize {

0 commit comments

Comments
 (0)