Skip to content

Wrong gradient of a cfunc decorated Numba function #2505

@ymardoukhi

Description

@ymardoukhi

Following the correspondence I had with @wsmoses on Julia discourse, the gradient of the cfunc entry in the LLVM IR emitted by a @cfunc decorated Numba function turns out to be wrong. This is not the case when we directly get the gradient of the function in the LLVM IR code.

A small Python script to get the LLVM IR of a @cfunc decorated Numba function. The function is a simple multiplication of two float numbers i.e. $f(a, b) = a b$.

import re
from numba import cfunc, types

sig = types.double(
    types.double,
    types.double,
)

@cfunc(sig)
def func(a, b):
    return a * b

if __name__ == "__main__":

    # get the LLVM IR
    ir_map = func.inspect_llvm()

    # Numba mangles the function's name
    # what follows here is a few regex
    # substitution to get rid of the mangled
    # function name and improve readability

    lines = ir_map.split("\n")

    mangled = None

    # there is a global definition of a mangled
    # function with `XXNumbaEnv` in it. By removing
    # this substring, the mangled function name is
    # recovered.
    filtered_lines = []
    for line in lines:
        alias_match = re.match(r'^(@_[^ ]*NumbaEnv[^ ]*)\s*=', line)
        if alias_match:
            mangled = alias_match.group(1)
            continue
        filtered_lines.append(line)
    
    if mangled is None:
        raise ValueError("Could not find the NumbaEnv alias line.")
    
    # substitute the `XXNumbaEnv` with the empty string.
    mangled = re.sub("@", "", mangled)
    mangled = re.sub(r"\d{2}NumbaEnv", "", mangled)
    
    # substitute the mangled name with a more readable
    # function name. Also change `@cfunc.func` to `@cfunc_func`
    # so we can define the symbol `:cfunc_func` in Julia.
    clean_irs = []
    for line in filtered_lines:
        line = re.sub(mangled, r'func', line)
        line = re.sub(r'cfunc\.func', r'cfunc_func', line)
        clean_irs.append(line + "\n")
    
    ir_map = "".join(clean_irs)
    
    with open("func.ll", "w") as f:
        f.write(ir_map)

Then in the shell

using Libdl
using Enzyme

file = read("func.ll");
run(
    pipeline(
        `clang -x ir - -Xclang -no-opaque-pointers -O3 -fPIC -fembed-bitcode -shared -o libfunc.so`; stdin=IOBuffer(file)
    )
);

const lib = Libdl.dlopen("./libfunc.so")
const f_ptr = Libdl.dlsym(lib, :cfunc_func)
const g_ptr = Libdl.dlsym(lib, :func)

f(a, b) = ccall(f_ptr, Cdouble, (Cdouble, Cdouble), a, b)

function g(a::Cdouble, b::Cdouble)
    result = Ref{Cdouble}()
    exc_info = Ref{Ptr{Nothing}}()

    status = ccall(g_ptr, Cint,
        (Ptr{Cdouble}, Ptr{Ptr{Nothing}}, Cdouble, Cdouble),
        result, exc_info, a, b)

    status == 0 || error("Python exception raised!")
    result[]
end

a = 2.0
b = 3.0

f(a, b) == g(a, b)

gradient(Reverse, f, a, Const(b))
gradient(Reverse, g, a, Const(b))

Which yields the following results

julia> gradient(Reverse, f, a, Const(b))
(2.072225527840143e-309, nothing)

julia> gradient(Reverse, g, a, Const(b))
(3.0, nothing)

Only the latter produces the correct answer.

The Enzyme output for the gradients after setting Enzyme.API.printall!(true).

for the cfunc entry in the LLVM IR

click to expand
julia> gradient(Reverse, f, a, Const(b))
after simplification :
; Function Attrs: mustprogress nofree willreturn memory(read, argmem: none, inaccessiblemem: none)
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f_6619(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %1) local_unnamed_addr #5 !dbg !18 {
top:
  %pgcstack = call {}*** @julia.get_pgcstack() #6
  %ptls_field3 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %2 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %2, align 8, !tbaa !8
  %3 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !12
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #6, !dbg !19
  fence syncscope("singlethread") seq_cst
  %4 = call fastcc double @cfunc_func(double %0, double %1) #7, !dbg !19
  ret double %4, !dbg !19
}

after simplification :
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
define internal fastcc double @preprocess_cfunc_func(double %.1, double %.2) unnamed_addr #3 {
entry:
  %.4 = alloca double, align 8
  store double 0.000000e+00, double* %.4, align 8, !noalias !15
  call fastcc void @func(double* noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(8) %.4, double %.1, double %.2) #6
  %.18 = load double, double* %.4, align 8
  ret double %.18
}

after simplification :
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define internal fastcc void @preprocess_func(double* noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(8) %retptr, double %arg.a, double %arg.b) unnamed_addr #4 {
entry:
  %.6 = fmul double %arg.a, %arg.b
  store double %.6, double* %retptr, align 8, !noalias !15
  ret void
}

; Function Attrs: mustprogress nofree norecurse nosync nounwind
define internal fastcc { double } @diffefunc(double* noalias nocapture nofree writeonly align 8 dereferenceable(8) %retptr, double* nocapture nofree align 8 %"retptr'", double %arg.a, double %arg.b) unnamed_addr #6 {
entry:
  %".6'de" = alloca double, align 8
  %0 = getelementptr double, double* %".6'de", i64 0
  store double 0.000000e+00, double* %0, align 8
  %"arg.a'de" = alloca double, align 8
  %1 = getelementptr double, double* %"arg.a'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  br label %invertentry

invertentry:                                      ; preds = %entry
  %2 = load double, double* %"retptr'", align 8, !alias.scope !27, !noalias !30
  store double 0.000000e+00, double* %"retptr'", align 8, !alias.scope !27, !noalias !30
  %3 = load double, double* %".6'de", align 8
  %4 = fadd fast double %3, %2
  store double %4, double* %".6'de", align 8
  %5 = load double, double* %".6'de", align 8
  store double 0.000000e+00, double* %".6'de", align 8
  %6 = fmul fast double %5, %arg.b
  %7 = load double, double* %"arg.a'de", align 8
  %8 = fadd fast double %7, %6
  store double %8, double* %"arg.a'de", align 8
  %9 = load double, double* %"arg.a'de", align 8
  %10 = insertvalue { double } undef, double %9, 0
  ret { double } %10
}

; Function Attrs: mustprogress nofree norecurse nosync nounwind
define internal fastcc { double } @diffecfunc_func(double %.1, double %.2, double %differeturn) unnamed_addr #6 {
entry:
  %".18'de" = alloca double, align 8
  %0 = getelementptr double, double* %".18'de", i64 0
  store double 0.000000e+00, double* %0, align 8
  %".1'de" = alloca double, align 8
  %1 = getelementptr double, double* %".1'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %".4'ipa" = alloca double, align 8
  store double 0.000000e+00, double* %".4'ipa", align 8
  br label %invertentry

invertentry:                                      ; preds = %entry
  store double %differeturn, double* %".18'de", align 8
  %2 = load double, double* %".18'de", align 8
  store double 0.000000e+00, double* %".18'de", align 8
  %3 = load double, double* %".4'ipa", align 8, !alias.scope !22, !noalias !25
  %4 = fadd fast double %3, %2
  store double %4, double* %".4'ipa", align 8, !alias.scope !22, !noalias !25
  %5 = call fastcc { double } @diffefunc(double* nocapture nofree writeonly align 8 undef, double* nocapture nofree align 8 %".4'ipa", double %.1, double %.2)
  %6 = extractvalue { double } %5, 0
  %7 = load double, double* %".1'de", align 8
  %8 = fadd fast double %7, %6
  store double %8, double* %".1'de", align 8
  store double 0.000000e+00, double* %".4'ipa", align 8, !alias.scope !22, !noalias !27
  %9 = load double, double* %".1'de", align 8
  %10 = insertvalue { double } undef, double %9, 0
  ret { double } %10
}

; Function Attrs: mustprogress nofree
define internal { double } @diffejulia_f_6619(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %1, double %differeturn) local_unnamed_addr #6 !dbg !20 {
top:
  %"'de" = alloca double, align 8
  %2 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %2, align 8
  %"'de1" = alloca double, align 8
  %3 = getelementptr double, double* %"'de1", i64 0
  store double 0.000000e+00, double* %3, align 8
  %pgcstack = call {}*** @julia.get_pgcstack() #9
  %ptls_field3 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %4 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %4, align 8, !tbaa !8, !alias.scope !21, !noalias !24
  %5 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %5, align 8, !tbaa !12, !alias.scope !26, !noalias !29
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #9, !dbg !31
  fence syncscope("singlethread") seq_cst
  br label %inverttop, !dbg !31

inverttop:                                        ; preds = %top
  store double %differeturn, double* %"'de", align 8
  %6 = load double, double* %"'de", align 8, !dbg !31
  %7 = call fastcc { double } @diffecfunc_func(double %0, double %1, double %6), !dbg !31
  %8 = extractvalue { double } %7, 0, !dbg !31
  %9 = load double, double* %"'de1", align 8, !dbg !31
  %10 = fadd fast double %9, %8, !dbg !31
  store double %10, double* %"'de1", align 8, !dbg !31
  store double 0.000000e+00, double* %"'de", align 8, !dbg !31
  %11 = load double, double* %"'de1", align 8
  %12 = insertvalue { double } undef, double %11, 0
  ret { double } %12
}

(2.08035975524203e-309, nothing)

for the direct call to the function entry in the LLVM IR

click to expand
julia> gradient(Reverse, g, a, Const(b))
after simplification :
; Function Attrs: mustprogress nofree willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_g_9336(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %1) local_unnamed_addr #7 !dbg !46 {
top:
  %2 = alloca i64, align 16
  %3 = bitcast i64* %2 to i8*
  %pgcstack = call {}*** @julia.get_pgcstack() #8
  %ptls_field8 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %4 = bitcast {}*** %ptls_field8 to i64***
  %ptls_load910 = load i64**, i64*** %4, align 8, !tbaa !9
  %5 = getelementptr inbounds i64*, i64** %ptls_load910, i64 2
  %safepoint = load i64*, i64** %5, align 8, !tbaa !13
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #8, !dbg !47
  fence syncscope("singlethread") seq_cst
  store i64 0, i64* %2, align 16, !dbg !48
  %6 = bitcast i64* %2 to double*, !dbg !51
  call fastcc void @func(double* noalias nocapture nofree noundef nonnull writeonly align 16 dereferenceable(8) %6, double %0, double %1) #9 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !51
  %7 = load double, double* %6, align 16, !dbg !52, !tbaa !31, !alias.scope !35, !noalias !38
  ret double %7, !dbg !52
}

after simplification :
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define internal fastcc void @preprocess_func(double* noalias nocapture nofree noundef nonnull writeonly align 16 dereferenceable(8) %retptr, double %arg.a, double %arg.b) unnamed_addr #6 {
entry:
  %.6 = fmul double %arg.a, %arg.b
  store double %.6, double* %retptr, align 16, !noalias !43
  ret void
}

; Function Attrs: mustprogress nofree norecurse nosync nounwind
define internal fastcc { double } @diffefunc(double* noalias nocapture nofree writeonly align 16 dereferenceable(8) %retptr, double* nocapture nofree align 16 %"retptr'", double %arg.a, double %arg.b) unnamed_addr #8 {
entry:
  %".6'de" = alloca double, align 8
  %0 = getelementptr double, double* %".6'de", i64 0
  store double 0.000000e+00, double* %0, align 8
  %"arg.a'de" = alloca double, align 8
  %1 = getelementptr double, double* %"arg.a'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  br label %invertentry

invertentry:                                      ; preds = %entry
  %2 = load double, double* %"retptr'", align 16, !alias.scope !69, !noalias !72
  store double 0.000000e+00, double* %"retptr'", align 16, !alias.scope !69, !noalias !72
  %3 = load double, double* %".6'de", align 8
  %4 = fadd fast double %3, %2
  store double %4, double* %".6'de", align 8
  %5 = load double, double* %".6'de", align 8
  store double 0.000000e+00, double* %".6'de", align 8
  %6 = fmul fast double %5, %arg.b
  %7 = load double, double* %"arg.a'de", align 8
  %8 = fadd fast double %7, %6
  store double %8, double* %"arg.a'de", align 8
  %9 = load double, double* %"arg.a'de", align 8
  %10 = insertvalue { double } undef, double %9, 0
  ret { double } %10
}

; Function Attrs: mustprogress nofree
define internal { double } @diffejulia_g_9336(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140356334728768" "enzymejl_parmtype_ref"="0" %1, double %differeturn) local_unnamed_addr #8 !dbg !55 {
top:
  %"'de" = alloca double, align 8
  %2 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %2, align 8
  %"'de1" = alloca double, align 8
  %3 = getelementptr double, double* %"'de1", i64 0
  store double 0.000000e+00, double* %3, align 8
  %"'ipa" = alloca i64, align 16
  store i64 0, i64* %"'ipa", align 16
  %pgcstack = call {}*** @julia.get_pgcstack() #10
  %ptls_field8 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %4 = bitcast {}*** %ptls_field8 to i64***
  %ptls_load910 = load i64**, i64*** %4, align 8, !tbaa !9, !alias.scope !56, !noalias !59
  %5 = getelementptr inbounds i64*, i64** %ptls_load910, i64 2
  %safepoint = load i64*, i64** %5, align 8, !tbaa !13, !alias.scope !61, !noalias !64
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #10, !dbg !66
  fence syncscope("singlethread") seq_cst
  %"'ipc" = bitcast i64* %"'ipa" to double*, !dbg !67
  br label %inverttop, !dbg !68

inverttop:                                        ; preds = %top
  store double %differeturn, double* %"'de", align 8
  %6 = load double, double* %"'de", align 8, !dbg !68
  store double 0.000000e+00, double* %"'de", align 8, !dbg !68
  %7 = load double, double* %"'ipc", align 16, !dbg !68, !tbaa !31, !alias.scope !71, !noalias !74
  %8 = fadd fast double %7, %6, !dbg !68
  store double %8, double* %"'ipc", align 16, !dbg !68, !tbaa !31, !alias.scope !71, !noalias !74
  %9 = call fastcc { double } @diffefunc(double* nocapture nofree writeonly align 16 undef, double* nocapture nofree align 16 %"'ipc", double %0, double %1) [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !67
  %10 = extractvalue { double } %9, 0, !dbg !67
  %11 = load double, double* %"'de1", align 8, !dbg !67
  %12 = fadd fast double %11, %10, !dbg !67
  store double %12, double* %"'de1", align 8, !dbg !67
  store i64 0, i64* %"'ipa", align 16, !dbg !76, !alias.scope !79, !noalias !80
  %13 = load double, double* %"'de1", align 8
  %14 = insertvalue { double } undef, double %13, 0
  ret { double } %14
}

(3.0, nothing)

The llvmlite version used by Numba

In [1]: import llvmlite

In [2]: llvmlite.__version__
Out[2]: '0.44.0'

which implies LLVM version 15.x.x. See the compatibility matrix here

The clang version

$ clang --version
clang version 15.0.7 (https://github.com/conda-forge/clangdev-feedstock 7546975a4a926b2b6b05f442d73827ff01b1ae76)
Target: x86_64-conda-linux-gnu
Thread model: posix
InstalledDir: /home/user001/miniforge3/envs/qruise-toolset/bin

Julia version

julia> versioninfo()
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 × AMD Ryzen 5 7640U w/ Radeon 760M Graphics
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver4)
Threads: 1 default, 0 interactive, 1 GC (on 12 virtual cores)

Enzyme version

name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.13.65"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions