-
Notifications
You must be signed in to change notification settings - Fork 74
Description
Following the correspondence I had with @wsmoses on Julia discourse, the gradient of the cfunc
entry in the LLVM IR emitted by a @cfunc
decorated Numba function turns out to be wrong. This is not the case when we directly get the gradient of the function in the LLVM IR code.
A small Python script to get the LLVM IR of a @cfunc
decorated Numba function. The function is a simple multiplication of two float numbers i.e.
import re
from numba import cfunc, types
sig = types.double(
types.double,
types.double,
)
@cfunc(sig)
def func(a, b):
return a * b
if __name__ == "__main__":
# get the LLVM IR
ir_map = func.inspect_llvm()
# Numba mangles the function's name
# what follows here is a few regex
# substitution to get rid of the mangled
# function name and improve readability
lines = ir_map.split("\n")
mangled = None
# there is a global definition of a mangled
# function with `XXNumbaEnv` in it. By removing
# this substring, the mangled function name is
# recovered.
filtered_lines = []
for line in lines:
alias_match = re.match(r'^(@_[^ ]*NumbaEnv[^ ]*)\s*=', line)
if alias_match:
mangled = alias_match.group(1)
continue
filtered_lines.append(line)
if mangled is None:
raise ValueError("Could not find the NumbaEnv alias line.")
# substitute the `XXNumbaEnv` with the empty string.
mangled = re.sub("@", "", mangled)
mangled = re.sub(r"\d{2}NumbaEnv", "", mangled)
# substitute the mangled name with a more readable
# function name. Also change `@cfunc.func` to `@cfunc_func`
# so we can define the symbol `:cfunc_func` in Julia.
clean_irs = []
for line in filtered_lines:
line = re.sub(mangled, r'func', line)
line = re.sub(r'cfunc\.func', r'cfunc_func', line)
clean_irs.append(line + "\n")
ir_map = "".join(clean_irs)
with open("func.ll", "w") as f:
f.write(ir_map)
Then in the shell
using Libdl
using Enzyme
file = read("func.ll");
run(
pipeline(
`clang -x ir - -Xclang -no-opaque-pointers -O3 -fPIC -fembed-bitcode -shared -o libfunc.so`; stdin=IOBuffer(file)
)
);
const lib = Libdl.dlopen("./libfunc.so")
const f_ptr = Libdl.dlsym(lib, :cfunc_func)
const g_ptr = Libdl.dlsym(lib, :func)
f(a, b) = ccall(f_ptr, Cdouble, (Cdouble, Cdouble), a, b)
function g(a::Cdouble, b::Cdouble)
result = Ref{Cdouble}()
exc_info = Ref{Ptr{Nothing}}()
status = ccall(g_ptr, Cint,
(Ptr{Cdouble}, Ptr{Ptr{Nothing}}, Cdouble, Cdouble),
result, exc_info, a, b)
status == 0 || error("Python exception raised!")
result[]
end
a = 2.0
b = 3.0
f(a, b) == g(a, b)
gradient(Reverse, f, a, Const(b))
gradient(Reverse, g, a, Const(b))
Which yields the following results
julia> gradient(Reverse, f, a, Const(b))
(2.072225527840143e-309, nothing)
julia> gradient(Reverse, g, a, Const(b))
(3.0, nothing)
Only the latter produces the correct answer.
The Enzyme output for the gradients after setting Enzyme.API.printall!(true)
.
for the cfunc
entry in the LLVM IR
click to expand
julia> gradient(Reverse, f, a, Const(b))
after simplification :
; Function Attrs: mustprogress nofree willreturn memory(read, argmem: none, inaccessiblemem: none)
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f_6619(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %1) local_unnamed_addr #5 !dbg !18 {
top:
%pgcstack = call {}*** @julia.get_pgcstack() #6
%ptls_field3 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
%2 = bitcast {}*** %ptls_field3 to i64***
%ptls_load45 = load i64**, i64*** %2, align 8, !tbaa !8
%3 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
%safepoint = load i64*, i64** %3, align 8, !tbaa !12
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #6, !dbg !19
fence syncscope("singlethread") seq_cst
%4 = call fastcc double @cfunc_func(double %0, double %1) #7, !dbg !19
ret double %4, !dbg !19
}
after simplification :
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
define internal fastcc double @preprocess_cfunc_func(double %.1, double %.2) unnamed_addr #3 {
entry:
%.4 = alloca double, align 8
store double 0.000000e+00, double* %.4, align 8, !noalias !15
call fastcc void @func(double* noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(8) %.4, double %.1, double %.2) #6
%.18 = load double, double* %.4, align 8
ret double %.18
}
after simplification :
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define internal fastcc void @preprocess_func(double* noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(8) %retptr, double %arg.a, double %arg.b) unnamed_addr #4 {
entry:
%.6 = fmul double %arg.a, %arg.b
store double %.6, double* %retptr, align 8, !noalias !15
ret void
}
; Function Attrs: mustprogress nofree norecurse nosync nounwind
define internal fastcc { double } @diffefunc(double* noalias nocapture nofree writeonly align 8 dereferenceable(8) %retptr, double* nocapture nofree align 8 %"retptr'", double %arg.a, double %arg.b) unnamed_addr #6 {
entry:
%".6'de" = alloca double, align 8
%0 = getelementptr double, double* %".6'de", i64 0
store double 0.000000e+00, double* %0, align 8
%"arg.a'de" = alloca double, align 8
%1 = getelementptr double, double* %"arg.a'de", i64 0
store double 0.000000e+00, double* %1, align 8
br label %invertentry
invertentry: ; preds = %entry
%2 = load double, double* %"retptr'", align 8, !alias.scope !27, !noalias !30
store double 0.000000e+00, double* %"retptr'", align 8, !alias.scope !27, !noalias !30
%3 = load double, double* %".6'de", align 8
%4 = fadd fast double %3, %2
store double %4, double* %".6'de", align 8
%5 = load double, double* %".6'de", align 8
store double 0.000000e+00, double* %".6'de", align 8
%6 = fmul fast double %5, %arg.b
%7 = load double, double* %"arg.a'de", align 8
%8 = fadd fast double %7, %6
store double %8, double* %"arg.a'de", align 8
%9 = load double, double* %"arg.a'de", align 8
%10 = insertvalue { double } undef, double %9, 0
ret { double } %10
}
; Function Attrs: mustprogress nofree norecurse nosync nounwind
define internal fastcc { double } @diffecfunc_func(double %.1, double %.2, double %differeturn) unnamed_addr #6 {
entry:
%".18'de" = alloca double, align 8
%0 = getelementptr double, double* %".18'de", i64 0
store double 0.000000e+00, double* %0, align 8
%".1'de" = alloca double, align 8
%1 = getelementptr double, double* %".1'de", i64 0
store double 0.000000e+00, double* %1, align 8
%".4'ipa" = alloca double, align 8
store double 0.000000e+00, double* %".4'ipa", align 8
br label %invertentry
invertentry: ; preds = %entry
store double %differeturn, double* %".18'de", align 8
%2 = load double, double* %".18'de", align 8
store double 0.000000e+00, double* %".18'de", align 8
%3 = load double, double* %".4'ipa", align 8, !alias.scope !22, !noalias !25
%4 = fadd fast double %3, %2
store double %4, double* %".4'ipa", align 8, !alias.scope !22, !noalias !25
%5 = call fastcc { double } @diffefunc(double* nocapture nofree writeonly align 8 undef, double* nocapture nofree align 8 %".4'ipa", double %.1, double %.2)
%6 = extractvalue { double } %5, 0
%7 = load double, double* %".1'de", align 8
%8 = fadd fast double %7, %6
store double %8, double* %".1'de", align 8
store double 0.000000e+00, double* %".4'ipa", align 8, !alias.scope !22, !noalias !27
%9 = load double, double* %".1'de", align 8
%10 = insertvalue { double } undef, double %9, 0
ret { double } %10
}
; Function Attrs: mustprogress nofree
define internal { double } @diffejulia_f_6619(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %1, double %differeturn) local_unnamed_addr #6 !dbg !20 {
top:
%"'de" = alloca double, align 8
%2 = getelementptr double, double* %"'de", i64 0
store double 0.000000e+00, double* %2, align 8
%"'de1" = alloca double, align 8
%3 = getelementptr double, double* %"'de1", i64 0
store double 0.000000e+00, double* %3, align 8
%pgcstack = call {}*** @julia.get_pgcstack() #9
%ptls_field3 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
%4 = bitcast {}*** %ptls_field3 to i64***
%ptls_load45 = load i64**, i64*** %4, align 8, !tbaa !8, !alias.scope !21, !noalias !24
%5 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
%safepoint = load i64*, i64** %5, align 8, !tbaa !12, !alias.scope !26, !noalias !29
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #9, !dbg !31
fence syncscope("singlethread") seq_cst
br label %inverttop, !dbg !31
inverttop: ; preds = %top
store double %differeturn, double* %"'de", align 8
%6 = load double, double* %"'de", align 8, !dbg !31
%7 = call fastcc { double } @diffecfunc_func(double %0, double %1, double %6), !dbg !31
%8 = extractvalue { double } %7, 0, !dbg !31
%9 = load double, double* %"'de1", align 8, !dbg !31
%10 = fadd fast double %9, %8, !dbg !31
store double %10, double* %"'de1", align 8, !dbg !31
store double 0.000000e+00, double* %"'de", align 8, !dbg !31
%11 = load double, double* %"'de1", align 8
%12 = insertvalue { double } undef, double %11, 0
ret { double } %12
}
(2.08035975524203e-309, nothing)
for the direct call to the function entry in the LLVM IR
click to expand
julia> gradient(Reverse, g, a, Const(b))
after simplification :
; Function Attrs: mustprogress nofree willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_g_9336(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %1) local_unnamed_addr #7 !dbg !46 {
top:
%2 = alloca i64, align 16
%3 = bitcast i64* %2 to i8*
%pgcstack = call {}*** @julia.get_pgcstack() #8
%ptls_field8 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
%4 = bitcast {}*** %ptls_field8 to i64***
%ptls_load910 = load i64**, i64*** %4, align 8, !tbaa !9
%5 = getelementptr inbounds i64*, i64** %ptls_load910, i64 2
%safepoint = load i64*, i64** %5, align 8, !tbaa !13
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #8, !dbg !47
fence syncscope("singlethread") seq_cst
store i64 0, i64* %2, align 16, !dbg !48
%6 = bitcast i64* %2 to double*, !dbg !51
call fastcc void @func(double* noalias nocapture nofree noundef nonnull writeonly align 16 dereferenceable(8) %6, double %0, double %1) #9 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !51
%7 = load double, double* %6, align 16, !dbg !52, !tbaa !31, !alias.scope !35, !noalias !38
ret double %7, !dbg !52
}
after simplification :
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define internal fastcc void @preprocess_func(double* noalias nocapture nofree noundef nonnull writeonly align 16 dereferenceable(8) %retptr, double %arg.a, double %arg.b) unnamed_addr #6 {
entry:
%.6 = fmul double %arg.a, %arg.b
store double %.6, double* %retptr, align 16, !noalias !43
ret void
}
; Function Attrs: mustprogress nofree norecurse nosync nounwind
define internal fastcc { double } @diffefunc(double* noalias nocapture nofree writeonly align 16 dereferenceable(8) %retptr, double* nocapture nofree align 16 %"retptr'", double %arg.a, double %arg.b) unnamed_addr #8 {
entry:
%".6'de" = alloca double, align 8
%0 = getelementptr double, double* %".6'de", i64 0
store double 0.000000e+00, double* %0, align 8
%"arg.a'de" = alloca double, align 8
%1 = getelementptr double, double* %"arg.a'de", i64 0
store double 0.000000e+00, double* %1, align 8
br label %invertentry
invertentry: ; preds = %entry
%2 = load double, double* %"retptr'", align 16, !alias.scope !69, !noalias !72
store double 0.000000e+00, double* %"retptr'", align 16, !alias.scope !69, !noalias !72
%3 = load double, double* %".6'de", align 8
%4 = fadd fast double %3, %2
store double %4, double* %".6'de", align 8
%5 = load double, double* %".6'de", align 8
store double 0.000000e+00, double* %".6'de", align 8
%6 = fmul fast double %5, %arg.b
%7 = load double, double* %"arg.a'de", align 8
%8 = fadd fast double %7, %6
store double %8, double* %"arg.a'de", align 8
%9 = load double, double* %"arg.a'de", align 8
%10 = insertvalue { double } undef, double %9, 0
ret { double } %10
}
; Function Attrs: mustprogress nofree
define internal { double } @diffejulia_g_9336(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %1, double %differeturn) local_unnamed_addr #8 !dbg !55 {
top:
%"'de" = alloca double, align 8
%2 = getelementptr double, double* %"'de", i64 0
store double 0.000000e+00, double* %2, align 8
%"'de1" = alloca double, align 8
%3 = getelementptr double, double* %"'de1", i64 0
store double 0.000000e+00, double* %3, align 8
%"'ipa" = alloca i64, align 16
store i64 0, i64* %"'ipa", align 16
%pgcstack = call {}*** @julia.get_pgcstack() #10
%ptls_field8 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
%4 = bitcast {}*** %ptls_field8 to i64***
%ptls_load910 = load i64**, i64*** %4, align 8, !tbaa !9, !alias.scope !56, !noalias !59
%5 = getelementptr inbounds i64*, i64** %ptls_load910, i64 2
%safepoint = load i64*, i64** %5, align 8, !tbaa !13, !alias.scope !61, !noalias !64
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #10, !dbg !66
fence syncscope("singlethread") seq_cst
%"'ipc" = bitcast i64* %"'ipa" to double*, !dbg !67
br label %inverttop, !dbg !68
inverttop: ; preds = %top
store double %differeturn, double* %"'de", align 8
%6 = load double, double* %"'de", align 8, !dbg !68
store double 0.000000e+00, double* %"'de", align 8, !dbg !68
%7 = load double, double* %"'ipc", align 16, !dbg !68, !tbaa !31, !alias.scope !71, !noalias !74
%8 = fadd fast double %7, %6, !dbg !68
store double %8, double* %"'ipc", align 16, !dbg !68, !tbaa !31, !alias.scope !71, !noalias !74
%9 = call fastcc { double } @diffefunc(double* nocapture nofree writeonly align 16 undef, double* nocapture nofree align 16 %"'ipc", double %0, double %1) [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !67
%10 = extractvalue { double } %9, 0, !dbg !67
%11 = load double, double* %"'de1", align 8, !dbg !67
%12 = fadd fast double %11, %10, !dbg !67
store double %12, double* %"'de1", align 8, !dbg !67
store i64 0, i64* %"'ipa", align 16, !dbg !76, !alias.scope !79, !noalias !80
%13 = load double, double* %"'de1", align 8
%14 = insertvalue { double } undef, double %13, 0
ret { double } %14
}
(3.0, nothing)
The llvmlite
version used by Numba
In [1]: import llvmlite
In [2]: llvmlite.__version__
Out[2]: '0.44.0'
which implies LLVM version 15.x.x
. See the compatibility matrix here
The clang version
$ clang --version
clang version 15.0.7 (https://github.com/conda-forge/clangdev-feedstock 7546975a4a926b2b6b05f442d73827ff01b1ae76)
Target: x86_64-conda-linux-gnu
Thread model: posix
InstalledDir: /home/user001/miniforge3/envs/qruise-toolset/bin
Julia version
julia> versioninfo()
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 12 × AMD Ryzen 5 7640U w/ Radeon 760M Graphics
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver4)
Threads: 1 default, 0 interactive, 1 GC (on 12 virtual cores)
Enzyme version
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.13.65"