Skip to content

Commit 3dc262d

Browse files
custom handling of maps:foreach
Summary: - implementing the logic similar to maps:fold Reviewed By: VLanvin Differential Revision: D77147855 fbshipit-source-id: c0a0c7df292b6e0fbe01760de4b1bca518d1d05e
1 parent 3f2edc1 commit 3dc262d

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import com.whatsapp.eqwalizer.ast.Types.*
1212
import com.whatsapp.eqwalizer.ast.{Exprs, Pos, RemoteId}
1313
import com.whatsapp.eqwalizer.tc.TcDiagnostics.{ExpectedSubtype, IndexOutOfBounds, UnboundRecord}
1414
import com.whatsapp.eqwalizer.ast.CompilerMacro
15-
import com.whatsapp.eqwalizer.ast.Pats.{PatAtom, PatVar, PatWild}
15+
import com.whatsapp.eqwalizer.ast.Pats.{Pat, PatAtom, PatTuple, PatVar, PatWild}
1616

1717
import scala.collection.mutable
1818

@@ -57,6 +57,7 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
5757
RemoteId("maps", "filtermap", 2),
5858
RemoteId("maps", "find", 2),
5959
RemoteId("maps", "fold", 3),
60+
RemoteId("maps", "foreach", 2),
6061
RemoteId("maps", "get", 2),
6162
RemoteId("maps", "get", 3),
6263
RemoteId("maps", "intersect", 2),
@@ -514,6 +515,69 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
514515
val resItemTy = funResultsToValTy(funResTys, valTy, callPos)
515516
(mapTys.map(narrow.setAllFieldsOptional(_, Some(resItemTy))).join(), env1)
516517

518+
case RemoteId("maps", "foreach", 2) =>
519+
val List(funArg, map) = args
520+
val List(funArgTy, mapTy) = argTys
521+
val mapTys = coerceToMapsOrIter(map, mapTy)
522+
val keyTy = mapTys.map(narrow.getKeyType).join()
523+
val valTy = mapTys.map(narrow.getValType).join()
524+
def isShapeKey(pat: Pat): Boolean = {
525+
pat match {
526+
case PatAtom(_) => true
527+
case PatTuple(pats) => pats.forall(isShapeKey)
528+
case _ => false
529+
}
530+
}
531+
def asKey(pat: Pat): Key = {
532+
pat match {
533+
case PatAtom(a) => AtomKey(a)
534+
case PatTuple(pats) => TupleKey(pats.map(asKey))
535+
case _ => throw new IllegalStateException(s"unexpected pattern: $pat")
536+
}
537+
}
538+
def isShapeIterator(lambda: Lambda): Boolean = {
539+
val usedKeys = mutable.Set[Pat]()
540+
(lambda.clauses forall { clause =>
541+
clause.pats match {
542+
case List(pat, _) if isShapeKey(pat) && !usedKeys(pat) =>
543+
usedKeys.add(pat)
544+
true
545+
case List(PatVar(_) | PatWild(), _) =>
546+
true
547+
case _ =>
548+
false
549+
}
550+
}) && (lambda.clauses.count { clause =>
551+
clause.pats match {
552+
case List(PatVar(_) | PatWild(), _, _) =>
553+
true
554+
case _ =>
555+
false
556+
}
557+
} <= 1)
558+
}
559+
560+
val expFunTy = FunType(Nil, List(keyTy, valTy), AnyType)
561+
var keyTyLast = keyTy
562+
funArg match {
563+
case lambda: Lambda if isShapeIterator(lambda) =>
564+
val lamEnv = lambda.name.map(name => env.updated(name, expFunTy)).getOrElse(env)
565+
lambda.clauses.foreach { clause =>
566+
if (isShapeKey(clause.pats.head)) {
567+
val key = asKey(clause.pats.head)
568+
val refinedValTy = UnionType(mapTys.map(m => narrow.getValType(key, m)))
569+
val kTy = Key.asType(key)
570+
keyTyLast = occurrence.remove(keyTyLast, kTy)
571+
check.checkClause(clause, List(kTy, refinedValTy), AnyType, lamEnv, Set.empty)
572+
} else {
573+
check.checkClause(clause, List(keyTyLast, valTy), AnyType, lamEnv, Set.empty)
574+
}
575+
}
576+
case _ =>
577+
check.checkExpr(funArg, expFunTy, env)
578+
}
579+
(AtomLitType("ok"), env)
580+
517581
case RemoteId("maps", "remove", 2) =>
518582
val List(keyArg, map) = args
519583
val List(keyTy, mapTy) = argTys

0 commit comments

Comments
 (0)