Skip to content

Pure python padding fails to compile when using multiple nsides #41

@akshaysubr

Description

@akshaysubr

When using the pure python padding with torch.compile, there seems to be some implicit limitation on only using the compiled module for a single nside value.

Here is a minimal reproducer for the failures with multiple nsides:

import torch
from earth2grid.healpix import HEALPIX_PAD_XY, pad_with_dim

pad_compiled = torch.compile(pad_with_dim)

def test_pad_compile(batch_size, timesteps, nside, nchannels):
    x = torch.rand([batch_size, timesteps, 12*nside*nside, nchannels], dtype=torch.bfloat16, device="cuda")
    x_pad = pad_with_dim(x, 1, dim=-2, pixel_order=HEALPIX_PAD_XY)
    x_pad_c = pad_compiled(x, 1, dim=-2, pixel_order=HEALPIX_PAD_XY)
    return torch.abs(x_pad - x_pad_c).max()

#nsides = [64,]  # This works
nsides = [64, 32,]  # This fails
for nside in nsides:
    print(f"compiled pad error = {test_pad_compile(16, 1, nside, 128)}")

Running this prints this error stemming from torch.compile:

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NameError: name 'OpaqueUnaryFn_sqrt' is not defined

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions