Skip to content

Commit fe681d0

Browse files
committed
[red-knot] Add support for unpacking for target
1 parent d47fba1 commit fe681d0

File tree

7 files changed

+294
-83
lines changed

7 files changed

+294
-83
lines changed

crates/red_knot_python_semantic/resources/mdtest/unpacking.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,104 @@ def _(arg: tuple[int, str] | Iterable):
472472
reveal_type(a) # revealed: int | bytes
473473
reveal_type(b) # revealed: str | bytes
474474
```
475+
476+
## For statement
477+
478+
Unpacking in a `for` statement.
479+
480+
### Same types
481+
482+
```py
483+
def _(arg: tuple[tuple[int, int], tuple[int, int]]):
484+
for a, b in arg:
485+
reveal_type(a) # revealed: int
486+
reveal_type(b) # revealed: int
487+
```
488+
489+
### Mixed types (1)
490+
491+
```py
492+
def _(arg: tuple[tuple[int, int], tuple[int, str]]):
493+
for a, b in arg:
494+
reveal_type(a) # revealed: int
495+
reveal_type(b) # revealed: int | str
496+
```
497+
498+
### Mixed types (2)
499+
500+
```py
501+
def _(arg: tuple[tuple[int, str], tuple[str, int]]):
502+
for a, b in arg:
503+
reveal_type(a) # revealed: int | str
504+
reveal_type(b) # revealed: str | int
505+
```
506+
507+
### Mixed types (3)
508+
509+
```py
510+
def _(arg: tuple[tuple[int, int, int], tuple[int, str, bytes], tuple[int, int, str]]):
511+
for a, b, c in arg:
512+
reveal_type(a) # revealed: int
513+
reveal_type(b) # revealed: int | str
514+
reveal_type(c) # revealed: int | bytes | str
515+
```
516+
517+
### Same literal values
518+
519+
```py
520+
for a, b in ((1, 2), (3, 4)):
521+
reveal_type(a) # revealed: Literal[1, 3]
522+
reveal_type(b) # revealed: Literal[2, 4]
523+
```
524+
525+
### Mixed literal values (1)
526+
527+
```py
528+
for a, b in ((1, 2), ("a", "b")):
529+
reveal_type(a) # revealed: Literal[1] | Literal["a"]
530+
reveal_type(b) # revealed: Literal[2] | Literal["b"]
531+
```
532+
533+
### Mixed literals values (2)
534+
535+
```py
536+
# error: "Object of type `Literal[1]` is not iterable"
537+
# error: "Object of type `Literal[2]` is not iterable"
538+
# error: "Object of type `Literal[4]` is not iterable"
539+
for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c"):
540+
reveal_type(a) # revealed: Unknown | Literal[3, 5] | LiteralString
541+
reveal_type(b) # revealed: Unknown | Literal["a", "b"]
542+
```
543+
544+
### Custom iterator (1)
545+
546+
```py
547+
class Iterator:
548+
def __next__(self) -> tuple[int, int]:
549+
return (1, 2)
550+
551+
class Iterable:
552+
def __iter__(self) -> Iterator:
553+
return Iterator()
554+
555+
for a, b in Iterable():
556+
reveal_type(a) # revealed: int
557+
reveal_type(b) # revealed: int
558+
```
559+
560+
### Custom iterator (2)
561+
562+
```py
563+
class Iterator:
564+
def __next__(self) -> bytes:
565+
return b""
566+
567+
class Iterable:
568+
def __iter__(self) -> Iterator:
569+
return Iterator()
570+
571+
def _(arg: tuple[tuple[int, str], Iterable]):
572+
for a, b in arg:
573+
reveal_type(a) # revealed: int | bytes
574+
reveal_type(b) # revealed: str | bytes
575+
```

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::semantic_index::symbol::{
2626
};
2727
use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
2828
use crate::semantic_index::SemanticIndex;
29-
use crate::unpack::Unpack;
29+
use crate::unpack::{Unpack, UnpackValue};
3030
use crate::Db;
3131

3232
use super::constraint::{Constraint, ConstraintNode, PatternConstraint};
@@ -726,7 +726,7 @@ where
726726
unsafe {
727727
AstNodeRef::new(self.module.clone(), target)
728728
},
729-
value,
729+
UnpackValue::Assign(value),
730730
countme::Count::default(),
731731
)),
732732
})
@@ -909,16 +909,45 @@ where
909909
orelse,
910910
},
911911
) => {
912-
self.add_standalone_expression(iter);
912+
debug_assert_eq!(&self.current_assignments, &[]);
913+
914+
let iter_expr = self.add_standalone_expression(iter);
913915
self.visit_expr(iter);
914916

915917
let pre_loop = self.flow_snapshot();
916918
let saved_break_states = std::mem::take(&mut self.loop_break_states);
917919

918-
debug_assert_eq!(&self.current_assignments, &[]);
919-
self.push_assignment(for_stmt.into());
920+
let current_assignment = match &**target {
921+
ast::Expr::List(_) | ast::Expr::Tuple(_) => Some(CurrentAssignment::For {
922+
node: for_stmt,
923+
first: true,
924+
unpack: Some(Unpack::new(
925+
self.db,
926+
self.file,
927+
self.current_scope(),
928+
#[allow(unsafe_code)]
929+
unsafe {
930+
AstNodeRef::new(self.module.clone(), target)
931+
},
932+
UnpackValue::Iterable(iter_expr),
933+
countme::Count::default(),
934+
)),
935+
}),
936+
ast::Expr::Name(_) => Some(CurrentAssignment::For {
937+
node: for_stmt,
938+
unpack: None,
939+
first: false,
940+
}),
941+
_ => None,
942+
};
943+
944+
if let Some(current_assignment) = current_assignment {
945+
self.push_assignment(current_assignment);
946+
}
920947
self.visit_expr(target);
921-
self.pop_assignment();
948+
if current_assignment.is_some() {
949+
self.pop_assignment();
950+
}
922951

923952
// TODO: Definitions created by loop variables
924953
// (and definitions created inside the body)
@@ -1136,12 +1165,18 @@ where
11361165
Some(CurrentAssignment::AugAssign(aug_assign)) => {
11371166
self.add_definition(symbol, aug_assign);
11381167
}
1139-
Some(CurrentAssignment::For(node)) => {
1168+
Some(CurrentAssignment::For {
1169+
node,
1170+
first,
1171+
unpack,
1172+
}) => {
11401173
self.add_definition(
11411174
symbol,
11421175
ForStmtDefinitionNodeRef {
1176+
unpack,
1177+
first,
11431178
iterable: &node.iter,
1144-
target: name_node,
1179+
name: name_node,
11451180
is_async: node.is_async,
11461181
},
11471182
);
@@ -1177,7 +1212,9 @@ where
11771212
}
11781213
}
11791214

1180-
if let Some(CurrentAssignment::Assign { first, .. }) = self.current_assignment_mut()
1215+
if let Some(
1216+
CurrentAssignment::Assign { first, .. } | CurrentAssignment::For { first, .. },
1217+
) = self.current_assignment_mut()
11811218
{
11821219
*first = false;
11831220
}
@@ -1391,7 +1428,11 @@ enum CurrentAssignment<'a> {
13911428
},
13921429
AnnAssign(&'a ast::StmtAnnAssign),
13931430
AugAssign(&'a ast::StmtAugAssign),
1394-
For(&'a ast::StmtFor),
1431+
For {
1432+
node: &'a ast::StmtFor,
1433+
first: bool,
1434+
unpack: Option<Unpack<'a>>,
1435+
},
13951436
Named(&'a ast::ExprNamed),
13961437
Comprehension {
13971438
node: &'a ast::Comprehension,
@@ -1415,12 +1456,6 @@ impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> {
14151456
}
14161457
}
14171458

1418-
impl<'a> From<&'a ast::StmtFor> for CurrentAssignment<'a> {
1419-
fn from(value: &'a ast::StmtFor) -> Self {
1420-
Self::For(value)
1421-
}
1422-
}
1423-
14241459
impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
14251460
fn from(value: &'a ast::ExprNamed) -> Self {
14261461
Self::Named(value)

crates/red_knot_python_semantic/src/semantic_index/definition.rs

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,10 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {
225225

226226
#[derive(Copy, Clone, Debug)]
227227
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
228+
pub(crate) unpack: Option<Unpack<'a>>,
228229
pub(crate) iterable: &'a ast::Expr,
229-
pub(crate) target: &'a ast::ExprName,
230+
pub(crate) name: &'a ast::ExprName,
231+
pub(crate) first: bool,
230232
pub(crate) is_async: bool,
231233
}
232234

@@ -298,12 +300,16 @@ impl<'db> DefinitionNodeRef<'db> {
298300
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
299301
}
300302
DefinitionNodeRef::For(ForStmtDefinitionNodeRef {
303+
unpack,
301304
iterable,
302-
target,
305+
name,
306+
first,
303307
is_async,
304308
}) => DefinitionKind::For(ForStmtDefinitionKind {
309+
target: TargetKind::from(unpack),
305310
iterable: AstNodeRef::new(parsed.clone(), iterable),
306-
target: AstNodeRef::new(parsed, target),
311+
name: AstNodeRef::new(parsed, name),
312+
first,
307313
is_async,
308314
}),
309315
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
@@ -382,10 +388,12 @@ impl<'db> DefinitionNodeRef<'db> {
382388
Self::AnnotatedAssignment(node) => node.into(),
383389
Self::AugmentedAssignment(node) => node.into(),
384390
Self::For(ForStmtDefinitionNodeRef {
391+
unpack: _,
385392
iterable: _,
386-
target,
393+
name,
394+
first: _,
387395
is_async: _,
388-
}) => target.into(),
396+
}) => name.into(),
389397
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
390398
Self::VariadicPositionalParameter(node) => node.into(),
391399
Self::VariadicKeywordParameter(node) => node.into(),
@@ -452,7 +460,7 @@ pub enum DefinitionKind<'db> {
452460
Assignment(AssignmentDefinitionKind<'db>),
453461
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
454462
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
455-
For(ForStmtDefinitionKind),
463+
For(ForStmtDefinitionKind<'db>),
456464
Comprehension(ComprehensionDefinitionKind),
457465
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
458466
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
@@ -477,7 +485,7 @@ impl Ranged for DefinitionKind<'_> {
477485
DefinitionKind::Assignment(assignment) => assignment.name().range(),
478486
DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(),
479487
DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.target.range(),
480-
DefinitionKind::For(for_stmt) => for_stmt.target().range(),
488+
DefinitionKind::For(for_stmt) => for_stmt.name().range(),
481489
DefinitionKind::Comprehension(comp) => comp.target().range(),
482490
DefinitionKind::VariadicPositionalParameter(parameter) => parameter.name.range(),
483491
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.name.range(),
@@ -665,22 +673,32 @@ impl WithItemDefinitionKind {
665673
}
666674

667675
#[derive(Clone, Debug)]
668-
pub struct ForStmtDefinitionKind {
676+
pub struct ForStmtDefinitionKind<'db> {
677+
target: TargetKind<'db>,
669678
iterable: AstNodeRef<ast::Expr>,
670-
target: AstNodeRef<ast::ExprName>,
679+
name: AstNodeRef<ast::ExprName>,
680+
first: bool,
671681
is_async: bool,
672682
}
673683

674-
impl ForStmtDefinitionKind {
684+
impl<'db> ForStmtDefinitionKind<'db> {
675685
pub(crate) fn iterable(&self) -> &ast::Expr {
676686
self.iterable.node()
677687
}
678688

679-
pub(crate) fn target(&self) -> &ast::ExprName {
680-
self.target.node()
689+
pub(crate) fn target(&self) -> TargetKind<'db> {
690+
self.target
681691
}
682692

683-
pub(crate) fn is_async(&self) -> bool {
693+
pub(crate) fn name(&self) -> &ast::ExprName {
694+
self.name.node()
695+
}
696+
697+
pub(crate) const fn is_first(&self) -> bool {
698+
self.first
699+
}
700+
701+
pub(crate) const fn is_async(&self) -> bool {
684702
self.is_async
685703
}
686704
}
@@ -756,12 +774,6 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
756774
}
757775
}
758776

759-
impl From<&ast::StmtFor> for DefinitionNodeKey {
760-
fn from(value: &ast::StmtFor) -> Self {
761-
Self(NodeKey::from_node(value))
762-
}
763-
}
764-
765777
impl From<&ast::Parameter> for DefinitionNodeKey {
766778
fn from(node: &ast::Parameter) -> Self {
767779
Self(NodeKey::from_node(node))

0 commit comments

Comments
 (0)