@@ -194,6 +194,63 @@ is_sumtype(T::Type) = false
194194
195195function variant_idx end
196196
197+ function _is_sumtype_structurally (T)
198+ return T isa DataType && fieldcount (T) == 1 && fieldname (T, 1 ) === :variants && fieldtype (T, 1 ) isa Union
199+ end
200+
201+ function _get_variant_types (T_sum)
202+ field_T = fieldtype (T_sum, 1 )
203+
204+ ! (field_T isa Union) && return [field_T]
205+
206+ types = []
207+ curr = field_T
208+ while curr isa Union
209+ push! (types, curr. a)
210+ curr = curr. b
211+ end
212+ push! (types, curr)
213+ return types
214+ end
215+
216+ @generated function apply (f:: F , args:: Tuple ) where {F}
217+
218+
219+ args = fieldtypes (args)
220+ sumtype_args = [(i, T) for (i, T) in enumerate (args) if _is_sumtype_structurally (T)]
221+
222+ if isempty (sumtype_args)
223+ return :(f (args... ))
224+ end
225+
226+ final_args = Any[:(args[$ i]) for i in 1 : length (args)]
227+ for (idx, T) in sumtype_args
228+ final_args[idx] = Symbol (" v_" , idx)
229+ end
230+
231+ body = :(f ($ (final_args... )))
232+
233+ for (idx, T) in reverse (sumtype_args)
234+ unwrapped_var = Symbol (" v_" , idx)
235+
236+ variant_types = _get_variant_types (T)
237+
238+ branch_expr = :(error (" THIS_SHOULD_BE_UNREACHABLE" ))
239+ for V_type in reverse (variant_types)
240+ condition = :($ unwrapped_var isa $ V_type)
241+ branch_expr = Expr (:elseif , condition, body, branch_expr)
242+ end
243+ branch_expr = Expr (:if , branch_expr. args... )
244+
245+ body = quote
246+ let $ (unwrapped_var) = $ LightSumTypes. unwrap (args[$ idx])
247+ $ branch_expr
248+ end
249+ end
250+ end
251+ return body
252+ end
253+
197254include (" precompile.jl" )
198255
199256end
0 commit comments