Skip to content

Commit da6b68c

Browse files
mtshibadhruvmanila
andauthored
[red-knot] infer attribute assignments bound in comprehensions (astral-sh#17396)
## Summary This PR is a follow-up to astral-sh#16852. Instance variables bound in comprehensions are recorded, allowing type inference to work correctly. This required adding support for unpacking in comprehension which resolves astral-sh#15369. ## Test Plan One TODO in `mdtest/attributes.md` is now resolved, and some new test cases are added. --------- Co-authored-by: Dhruv Manilawala <[email protected]>
1 parent 2a478ce commit da6b68c

File tree

10 files changed

+349
-108
lines changed

10 files changed

+349
-108
lines changed

crates/red_knot_project/tests/check.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ use ruff_db::parsed::parsed_module;
66
use ruff_db::system::{SystemPath, SystemPathBuf, TestSystem};
77
use ruff_python_ast::visitor::source_order;
88
use ruff_python_ast::visitor::source_order::SourceOrderVisitor;
9-
use ruff_python_ast::{self as ast, Alias, Expr, Parameter, ParameterWithDefault, Stmt};
9+
use ruff_python_ast::{
10+
self as ast, Alias, Comprehension, Expr, Parameter, ParameterWithDefault, Stmt,
11+
};
1012

1113
fn setup_db(project_root: &SystemPath, system: TestSystem) -> anyhow::Result<ProjectDatabase> {
1214
let project = ProjectMetadata::discover(project_root, &system)?;
@@ -258,6 +260,14 @@ impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> {
258260
source_order::walk_expr(self, expr);
259261
}
260262

263+
fn visit_comprehension(&mut self, comprehension: &Comprehension) {
264+
self.visit_expr(&comprehension.iter);
265+
self.visit_target(&comprehension.target);
266+
for if_expr in &comprehension.ifs {
267+
self.visit_expr(if_expr);
268+
}
269+
}
270+
261271
fn visit_parameter(&mut self, parameter: &Parameter) {
262272
let _ty = parameter.inferred_type(&self.model);
263273

crates/red_knot_python_semantic/resources/mdtest/attributes.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,27 @@ class IntIterable:
397397
def __iter__(self) -> IntIterator:
398398
return IntIterator()
399399

400+
class TupleIterator:
401+
def __next__(self) -> tuple[int, str]:
402+
return (1, "a")
403+
404+
class TupleIterable:
405+
def __iter__(self) -> TupleIterator:
406+
return TupleIterator()
407+
400408
class C:
401409
def __init__(self) -> None:
402410
[... for self.a in IntIterable()]
411+
[... for (self.b, self.c) in TupleIterable()]
412+
[... for self.d in IntIterable() for self.e in IntIterable()]
403413

404414
c_instance = C()
405415

406-
# TODO: Should be `Unknown | int`
407-
# error: [unresolved-attribute]
408-
reveal_type(c_instance.a) # revealed: Unknown
416+
reveal_type(c_instance.a) # revealed: Unknown | int
417+
reveal_type(c_instance.b) # revealed: Unknown | int
418+
reveal_type(c_instance.c) # revealed: Unknown | str
419+
reveal_type(c_instance.d) # revealed: Unknown | int
420+
reveal_type(c_instance.e) # revealed: Unknown | int
409421
```
410422

411423
#### Conditionally declared / bound attributes

crates/red_knot_python_semantic/resources/mdtest/unpacking.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,3 +708,95 @@ with ContextManager() as (a, b, c):
708708
reveal_type(b) # revealed: Unknown
709709
reveal_type(c) # revealed: Unknown
710710
```
711+
712+
## Comprehension
713+
714+
Unpacking in a comprehension.
715+
716+
### Same types
717+
718+
```py
719+
def _(arg: tuple[tuple[int, int], tuple[int, int]]):
720+
# revealed: tuple[int, int]
721+
[reveal_type((a, b)) for a, b in arg]
722+
```
723+
724+
### Mixed types (1)
725+
726+
```py
727+
def _(arg: tuple[tuple[int, int], tuple[int, str]]):
728+
# revealed: tuple[int, int | str]
729+
[reveal_type((a, b)) for a, b in arg]
730+
```
731+
732+
### Mixed types (2)
733+
734+
```py
735+
def _(arg: tuple[tuple[int, str], tuple[str, int]]):
736+
# revealed: tuple[int | str, str | int]
737+
[reveal_type((a, b)) for a, b in arg]
738+
```
739+
740+
### Mixed types (3)
741+
742+
```py
743+
def _(arg: tuple[tuple[int, int, int], tuple[int, str, bytes], tuple[int, int, str]]):
744+
# revealed: tuple[int, int | str, int | bytes | str]
745+
[reveal_type((a, b, c)) for a, b, c in arg]
746+
```
747+
748+
### Same literal values
749+
750+
```py
751+
# revealed: tuple[Literal[1, 3], Literal[2, 4]]
752+
[reveal_type((a, b)) for a, b in ((1, 2), (3, 4))]
753+
```
754+
755+
### Mixed literal values (1)
756+
757+
```py
758+
# revealed: tuple[Literal[1, "a"], Literal[2, "b"]]
759+
[reveal_type((a, b)) for a, b in ((1, 2), ("a", "b"))]
760+
```
761+
762+
### Mixed literals values (2)
763+
764+
```py
765+
# error: "Object of type `Literal[1]` is not iterable"
766+
# error: "Object of type `Literal[2]` is not iterable"
767+
# error: "Object of type `Literal[4]` is not iterable"
768+
# error: [invalid-assignment] "Not enough values to unpack (expected 2, got 1)"
769+
# revealed: tuple[Unknown | Literal[3, 5], Unknown | Literal["a", "b"]]
770+
[reveal_type((a, b)) for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c")]
771+
```
772+
773+
### Custom iterator (1)
774+
775+
```py
776+
class Iterator:
777+
def __next__(self) -> tuple[int, int]:
778+
return (1, 2)
779+
780+
class Iterable:
781+
def __iter__(self) -> Iterator:
782+
return Iterator()
783+
784+
# revealed: tuple[int, int]
785+
[reveal_type((a, b)) for a, b in Iterable()]
786+
```
787+
788+
### Custom iterator (2)
789+
790+
```py
791+
class Iterator:
792+
def __next__(self) -> bytes:
793+
return b""
794+
795+
class Iterable:
796+
def __iter__(self) -> Iterator:
797+
return Iterator()
798+
799+
def _(arg: tuple[tuple[int, str], Iterable]):
800+
# revealed: tuple[int | bytes, str | bytes]
801+
[reveal_type((a, b)) for a, b in arg]
802+
```

crates/red_knot_python_semantic/src/semantic_index.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
940940
panic!("expected generator definition")
941941
};
942942
let target = comprehension.target();
943-
let name = target.id().as_str();
943+
let name = target.as_name_expr().unwrap().id().as_str();
944944

945945
assert_eq!(name, "x");
946946
assert_eq!(target.range(), TextRange::new(23.into(), 24.into()));

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 93 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
1818
use crate::semantic_index::ast_ids::AstIdsBuilder;
1919
use crate::semantic_index::definition::{
2020
AnnotatedAssignmentDefinitionKind, AnnotatedAssignmentDefinitionNodeRef,
21-
AssignmentDefinitionKind, AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef,
22-
Definition, DefinitionCategory, DefinitionKind, DefinitionNodeKey, DefinitionNodeRef,
23-
Definitions, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionKind, ForStmtDefinitionNodeRef,
24-
ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef,
25-
StarImportDefinitionNodeRef, TargetKind, WithItemDefinitionKind, WithItemDefinitionNodeRef,
21+
AssignmentDefinitionKind, AssignmentDefinitionNodeRef, ComprehensionDefinitionKind,
22+
ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionKind,
23+
DefinitionNodeKey, DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef,
24+
ForStmtDefinitionKind, ForStmtDefinitionNodeRef, ImportDefinitionNodeRef,
25+
ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef,
26+
TargetKind, WithItemDefinitionKind, WithItemDefinitionNodeRef,
2627
};
2728
use crate::semantic_index::expression::{Expression, ExpressionKind};
2829
use crate::semantic_index::predicate::{
@@ -850,31 +851,35 @@ impl<'db> SemanticIndexBuilder<'db> {
850851

851852
// The `iter` of the first generator is evaluated in the outer scope, while all subsequent
852853
// nodes are evaluated in the inner scope.
853-
self.add_standalone_expression(&generator.iter);
854+
let value = self.add_standalone_expression(&generator.iter);
854855
self.visit_expr(&generator.iter);
855856
self.push_scope(scope);
856857

857-
self.push_assignment(CurrentAssignment::Comprehension {
858-
node: generator,
859-
first: true,
860-
});
861-
self.visit_expr(&generator.target);
862-
self.pop_assignment();
858+
self.add_unpackable_assignment(
859+
&Unpackable::Comprehension {
860+
node: generator,
861+
first: true,
862+
},
863+
&generator.target,
864+
value,
865+
);
863866

864867
for expr in &generator.ifs {
865868
self.visit_expr(expr);
866869
}
867870

868871
for generator in generators_iter {
869-
self.add_standalone_expression(&generator.iter);
872+
let value = self.add_standalone_expression(&generator.iter);
870873
self.visit_expr(&generator.iter);
871874

872-
self.push_assignment(CurrentAssignment::Comprehension {
873-
node: generator,
874-
first: false,
875-
});
876-
self.visit_expr(&generator.target);
877-
self.pop_assignment();
875+
self.add_unpackable_assignment(
876+
&Unpackable::Comprehension {
877+
node: generator,
878+
first: false,
879+
},
880+
&generator.target,
881+
value,
882+
);
878883

879884
for expr in &generator.ifs {
880885
self.visit_expr(expr);
@@ -933,9 +938,30 @@ impl<'db> SemanticIndexBuilder<'db> {
933938

934939
let current_assignment = match target {
935940
ast::Expr::List(_) | ast::Expr::Tuple(_) => {
941+
if matches!(unpackable, Unpackable::Comprehension { .. }) {
942+
debug_assert_eq!(
943+
self.scopes[self.current_scope()].node().scope_kind(),
944+
ScopeKind::Comprehension
945+
);
946+
}
947+
// The first iterator of the comprehension is evaluated in the outer scope, while all subsequent
948+
// nodes are evaluated in the inner scope.
949+
// SAFETY: The current scope is the comprehension, and the comprehension scope must have a parent scope.
950+
let value_file_scope =
951+
if let Unpackable::Comprehension { first: true, .. } = unpackable {
952+
self.scope_stack
953+
.iter()
954+
.rev()
955+
.nth(1)
956+
.expect("The comprehension scope must have a parent scope")
957+
.file_scope_id
958+
} else {
959+
self.current_scope()
960+
};
936961
let unpack = Some(Unpack::new(
937962
self.db,
938963
self.file,
964+
value_file_scope,
939965
self.current_scope(),
940966
// SAFETY: `target` belongs to the `self.module` tree
941967
#[allow(unsafe_code)]
@@ -1804,7 +1830,7 @@ where
18041830
let node_key = NodeKey::from_node(expr);
18051831

18061832
match expr {
1807-
ast::Expr::Name(name_node @ ast::ExprName { id, ctx, .. }) => {
1833+
ast::Expr::Name(ast::ExprName { id, ctx, .. }) => {
18081834
let (is_use, is_definition) = match (ctx, self.current_assignment()) {
18091835
(ast::ExprContext::Store, Some(CurrentAssignment::AugAssign(_))) => {
18101836
// For augmented assignment, the target expression is also used.
@@ -1867,12 +1893,17 @@ where
18671893
// implemented.
18681894
self.add_definition(symbol, named);
18691895
}
1870-
Some(CurrentAssignment::Comprehension { node, first }) => {
1896+
Some(CurrentAssignment::Comprehension {
1897+
unpack,
1898+
node,
1899+
first,
1900+
}) => {
18711901
self.add_definition(
18721902
symbol,
18731903
ComprehensionDefinitionNodeRef {
1904+
unpack,
18741905
iterable: &node.iter,
1875-
target: name_node,
1906+
target: expr,
18761907
first,
18771908
is_async: node.is_async,
18781909
},
@@ -2143,14 +2174,37 @@ where
21432174
DefinitionKind::WithItem(assignment),
21442175
);
21452176
}
2146-
Some(CurrentAssignment::Comprehension { .. }) => {
2147-
// TODO:
2177+
Some(CurrentAssignment::Comprehension {
2178+
unpack,
2179+
node,
2180+
first,
2181+
}) => {
2182+
// SAFETY: `iter` and `expr` belong to the `self.module` tree
2183+
#[allow(unsafe_code)]
2184+
let assignment = ComprehensionDefinitionKind {
2185+
target_kind: TargetKind::from(unpack),
2186+
iterable: unsafe {
2187+
AstNodeRef::new(self.module.clone(), &node.iter)
2188+
},
2189+
target: unsafe { AstNodeRef::new(self.module.clone(), expr) },
2190+
first,
2191+
is_async: node.is_async,
2192+
};
2193+
// Temporarily move to the scope of the method to which the instance attribute is defined.
2194+
// SAFETY: `self.scope_stack` is not empty because the targets in comprehensions should always introduce a new scope.
2195+
let scope = self.scope_stack.pop().expect("The popped scope must be a comprehension, which must have a parent scope");
2196+
self.register_attribute_assignment(
2197+
object,
2198+
attr,
2199+
DefinitionKind::Comprehension(assignment),
2200+
);
2201+
self.scope_stack.push(scope);
21482202
}
21492203
Some(CurrentAssignment::AugAssign(_)) => {
21502204
// TODO:
21512205
}
21522206
Some(CurrentAssignment::Named(_)) => {
2153-
// TODO:
2207+
// A named expression whose target is an attribute is syntactically prohibited
21542208
}
21552209
None => {}
21562210
}
@@ -2244,6 +2298,7 @@ enum CurrentAssignment<'a> {
22442298
Comprehension {
22452299
node: &'a ast::Comprehension,
22462300
first: bool,
2301+
unpack: Option<(UnpackPosition, Unpack<'a>)>,
22472302
},
22482303
WithItem {
22492304
item: &'a ast::WithItem,
@@ -2257,11 +2312,9 @@ impl CurrentAssignment<'_> {
22572312
match self {
22582313
Self::Assign { unpack, .. }
22592314
| Self::For { unpack, .. }
2260-
| Self::WithItem { unpack, .. } => unpack.as_mut().map(|(position, _)| position),
2261-
Self::AnnAssign(_)
2262-
| Self::AugAssign(_)
2263-
| Self::Named(_)
2264-
| Self::Comprehension { .. } => None,
2315+
| Self::WithItem { unpack, .. }
2316+
| Self::Comprehension { unpack, .. } => unpack.as_mut().map(|(position, _)| position),
2317+
Self::AnnAssign(_) | Self::AugAssign(_) | Self::Named(_) => None,
22652318
}
22662319
}
22672320
}
@@ -2316,13 +2369,17 @@ enum Unpackable<'a> {
23162369
item: &'a ast::WithItem,
23172370
is_async: bool,
23182371
},
2372+
Comprehension {
2373+
first: bool,
2374+
node: &'a ast::Comprehension,
2375+
},
23192376
}
23202377

23212378
impl<'a> Unpackable<'a> {
23222379
const fn kind(&self) -> UnpackKind {
23232380
match self {
23242381
Unpackable::Assign(_) => UnpackKind::Assign,
2325-
Unpackable::For(_) => UnpackKind::Iterable,
2382+
Unpackable::For(_) | Unpackable::Comprehension { .. } => UnpackKind::Iterable,
23262383
Unpackable::WithItem { .. } => UnpackKind::ContextManager,
23272384
}
23282385
}
@@ -2337,6 +2394,11 @@ impl<'a> Unpackable<'a> {
23372394
is_async: *is_async,
23382395
unpack,
23392396
},
2397+
Unpackable::Comprehension { node, first } => CurrentAssignment::Comprehension {
2398+
node,
2399+
first: *first,
2400+
unpack,
2401+
},
23402402
}
23412403
}
23422404
}

0 commit comments

Comments
 (0)