@@ -12,7 +12,7 @@ import com.whatsapp.eqwalizer.ast.Types.*
12
12
import com .whatsapp .eqwalizer .ast .{Exprs , Pos , RemoteId }
13
13
import com .whatsapp .eqwalizer .tc .TcDiagnostics .{ExpectedSubtype , IndexOutOfBounds , UnboundRecord }
14
14
import com .whatsapp .eqwalizer .ast .CompilerMacro
15
- import com .whatsapp .eqwalizer .ast .Pats .{PatAtom , PatVar , PatWild }
15
+ import com .whatsapp .eqwalizer .ast .Pats .{Pat , PatAtom , PatTuple , PatVar , PatWild }
16
16
17
17
import scala .collection .mutable
18
18
@@ -57,6 +57,7 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
57
57
RemoteId (" maps" , " filtermap" , 2 ),
58
58
RemoteId (" maps" , " find" , 2 ),
59
59
RemoteId (" maps" , " fold" , 3 ),
60
+ RemoteId (" maps" , " foreach" , 2 ),
60
61
RemoteId (" maps" , " get" , 2 ),
61
62
RemoteId (" maps" , " get" , 3 ),
62
63
RemoteId (" maps" , " intersect" , 2 ),
@@ -514,6 +515,69 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
514
515
val resItemTy = funResultsToValTy(funResTys, valTy, callPos)
515
516
(mapTys.map(narrow.setAllFieldsOptional(_, Some (resItemTy))).join(), env1)
516
517
518
+ case RemoteId (" maps" , " foreach" , 2 ) =>
519
+ val List (funArg, map) = args
520
+ val List (funArgTy, mapTy) = argTys
521
+ val mapTys = coerceToMapsOrIter(map, mapTy)
522
+ val keyTy = mapTys.map(narrow.getKeyType).join()
523
+ val valTy = mapTys.map(narrow.getValType).join()
524
+ def isShapeKey (pat : Pat ): Boolean = {
525
+ pat match {
526
+ case PatAtom (_) => true
527
+ case PatTuple (pats) => pats.forall(isShapeKey)
528
+ case _ => false
529
+ }
530
+ }
531
+ def asKey (pat : Pat ): Key = {
532
+ pat match {
533
+ case PatAtom (a) => AtomKey (a)
534
+ case PatTuple (pats) => TupleKey (pats.map(asKey))
535
+ case _ => throw new IllegalStateException (s " unexpected pattern: $pat" )
536
+ }
537
+ }
538
+ def isShapeIterator (lambda : Lambda ): Boolean = {
539
+ val usedKeys = mutable.Set [Pat ]()
540
+ (lambda.clauses forall { clause =>
541
+ clause.pats match {
542
+ case List (pat, _) if isShapeKey(pat) && ! usedKeys(pat) =>
543
+ usedKeys.add(pat)
544
+ true
545
+ case List (PatVar (_) | PatWild (), _) =>
546
+ true
547
+ case _ =>
548
+ false
549
+ }
550
+ }) && (lambda.clauses.count { clause =>
551
+ clause.pats match {
552
+ case List (PatVar (_) | PatWild (), _, _) =>
553
+ true
554
+ case _ =>
555
+ false
556
+ }
557
+ } <= 1 )
558
+ }
559
+
560
+ val expFunTy = FunType (Nil , List (keyTy, valTy), AnyType )
561
+ var keyTyLast = keyTy
562
+ funArg match {
563
+ case lambda : Lambda if isShapeIterator(lambda) =>
564
+ val lamEnv = lambda.name.map(name => env.updated(name, expFunTy)).getOrElse(env)
565
+ lambda.clauses.foreach { clause =>
566
+ if (isShapeKey(clause.pats.head)) {
567
+ val key = asKey(clause.pats.head)
568
+ val refinedValTy = UnionType (mapTys.map(m => narrow.getValType(key, m)))
569
+ val kTy = Key .asType(key)
570
+ keyTyLast = occurrence.remove(keyTyLast, kTy)
571
+ check.checkClause(clause, List (kTy, refinedValTy), AnyType , lamEnv, Set .empty)
572
+ } else {
573
+ check.checkClause(clause, List (keyTyLast, valTy), AnyType , lamEnv, Set .empty)
574
+ }
575
+ }
576
+ case _ =>
577
+ check.checkExpr(funArg, expFunTy, env)
578
+ }
579
+ (AtomLitType (" ok" ), env)
580
+
517
581
case RemoteId (" maps" , " remove" , 2 ) =>
518
582
val List (keyArg, map) = args
519
583
val List (keyTy, mapTy) = argTys
0 commit comments