Skip to content

Commit 7cd3be6

Browse files
committed
[red-knot] Add support for unpacking for target
1 parent d47fba1 commit 7cd3be6

File tree

7 files changed

+273
-83
lines changed

7 files changed

+273
-83
lines changed

crates/red_knot_python_semantic/resources/mdtest/unpacking.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,85 @@ def _(arg: tuple[int, str] | Iterable):
472472
reveal_type(a) # revealed: int | bytes
473473
reveal_type(b) # revealed: str | bytes
474474
```
475+
476+
## For statement
477+
478+
### Same types
479+
480+
```py
481+
def _(arg: tuple[tuple[int, int], tuple[int, int]]):
482+
for a, b in arg:
483+
reveal_type(a) # revealed: int
484+
reveal_type(b) # revealed: int
485+
```
486+
487+
### Mixed types (1)
488+
489+
```py
490+
def _(arg: tuple[tuple[int, int], tuple[int, str]]):
491+
for a, b in arg:
492+
reveal_type(a) # revealed: int
493+
reveal_type(b) # revealed: int | str
494+
```
495+
496+
### Mixed types (2)
497+
498+
```py
499+
def _(arg: tuple[tuple[int, str], tuple[str, int]]):
500+
for a, b in arg:
501+
reveal_type(a) # revealed: int | str
502+
reveal_type(b) # revealed: str | int
503+
```
504+
505+
### Mixed types (3)
506+
507+
```py
508+
def _(arg: tuple[tuple[int, int, int], tuple[int, str, bytes], tuple[int, int, str]]):
509+
for a, b, c in arg:
510+
reveal_type(a) # revealed: int
511+
reveal_type(b) # revealed: int | str
512+
reveal_type(c) # revealed: int | bytes | str
513+
```
514+
515+
### Same literal types
516+
517+
```py
518+
for a, b in ((1, 2), (3, 4)):
519+
reveal_type(a) # revealed: Literal[1, 3]
520+
reveal_type(b) # revealed: Literal[2, 4]
521+
```
522+
523+
### Mixed literal types
524+
525+
```py
526+
for a, b in ((1, 2), ("a", "b")):
527+
reveal_type(a) # revealed: Literal[1] | Literal["a"]
528+
reveal_type(b) # revealed: Literal[2] | Literal["b"]
529+
```
530+
531+
### Mixed literals (2)
532+
533+
```py
534+
# error: "Object of type `Literal[1]` is not iterable"
535+
# error: "Object of type `Literal[2]` is not iterable"
536+
# error: "Object of type `Literal[4]` is not iterable"
537+
for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c"):
538+
reveal_type(a) # revealed: Unknown | Literal[3, 5] | LiteralString
539+
reveal_type(b) # revealed: Unknown | Literal["a", "b"]
540+
```
541+
542+
### Custom iterator
543+
544+
```py
545+
class Iterator:
546+
def __next__(self) -> tuple[int, int]:
547+
return (1, 2)
548+
549+
class Iterable:
550+
def __iter__(self) -> Iterator:
551+
return Iterator()
552+
553+
for a, b in Iterable():
554+
reveal_type(a) # revealed: int
555+
reveal_type(b) # revealed: int
556+
```

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::semantic_index::symbol::{
2626
};
2727
use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
2828
use crate::semantic_index::SemanticIndex;
29-
use crate::unpack::Unpack;
29+
use crate::unpack::{Unpack, UnpackValue};
3030
use crate::Db;
3131

3232
use super::constraint::{Constraint, ConstraintNode, PatternConstraint};
@@ -726,7 +726,7 @@ where
726726
unsafe {
727727
AstNodeRef::new(self.module.clone(), target)
728728
},
729-
value,
729+
UnpackValue::Assign(value),
730730
countme::Count::default(),
731731
)),
732732
})
@@ -909,16 +909,45 @@ where
909909
orelse,
910910
},
911911
) => {
912-
self.add_standalone_expression(iter);
912+
debug_assert_eq!(&self.current_assignments, &[]);
913+
914+
let iter_expr = self.add_standalone_expression(iter);
913915
self.visit_expr(iter);
914916

915917
let pre_loop = self.flow_snapshot();
916918
let saved_break_states = std::mem::take(&mut self.loop_break_states);
917919

918-
debug_assert_eq!(&self.current_assignments, &[]);
919-
self.push_assignment(for_stmt.into());
920+
let current_assignment = match &**target {
921+
ast::Expr::List(_) | ast::Expr::Tuple(_) => Some(CurrentAssignment::For {
922+
node: for_stmt,
923+
first: true,
924+
unpack: Some(Unpack::new(
925+
self.db,
926+
self.file,
927+
self.current_scope(),
928+
#[allow(unsafe_code)]
929+
unsafe {
930+
AstNodeRef::new(self.module.clone(), target)
931+
},
932+
UnpackValue::Iterable(iter_expr),
933+
countme::Count::default(),
934+
)),
935+
}),
936+
ast::Expr::Name(_) => Some(CurrentAssignment::For {
937+
node: for_stmt,
938+
unpack: None,
939+
first: false,
940+
}),
941+
_ => None,
942+
};
943+
944+
if let Some(current_assignment) = current_assignment {
945+
self.push_assignment(current_assignment);
946+
}
920947
self.visit_expr(target);
921-
self.pop_assignment();
948+
if current_assignment.is_some() {
949+
self.pop_assignment();
950+
}
922951

923952
// TODO: Definitions created by loop variables
924953
// (and definitions created inside the body)
@@ -1136,12 +1165,18 @@ where
11361165
Some(CurrentAssignment::AugAssign(aug_assign)) => {
11371166
self.add_definition(symbol, aug_assign);
11381167
}
1139-
Some(CurrentAssignment::For(node)) => {
1168+
Some(CurrentAssignment::For {
1169+
node,
1170+
first,
1171+
unpack,
1172+
}) => {
11401173
self.add_definition(
11411174
symbol,
11421175
ForStmtDefinitionNodeRef {
1176+
unpack,
1177+
first,
11431178
iterable: &node.iter,
1144-
target: name_node,
1179+
name: name_node,
11451180
is_async: node.is_async,
11461181
},
11471182
);
@@ -1177,7 +1212,9 @@ where
11771212
}
11781213
}
11791214

1180-
if let Some(CurrentAssignment::Assign { first, .. }) = self.current_assignment_mut()
1215+
if let Some(
1216+
CurrentAssignment::Assign { first, .. } | CurrentAssignment::For { first, .. },
1217+
) = self.current_assignment_mut()
11811218
{
11821219
*first = false;
11831220
}
@@ -1391,7 +1428,11 @@ enum CurrentAssignment<'a> {
13911428
},
13921429
AnnAssign(&'a ast::StmtAnnAssign),
13931430
AugAssign(&'a ast::StmtAugAssign),
1394-
For(&'a ast::StmtFor),
1431+
For {
1432+
node: &'a ast::StmtFor,
1433+
first: bool,
1434+
unpack: Option<Unpack<'a>>,
1435+
},
13951436
Named(&'a ast::ExprNamed),
13961437
Comprehension {
13971438
node: &'a ast::Comprehension,
@@ -1415,12 +1456,6 @@ impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> {
14151456
}
14161457
}
14171458

1418-
impl<'a> From<&'a ast::StmtFor> for CurrentAssignment<'a> {
1419-
fn from(value: &'a ast::StmtFor) -> Self {
1420-
Self::For(value)
1421-
}
1422-
}
1423-
14241459
impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
14251460
fn from(value: &'a ast::ExprNamed) -> Self {
14261461
Self::Named(value)

crates/red_knot_python_semantic/src/semantic_index/definition.rs

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,10 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {
225225

226226
#[derive(Copy, Clone, Debug)]
227227
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
228+
pub(crate) unpack: Option<Unpack<'a>>,
228229
pub(crate) iterable: &'a ast::Expr,
229-
pub(crate) target: &'a ast::ExprName,
230+
pub(crate) name: &'a ast::ExprName,
231+
pub(crate) first: bool,
230232
pub(crate) is_async: bool,
231233
}
232234

@@ -298,12 +300,16 @@ impl<'db> DefinitionNodeRef<'db> {
298300
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
299301
}
300302
DefinitionNodeRef::For(ForStmtDefinitionNodeRef {
303+
unpack,
301304
iterable,
302-
target,
305+
name,
306+
first,
303307
is_async,
304308
}) => DefinitionKind::For(ForStmtDefinitionKind {
309+
target: TargetKind::from(unpack),
305310
iterable: AstNodeRef::new(parsed.clone(), iterable),
306-
target: AstNodeRef::new(parsed, target),
311+
name: AstNodeRef::new(parsed, name),
312+
first,
307313
is_async,
308314
}),
309315
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
@@ -382,10 +388,12 @@ impl<'db> DefinitionNodeRef<'db> {
382388
Self::AnnotatedAssignment(node) => node.into(),
383389
Self::AugmentedAssignment(node) => node.into(),
384390
Self::For(ForStmtDefinitionNodeRef {
391+
unpack: _,
385392
iterable: _,
386-
target,
393+
name,
394+
first: _,
387395
is_async: _,
388-
}) => target.into(),
396+
}) => name.into(),
389397
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
390398
Self::VariadicPositionalParameter(node) => node.into(),
391399
Self::VariadicKeywordParameter(node) => node.into(),
@@ -452,7 +460,7 @@ pub enum DefinitionKind<'db> {
452460
Assignment(AssignmentDefinitionKind<'db>),
453461
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
454462
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
455-
For(ForStmtDefinitionKind),
463+
For(ForStmtDefinitionKind<'db>),
456464
Comprehension(ComprehensionDefinitionKind),
457465
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
458466
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
@@ -477,7 +485,7 @@ impl Ranged for DefinitionKind<'_> {
477485
DefinitionKind::Assignment(assignment) => assignment.name().range(),
478486
DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(),
479487
DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.target.range(),
480-
DefinitionKind::For(for_stmt) => for_stmt.target().range(),
488+
DefinitionKind::For(for_stmt) => for_stmt.name().range(),
481489
DefinitionKind::Comprehension(comp) => comp.target().range(),
482490
DefinitionKind::VariadicPositionalParameter(parameter) => parameter.name.range(),
483491
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.name.range(),
@@ -665,22 +673,32 @@ impl WithItemDefinitionKind {
665673
}
666674

667675
#[derive(Clone, Debug)]
668-
pub struct ForStmtDefinitionKind {
676+
pub struct ForStmtDefinitionKind<'db> {
677+
target: TargetKind<'db>,
669678
iterable: AstNodeRef<ast::Expr>,
670-
target: AstNodeRef<ast::ExprName>,
679+
name: AstNodeRef<ast::ExprName>,
680+
first: bool,
671681
is_async: bool,
672682
}
673683

674-
impl ForStmtDefinitionKind {
684+
impl<'db> ForStmtDefinitionKind<'db> {
675685
pub(crate) fn iterable(&self) -> &ast::Expr {
676686
self.iterable.node()
677687
}
678688

679-
pub(crate) fn target(&self) -> &ast::ExprName {
680-
self.target.node()
689+
pub(crate) fn target(&self) -> TargetKind<'db> {
690+
self.target
681691
}
682692

683-
pub(crate) fn is_async(&self) -> bool {
693+
pub(crate) fn name(&self) -> &ast::ExprName {
694+
self.name.node()
695+
}
696+
697+
pub(crate) const fn is_first(&self) -> bool {
698+
self.first
699+
}
700+
701+
pub(crate) const fn is_async(&self) -> bool {
684702
self.is_async
685703
}
686704
}
@@ -756,12 +774,6 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
756774
}
757775
}
758776

759-
impl From<&ast::StmtFor> for DefinitionNodeKey {
760-
fn from(value: &ast::StmtFor) -> Self {
761-
Self(NodeKey::from_node(value))
762-
}
763-
}
764-
765777
impl From<&ast::Parameter> for DefinitionNodeKey {
766778
fn from(node: &ast::Parameter) -> Self {
767779
Self(NodeKey::from_node(node))

0 commit comments

Comments
 (0)