Skip to content

Commit c206550

Browse files
ruippeixotogfacebook-github-bot
authored andcommitted
Add custom type checking for element/2
Summary: This would have caught the issue in D50360629, where the usage of `element` was clearly unsafe. Reviewed By: ilya-klyuchnikov Differential Revision: D50412145 fbshipit-source-id: a93a679027f7de3a2f469cfcf78fbcd53237d10c
1 parent bb47c23 commit c206550

File tree

8 files changed

+389
-4
lines changed

8 files changed

+389
-4
lines changed

eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
package com.whatsapp.eqwalizer.tc
88

99
import scala.annotation.tailrec
10-
import com.whatsapp.eqwalizer.ast.Exprs.{AtomLit, Cons, Expr, Lambda, NilLit}
10+
import com.whatsapp.eqwalizer.ast.Exprs.{AtomLit, Cons, Expr, IntLit, Lambda, NilLit}
1111
import com.whatsapp.eqwalizer.ast.Types._
1212
import com.whatsapp.eqwalizer.ast.{Exprs, Pos, RemoteId}
13-
import com.whatsapp.eqwalizer.tc.TcDiagnostics.{ExpectedSubtype, UnboundVar, UnboundRecord}
13+
import com.whatsapp.eqwalizer.tc.TcDiagnostics.{ExpectedSubtype, IndexOutOfBounds, UnboundVar, UnboundRecord}
1414
import com.whatsapp.eqwalizer.ast.CompilerMacro
1515

