Skip to content

Commit fcb08e4

Browse files
authored
Add apply function for type stable calls of functions
1 parent d70d825 commit fcb08e4

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

src/LightSumTypes.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,63 @@ is_sumtype(T::Type) = false
194194

195195
function variant_idx end
196196

197+
function _is_sumtype_structurally(T)
198+
return T isa DataType && fieldcount(T) == 1 && fieldname(T, 1) === :variants && fieldtype(T, 1) isa Union
199+
end
200+
201+
function _get_variant_types(T_sum)
202+
field_T = fieldtype(T_sum, 1)
203+
204+
!(field_T isa Union) && return [field_T]
205+
206+
types = []
207+
curr = field_T
208+
while curr isa Union
209+
push!(types, curr.a)
210+
curr = curr.b
211+
end
212+
push!(types, curr)
213+
return types
214+
end
215+
216+
@generated function apply(f::F, args::Tuple) where {F}
217+
218+
219+
args = fieldtypes(args)
220+
sumtype_args = [(i, T) for (i, T) in enumerate(args) if _is_sumtype_structurally(T)]
221+
222+
if isempty(sumtype_args)
223+
return :(f(args...))
224+
end
225+
226+
final_args = Any[:(args[$i]) for i in 1:length(args)]
227+
for (idx, T) in sumtype_args
228+
final_args[idx] = Symbol("v_", idx)
229+
end
230+
231+
body = :(f($(final_args...)))
232+
233+
for (idx, T) in reverse(sumtype_args)
234+
unwrapped_var = Symbol("v_", idx)
235+
236+
variant_types = _get_variant_types(T)
237+
238+
branch_expr = :(error("THIS_SHOULD_BE_UNREACHABLE"))
239+
for V_type in reverse(variant_types)
240+
condition = :($unwrapped_var isa $V_type)
241+
branch_expr = Expr(:elseif, condition, body, branch_expr)
242+
end
243+
branch_expr = Expr(:if, branch_expr.args...)
244+
245+
body = quote
246+
let $(unwrapped_var) = $LightSumTypes.unwrap(args[$idx])
247+
$branch_expr
248+
end
249+
end
250+
end
251+
return body
252+
end
253+
197254
include("precompile.jl")
198255

199256
end

0 commit comments

Comments
 (0)