@@ -2050,24 +2050,44 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
20502050 }}
20512051 val argVals = argVals0.reverse
20522052 val argRefs = argRefs0.reverse
2053- def rec (fn : Tree ): Tree = fn match {
2053+ def rec (fn : Tree , topAscription : Option [TypeTree ]): Tree = fn match {
2054+ case Typed (expr, tpt) =>
2055+ // we need to retain any type ascriptions we see and:
2056+ // a) if we succeed, ascribe the result type of the ascription to the inlined body
2057+ // b) if we fail, re-ascribe the same type to whatever it was we couldn't inline
2058+ // note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects
2059+ rec(expr, topAscription.orElse(Some (tpt)))
20542060 case Inlined (call, bindings, expansion) =>
20552061 // this case must go before closureDef to avoid dropping the inline node
2056- cpy.Inlined (fn)(call, bindings, rec(expansion))
2062+ cpy.Inlined (fn)(call, bindings, rec(expansion, topAscription ))
20572063 case closureDef(ddef) =>
20582064 val paramSyms = ddef.vparamss.head.map(param => param.symbol)
20592065 val paramToVals = paramSyms.zip(argRefs).toMap
2060- new TreeTypeMap (
2066+ val result = new TreeTypeMap (
20612067 oldOwners = ddef.symbol :: Nil ,
20622068 newOwners = ctx.owner :: Nil ,
20632069 treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
20642070 ).transform(ddef.rhs)
2071+ topAscription match {
2072+ case Some (tpt) =>
2073+ // we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys)
2074+ val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf [MethodType ]
2075+ // result might contain paramrefs, so we substitute them with arg termrefs
2076+ val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe))
2077+ Typed (result, TypeTree (resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)
2078+ case None =>
2079+ result
2080+ }
20652081 case tpd.Block (stats, expr) =>
2066- seq(stats, rec(expr)).withSpan(fn.span)
2082+ seq(stats, rec(expr, topAscription )).withSpan(fn.span)
20672083 case _ =>
2068- fn.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
2084+ val maybeAscribed = topAscription match {
2085+ case Some (tpt) => Typed (fn, tpt).withSpan(fn.span)
2086+ case None => fn
2087+ }
2088+ maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
20692089 }
2070- seq(argVals, rec(fn))
2090+ seq(argVals, rec(fn, None ))
20712091 }
20722092
20732093 // ///////////
0 commit comments