@@ -34,7 +34,7 @@ use crate::types::{
3434 IsEquivalentVisitor , KnownInstanceType , ManualPEP695TypeAliasType , MaterializationKind ,
3535 NormalizedVisitor , PropertyInstanceType , StringLiteralType , TypeAliasType , TypeMapping ,
3636 TypeRelation , TypeVarBoundOrConstraints , TypeVarInstance , TypeVarKind , TypedDictParams ,
37- VarianceInferable , declaration_type, infer_definition_types, todo_type ,
37+ UnionBuilder , VarianceInferable , declaration_type, infer_definition_types,
3838} ;
3939use crate :: {
4040 Db , FxIndexMap , FxOrderSet , Program ,
@@ -51,7 +51,7 @@ use crate::{
5151 semantic_index, use_def_map,
5252 } ,
5353 types:: {
54- CallArguments , CallError , CallErrorKind , MetaclassCandidate , UnionBuilder , UnionType ,
54+ CallArguments , CallError , CallErrorKind , MetaclassCandidate , UnionType ,
5555 definition_expression_type,
5656 } ,
5757} ;
@@ -2331,49 +2331,179 @@ impl<'db> ClassLiteral<'db> {
23312331 ) ) )
23322332 }
23332333 ( CodeGeneratorKind :: TypedDict , "get" ) => {
2334- // TODO: synthesize a set of overloads with precise types
2335- let signature = Signature :: new (
2336- Parameters :: new ( [
2337- Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2338- . with_annotated_type ( instance_ty) ,
2339- Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) ) ,
2340- Parameter :: positional_only ( Some ( Name :: new_static ( "default" ) ) )
2341- . with_default_type ( Type :: unknown ( ) ) ,
2342- ] ) ,
2343- Some ( todo_type ! ( "Support for `TypedDict`" ) ) ,
2344- ) ;
2334+ let overloads = self
2335+ . fields ( db, specialization, field_policy)
2336+ . into_iter ( )
2337+ . flat_map ( |( name, field) | {
2338+ let key_type =
2339+ Type :: StringLiteral ( StringLiteralType :: new ( db, name. as_str ( ) ) ) ;
2340+
2341+ // For a required key, `.get()` always returns the value type. For a non-required key,
2342+ // `.get()` returns the union of the value type and the type of the default argument
2343+ // (which defaults to `None`).
2344+
2345+ // TODO: For now, we use two overloads here. They can be merged into a single function
2346+ // once the generics solver takes default arguments into account.
2347+
2348+ let get_sig = Signature :: new (
2349+ Parameters :: new ( [
2350+ Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2351+ . with_annotated_type ( instance_ty) ,
2352+ Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) )
2353+ . with_annotated_type ( key_type) ,
2354+ ] ) ,
2355+ Some ( if field. is_required ( ) {
2356+ field. declared_ty
2357+ } else {
2358+ UnionType :: from_elements ( db, [ field. declared_ty , Type :: none ( db) ] )
2359+ } ) ,
2360+ ) ;
23452361
2346- Some ( CallableType :: function_like ( db, signature) )
2362+ let t_default =
2363+ BoundTypeVarInstance :: synthetic ( db, "T" , TypeVarVariance :: Covariant ) ;
2364+
2365+ let get_with_default_sig = Signature :: new_generic (
2366+ Some ( GenericContext :: from_typevar_instances ( db, [ t_default] ) ) ,
2367+ Parameters :: new ( [
2368+ Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2369+ . with_annotated_type ( instance_ty) ,
2370+ Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) )
2371+ . with_annotated_type ( key_type) ,
2372+ Parameter :: positional_only ( Some ( Name :: new_static ( "default" ) ) )
2373+ . with_annotated_type ( Type :: TypeVar ( t_default) ) ,
2374+ ] ) ,
2375+ Some ( if field. is_required ( ) {
2376+ field. declared_ty
2377+ } else {
2378+ UnionType :: from_elements (
2379+ db,
2380+ [ field. declared_ty , Type :: TypeVar ( t_default) ] ,
2381+ )
2382+ } ) ,
2383+ ) ;
2384+
2385+ [ get_sig, get_with_default_sig]
2386+ } )
2387+ // Fallback overloads for unknown keys
2388+ . chain ( std:: iter:: once ( {
2389+ Signature :: new (
2390+ Parameters :: new ( [
2391+ Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2392+ . with_annotated_type ( instance_ty) ,
2393+ Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) )
2394+ . with_annotated_type ( KnownClass :: Str . to_instance ( db) ) ,
2395+ ] ) ,
2396+ Some ( UnionType :: from_elements (
2397+ db,
2398+ [ Type :: unknown ( ) , Type :: none ( db) ] ,
2399+ ) ) ,
2400+ )
2401+ } ) )
2402+ . chain ( std:: iter:: once ( {
2403+ let t_default =
2404+ BoundTypeVarInstance :: synthetic ( db, "T" , TypeVarVariance :: Covariant ) ;
2405+
2406+ Signature :: new_generic (
2407+ Some ( GenericContext :: from_typevar_instances ( db, [ t_default] ) ) ,
2408+ Parameters :: new ( [
2409+ Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2410+ . with_annotated_type ( instance_ty) ,
2411+ Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) )
2412+ . with_annotated_type ( KnownClass :: Str . to_instance ( db) ) ,
2413+ Parameter :: positional_only ( Some ( Name :: new_static ( "default" ) ) )
2414+ . with_annotated_type ( Type :: TypeVar ( t_default) ) ,
2415+ ] ) ,
2416+ Some ( UnionType :: from_elements (
2417+ db,
2418+ [ Type :: unknown ( ) , Type :: TypeVar ( t_default) ] ,
2419+ ) ) ,
2420+ )
2421+ } ) ) ;
2422+
2423+ Some ( Type :: Callable ( CallableType :: new (
2424+ db,
2425+ CallableSignature :: from_overloads ( overloads) ,
2426+ true ,
2427+ ) ) )
23472428 }
23482429 ( CodeGeneratorKind :: TypedDict , "pop" ) => {
2349- // TODO: synthesize a set of overloads with precise types.
2350- // Required keys should be forbidden to be popped.
2351- let signature = Signature :: new (
2352- Parameters :: new ( [
2353- Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2354- . with_annotated_type ( instance_ty) ,
2355- Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) ) ,
2356- Parameter :: positional_only ( Some ( Name :: new_static ( "default" ) ) )
2357- . with_default_type ( Type :: unknown ( ) ) ,
2358- ] ) ,
2359- Some ( todo_type ! ( "Support for `TypedDict`" ) ) ,
2360- ) ;
2430+ let fields = self . fields ( db, specialization, field_policy) ;
2431+ let overloads = fields
2432+ . iter ( )
2433+ . filter ( |( _, field) | {
2434+ // Only synthesize `pop` for fields that are not required.
2435+ !field. is_required ( )
2436+ } )
2437+ . flat_map ( |( name, field) | {
2438+ let key_type =
2439+ Type :: StringLiteral ( StringLiteralType :: new ( db, name. as_str ( ) ) ) ;
2440+
2441+ // TODO: Similar to above: consider merging these two overloads into one
2442+
2443+ // `.pop()` without default
2444+ let pop_sig = Signature :: new (
2445+ Parameters :: new ( [
2446+ Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2447+ . with_annotated_type ( instance_ty) ,
2448+ Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) )
2449+ . with_annotated_type ( key_type) ,
2450+ ] ) ,
2451+ Some ( field. declared_ty ) ,
2452+ ) ;
23612453
2362- Some ( CallableType :: function_like ( db, signature) )
2454+ // `.pop()` with a default value
2455+ let t_default =
2456+ BoundTypeVarInstance :: synthetic ( db, "T" , TypeVarVariance :: Covariant ) ;
2457+
2458+ let pop_with_default_sig = Signature :: new_generic (
2459+ Some ( GenericContext :: from_typevar_instances ( db, [ t_default] ) ) ,
2460+ Parameters :: new ( [
2461+ Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2462+ . with_annotated_type ( instance_ty) ,
2463+ Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) )
2464+ . with_annotated_type ( key_type) ,
2465+ Parameter :: positional_only ( Some ( Name :: new_static ( "default" ) ) )
2466+ . with_annotated_type ( Type :: TypeVar ( t_default) ) ,
2467+ ] ) ,
2468+ Some ( UnionType :: from_elements (
2469+ db,
2470+ [ field. declared_ty , Type :: TypeVar ( t_default) ] ,
2471+ ) ) ,
2472+ ) ;
2473+
2474+ [ pop_sig, pop_with_default_sig]
2475+ } ) ;
2476+
2477+ Some ( Type :: Callable ( CallableType :: new (
2478+ db,
2479+ CallableSignature :: from_overloads ( overloads) ,
2480+ true ,
2481+ ) ) )
23632482 }
23642483 ( CodeGeneratorKind :: TypedDict , "setdefault" ) => {
2365- // TODO: synthesize a set of overloads with precise types
2366- let signature = Signature :: new (
2367- Parameters :: new ( [
2368- Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2369- . with_annotated_type ( instance_ty) ,
2370- Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) ) ,
2371- Parameter :: positional_only ( Some ( Name :: new_static ( "default" ) ) ) ,
2372- ] ) ,
2373- Some ( todo_type ! ( "Support for `TypedDict`" ) ) ,
2374- ) ;
2484+ let fields = self . fields ( db, specialization, field_policy) ;
2485+ let overloads = fields. iter ( ) . map ( |( name, field) | {
2486+ let key_type = Type :: StringLiteral ( StringLiteralType :: new ( db, name. as_str ( ) ) ) ;
23752487
2376- Some ( CallableType :: function_like ( db, signature) )
2488+ // `setdefault` always returns the field type
2489+ Signature :: new (
2490+ Parameters :: new ( [
2491+ Parameter :: positional_only ( Some ( Name :: new_static ( "self" ) ) )
2492+ . with_annotated_type ( instance_ty) ,
2493+ Parameter :: positional_only ( Some ( Name :: new_static ( "key" ) ) )
2494+ . with_annotated_type ( key_type) ,
2495+ Parameter :: positional_only ( Some ( Name :: new_static ( "default" ) ) )
2496+ . with_annotated_type ( field. declared_ty ) ,
2497+ ] ) ,
2498+ Some ( field. declared_ty ) ,
2499+ )
2500+ } ) ;
2501+
2502+ Some ( Type :: Callable ( CallableType :: new (
2503+ db,
2504+ CallableSignature :: from_overloads ( overloads) ,
2505+ true ,
2506+ ) ) )
23772507 }
23782508 ( CodeGeneratorKind :: TypedDict , "update" ) => {
23792509 // TODO: synthesize a set of overloads with precise types
0 commit comments