-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Labels
bugSomething isn't workingSomething isn't working
Description
When using the pure python padding with torch.compile
, there seems to be some implicit limitation on only using the compiled module for a single nside
value.
Here is a minimal reproducer for the failures with multiple nside
s:
import torch
from earth2grid.healpix import HEALPIX_PAD_XY, pad_with_dim
pad_compiled = torch.compile(pad_with_dim)
def test_pad_compile(batch_size, timesteps, nside, nchannels):
x = torch.rand([batch_size, timesteps, 12*nside*nside, nchannels], dtype=torch.bfloat16, device="cuda")
x_pad = pad_with_dim(x, 1, dim=-2, pixel_order=HEALPIX_PAD_XY)
x_pad_c = pad_compiled(x, 1, dim=-2, pixel_order=HEALPIX_PAD_XY)
return torch.abs(x_pad - x_pad_c).max()
#nsides = [64,] # This works
nsides = [64, 32,] # This fails
for nside in nsides:
print(f"compiled pad error = {test_pad_compile(16, 1, nside, 128)}")
Running this prints this error stemming from torch.compile
:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NameError: name 'OpaqueUnaryFn_sqrt' is not defined
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working