1616
class ElabApplyCustom(pipelineContext: PipelineContext) {
@@ -28,6 +28,7 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
2828

2929
private lazy val custom: Set[RemoteId] =
3030
Set(
31+
RemoteId("erlang", "element", 2),
3132
RemoteId("erlang", "map_get", 2),
3233
RemoteId("file", "open", 2),
3334
RemoteId("lists", "filtermap", 2),
@@ -296,6 +297,36 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
296297
(valTy, env1)
297298
}
298299

300+
/*
301+
`-spec element(N :: NumberType, Tup :: TupleType) -> Out`, where `Out` is:
302+
- `Tup[N]` when `N` is an integer literal corresponding to a valid index
303+
- Union of element types of `Tup` when `N` is not a literal
304+
- An error otherwise (index out of bounds or unexpected type)
305+
*/
306+
case RemoteId("erlang", "element", 2) =>
307+
val List(index, tuple) = args
308+
val List(indexTy, tupleTy) = argTys
309+
310+
def validate(): Unit = {
311+
if (!subtype.subType(indexTy, NumberType))
312+
throw ExpectedSubtype(index.pos, index, expected = NumberType, got = indexTy)
313+
if (!subtype.subType(tupleTy, AnyTupleType))
314+
throw ExpectedSubtype(tuple.pos, tuple, expected = AnyTupleType, got = tupleTy)
315+
}
316+
validate()
317+
318+
val elemTy = index match {
319+
case IntLit(Some(n)) =>
320+
narrow.getTupleElement(tupleTy, n) match {
321+
case Right(elemTy) => elemTy
322+
case Left(tupLen) => throw IndexOutOfBounds(callPos, index, n, tupLen)
323+
}
324+
case _ =>
325+
narrow.getAllTupleElements(tupleTy)
326+
}
327+
328+
(elemTy, env1)
329+
299330
case RemoteId("maps", "get", 3) =>
300331
val List(key, map, defaultVal) = args
301332
val List(keyTy, mapTy, defaultValTy) = argTys

eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,104 @@ class Narrow(pipelineContext: PipelineContext) {
351351
case _ => List()
352352
}
353353

354+
/**
355+
* Given a type (required to be a subtype of `AnyTupleType`) and an index, returns the type of the tuple element at
356+
* the index wrapped in a `Right`. If the index can be possibly out of bounds (in at least one of the options in a
357+
* union) the function returns `Left(tupLen)`, where `tupLen` is the minimum index value for which this operation would
358+
* type check.
359+
*/
360+
def getTupleElement(t: Type, idx: Int): Either[Int, Type] = t match {
361+
case NoneType =>
362+
Right(NoneType)
363+
case DynamicType =>
364+
Right(DynamicType)
365+
case AnyTupleType if pipelineContext.gradualTyping =>
366+
Right(DynamicType)
367+
case AnyTupleType =>
368+
Right(AnyType)
369+
case BoundedDynamicType(t) if subtype.subType(t, AnyTupleType) =>
370+
Right(BoundedDynamicType(getTupleElement(t, idx).getOrElse(NoneType)))
371+
case BoundedDynamicType(t) =>
372+
Right(BoundedDynamicType(NoneType))
373+
case TupleType(elemTys) if idx >= 1 && idx <= elemTys.length =>
374+
Right(elemTys(idx - 1))
375+
case TupleType(elemTys) =>
376+
Left(elemTys.length)
377+
case r: RecordType =>
378+
recordToTuple(r) match {
379+
case Some(tupTy) => getTupleElement(tupTy, idx)
380+
case None if pipelineContext.gradualTyping => Right(DynamicType)
381+
case None => Right(AnyType)
382+
}
383+
case r: RefinedRecordType =>
384+
refinedRecordToTuple(r) match {
385+
case Some(tupTy) => getTupleElement(tupTy, idx)
386+
case None if pipelineContext.gradualTyping => Right(DynamicType)
387+
case None => Right(AnyType)
388+
}
389+
case UnionType(tys) =>
390+
val res = tys.map(getTupleElement(_, idx)).foldLeft[Either[Int, Set[Type]]](Right(Set.empty)) {
391+
case (Right(accTy), Right(elemTy)) => Right(accTy + elemTy)
392+
case (Left(n1), Left(n2)) => Left(n1.min(n2))
393+
case (Left(n1), _) => Left(n1)
394+
case (_, Left(n2)) => Left(n2)
395+
}
396+
res.map { optionTys => UnionType(util.flattenUnions(UnionType(optionTys)).toSet) }
397+
case RemoteType(rid, args) =>
398+
val body = util.getTypeDeclBody(rid, args)
399+
getTupleElement(body, idx)
400+
case _ =>
401+
throw new IllegalStateException()
402+
}
403+
404+
/**
405+
* Given a type (required to be a subtype of `AnyTupleType`), returns the union of all its element types.
406+
*/
407+
def getAllTupleElements(t: Type): Type = t match {
408+
case NoneType =>
409+
NoneType
410+
case DynamicType =>
411+
DynamicType
412+
case AnyTupleType if pipelineContext.gradualTyping =>
413+
DynamicType
414+
case AnyTupleType =>
415+
AnyType
416+
case BoundedDynamicType(t) if subtype.subType(t, AnyTupleType) =>
417+
BoundedDynamicType(getAllTupleElements(t))
418+
case BoundedDynamicType(t) =>
419+
BoundedDynamicType(NoneType)
420+
case TupleType(elemTys) =>
421+
UnionType(elemTys.toSet)
422+
case r: RecordType =>
423+
recordToTuple(r) match {
424+
case Some(tupTy) => getAllTupleElements(tupTy)
425+
case None if pipelineContext.gradualTyping => DynamicType
426+
case None => AnyType
427+
}
428+
case r: RefinedRecordType =>
429+
refinedRecordToTuple(r) match {
430+
case Some(tupTy) => getAllTupleElements(tupTy)
431+
case None if pipelineContext.gradualTyping => DynamicType
432+
case None => AnyType
433+
}
434+
case UnionType(tys) =>
435+
UnionType(util.flattenUnions(UnionType(tys.map(getAllTupleElements))).toSet)
436+
case RemoteType(rid, args) =>
437+
val body = util.getTypeDeclBody(rid, args)
438+
getAllTupleElements(body)
439+
case _ =>
440+
throw new IllegalStateException()
441+
}
442+
443+
private def recordToTuple(r: RecordType): Option[TupleType] =
444+
refinedRecordToTuple(RefinedRecordType(r, Map()))
445+
446+
private def refinedRecordToTuple(r: RefinedRecordType): Option[TupleType] =
447+
util.getRecord(r.recType.module, r.recType.name).map { recDecl =>
448+
val elemTys = AtomLitType(r.recType.name) :: recDecl.fields.map(f => r.fields.getOrElse(f._1, f._2.tp)).toList
449+
TupleType(elemTys)
450+
}
451+
354452
private def adjustShapeMap(t: ShapeMap, keyT: Type, valT: Type): Type =
355453
keyT match {
356454
case AtomLitType(key) =>

eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/TcDiagnostics.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ object TcDiagnostics {
5656
def errorName = "fun_arity_mismatch"
5757
override def erroneousExpr: Option[Expr] = Some(expr)
5858
}
59+
case class IndexOutOfBounds(pos: Pos, expr: Expr, index: Int, tupleArity: Int) extends TypeError {
60+
override val msg: String = s"Tried to access element $index of a tuple with $tupleArity elements"
61+
def errorName = "index_out_of_bounds"
62+
override def erroneousExpr: Option[Expr] = Some(expr)
63+
}
5964
case class NotSupportedLambdaInOverloadedCall(pos: Pos, expr: Expr) extends TypeError {
6065
override val msg: String = s"Lambdas are not allowed as args to overloaded functions"
6166
def errorName = "fun_in_overload_arg"

eqwalizer/test_projects/_cli/otp_funs.cli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ gb_sets 26
1616
proplists 51
1717
maps 149
1818
lists 169
19-
erlang 374
19+
erlang 396
2020
Per app stats:
2121
kernel 21
22-
erts 374
22+
erts 396
2323
stdlib 471

eqwalizer/test_projects/check/src/custom.erl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,92 @@
88
-import(maps, [get/2, get/3]).
99
-compile([export_all, nowarn_export_all]).
1010

11+
-record(foo, {
12+
a :: ok | error,
13+
b :: number(),
14+
c :: string()
15+
}).
16+
17+
% element/2 - basic examples
18+
19+
-spec element_2_basic_1({atom(), number(), string()}) -> atom().
20+
element_2_basic_1(Tup) ->
21+
element(1, Tup).
22+
23+
-spec element_2_basic_2_neg({atom(), number(), string(), map()}) -> atom().
24+
element_2_basic_2_neg(Tup) ->
25+
element(4, Tup).
26+
27+
-spec element_2_basic_3_neg({atom(), number(), string()}) -> atom().
28+
element_2_basic_3_neg(Tup) ->
29+
element(42, Tup).
30+
31+
% element/2 - union examples
32+
33+
-spec element_2_union_1({atom(), number() | string()} | {number(), atom()}) -> number() | string() | atom().
34+
element_2_union_1(Tup) ->
35+
element(2, Tup).
36+
37+
-spec element_2_union_2_neg({atom(), number() | string()} | {number(), atom()}) -> map().
38+
element_2_union_2_neg(Tup) ->
39+
element(2, Tup).
40+
41+
-spec element_2_union_3_neg({atom(), string()} | list()) -> string().
42+
element_2_union_3_neg(Tup) ->
43+
element(2, Tup).
44+
45+
-spec element_2_union_4_neg({c, d, e, f} | {a, b} | {b, c, d}) -> atom().
46+
element_2_union_4_neg(Tup) ->
47+
element(42, Tup).
48+
49+
% element/2 - dynamic index examples
50+
51+
-spec element_2_dynindex_1_neg(pos_integer(), {atom(), number(), string()}) -> map().
52+
element_2_dynindex_1_neg(N, Tup) ->
53+
element(N, Tup).
54+
55+
-spec element_2_dynindex_2_neg(pos_integer(), {atom(), atom()} | {atom(), atom(), number()}) -> atom().
56+
element_2_dynindex_2_neg(N, Tup) ->
57+
element(N, Tup).
58+
59+
% element/2 - tuple() examples
60+
61+
-spec element_2_anytuple_1_neg(tuple()) -> atom().
62+
element_2_anytuple_1_neg(Tup) ->
63+
element(1, Tup).
64+
65+
-spec element_2_anytuple_2_neg(tuple() | {number(), atom()}) -> atom().
66+
element_2_anytuple_2_neg(Tup) ->
67+
element(1, Tup).
68+
69+
% element/2 - record examples
70+
71+
-spec element_2_record_1(#foo{}) -> foo.
72+
element_2_record_1(Rec) ->
73+
element(1, Rec).
74+
75+
-spec element_2_record_2(#foo{}) -> ok | error.
76+
element_2_record_2(Rec) ->
77+
element(2, Rec).
78+
79+
-spec element_2_record_3(#foo{}) -> ok.
80+
element_2_record_3(Rec) when Rec#foo.a =/= error ->
81+
element(2, Rec).
82+
83+
-spec element_2_record_4_neg(pos_integer(), #foo{}) -> atom().
84+
element_2_record_4_neg(N, Rec) ->
85+
element(N, Rec).
86+
87+
% element/2 - none examples
88+
89+
-spec element_2_none_1(none()) -> number().
90+
element_2_none_1(Tup) ->
91+
element(42, Tup).
92+
93+
-spec element_2_none_2(pos_integer(), none()) -> number().
94+
element_2_none_2(N, Tup) ->
95+
element(N, Tup).
96+
1197
-spec map_get_2_1(
1298
pid(), #{pid() => atom()}
1399
) -> atom().

0 commit comments

Comments
 (0)