-
Notifications
You must be signed in to change notification settings - Fork 349
Open
Description
In the current implementation, we get a NVRTC compilation error when a kernel with multiple overloads is compiled and strip_hash=True
:
from typing import Any
import warp as wp
wp.config.verbose = True
@wp.kernel
def scale(x: wp.array(dtype=Any), s: Any):
i = wp.tid()
x[i] = s * x[i]
scale_f16 = wp.overload(scale, [wp.array(dtype=wp.float16), wp.float16])
scale_f32 = wp.overload(scale, [wp.array(dtype=wp.float32), wp.float32])
scale_f64 = wp.overload(scale, [wp.array(dtype=wp.float64), wp.float64])
wp.compile_aot_module(__name__, wp.get_device(), strip_hash=True)
Output:
Warp NVRTC compilation error 6: NVRTC_ERROR_COMPILATION (warp/warp/native/warp.cu:3622)
wp___main__.cu(111): error: more than one instance of overloaded function "scale_cuda_kernel_forward" has "C" linkage
extern "C" __global__ void scale_cuda_kernel_forward(
^
wp___main__.cu(144): error: more than one instance of overloaded function "scale_cuda_kernel_backward" has "C" linkage
extern "C" __global__ void scale_cuda_kernel_backward(
^
wp___main__.cu(195): error: more than one instance of overloaded function "scale_cuda_kernel_forward" has "C" linkage
extern "C" __global__ void scale_cuda_kernel_forward(
^
wp___main__.cu(228): error: more than one instance of overloaded function "scale_cuda_kernel_backward" has "C" linkage
extern "C" __global__ void scale_cuda_kernel_backward(
This is not a bug but a known restriction when strip_hash=True
, but we can make the error easier to understand by detecting it in ModuleHasher
.