Skip to content

Improve lead time support for diffusion models #980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Aug 9, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bbaea98
Improve lead time support for diffusion models
jleinonen Jun 17, 2025
232c29c
Update changelog
jleinonen Jun 17, 2025
8b94e59
Add back mistakenly removed docstrings and type hints
jleinonen Jun 17, 2025
f3338b6
Revert a couple more unintended changes
jleinonen Jun 17, 2025
7ba0563
Fix type hint of lead time label
jleinonen Jun 17, 2025
11e4ea5
Fix deterministic samples to allow CorrDiff tests to pass
jleinonen Jun 18, 2025
6f08b02
Rename utils.generative to utils.diffusion
jleinonen Jun 25, 2025
63b70c9
Merge main and resolve conflicts
jleinonen Jun 25, 2025
fd35097
Add back __init__.py in generative
jleinonen Jun 25, 2025
ef8a9a6
Revert unnecessary changes
jleinonen Jun 25, 2025
d0b1bfb
Revert unnecessary changes
jleinonen Jun 25, 2025
4d503c4
Revert unnecessary changes
jleinonen Jun 25, 2025
77a5cde
Merge branch 'NVIDIA:main' into leadtime-fixes
jleinonen Jul 2, 2025
e600803
Merge branch 'NVIDIA:main' into leadtime-fixes
jleinonen Jul 8, 2025
4929596
Merge branch 'main' into leadtime-fixes
CharlelieLrt Jul 10, 2025
457d50d
Minor docstring improvement in SongUNetPosEmdb
CharlelieLrt Jul 10, 2025
522440b
Add value checks and docstrings
jleinonen Jul 11, 2025
fd96890
Update docstrings, add error condition
jleinonen Jul 14, 2025
29ca853
Update docstrings
jleinonen Jul 14, 2025
e87c4c5
Fix lead time tests
jleinonen Aug 7, 2025
e2fdd39
Fix tests after merge
jleinonen Aug 7, 2025
6cec691
Merge branch 'NVIDIA:main' into leadtime-fixes
jleinonen Aug 7, 2025
bf461a8
Update docstring
jleinonen Aug 7, 2025
89a5bd1
Change super().__init__ to use keyword args
jleinonen Aug 8, 2025
2c0cd29
Minor formatting in deterministic_sampler docstring
CharlelieLrt Aug 8, 2025
77b162b
Minor renaming and formatting in loss.py
CharlelieLrt Aug 8, 2025
761468f
Removed dtype casting of pos_emb in SongUNetPosEmbd
CharlelieLrt Aug 8, 2025
948063e
Removed duplicate code in SongUNetPosEmbd.positional_embedding_indexing
CharlelieLrt Aug 8, 2025
58c55af
Refactor positional_embedding_indexing to eliminate dead and duplicat…
CharlelieLrt Aug 8, 2025
fe3052b
Refactor positional_embedding_selector to enable batched lead-time la…
CharlelieLrt Aug 8, 2025
b5992d5
Moved new test from song_unet_pos_embd to song_unet_pos_lt_embd
CharlelieLrt Aug 8, 2025
ebaa43f
Updated CHANGELOG.md
CharlelieLrt Aug 8, 2025
2013cf5
Added safety check to force users to use SongUNetPosLtEmdb for lead-t…
CharlelieLrt Aug 8, 2025
374c9f8
Deleted unecessary test
CharlelieLrt Aug 8, 2025
2b6ac4e
Fixed bug in positional_embedding_selector + changed samplers and tes…
CharlelieLrt Aug 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Improved lead time support for diffusion models
- Improved documentation for diffusion models and diffusion utils.
- Safe API to override `__init__`'s arguments saved in checkpoint file with
`Module.from_checkpoint("chkpt.mdlus", models_args)`.
Expand Down
22 changes: 19 additions & 3 deletions physicsnemo/metrics/diffusion/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,15 @@ def __init__(
self.P_std = P_std
self.sigma_data = sigma_data

def __call__(self, net, images, condition=None, labels=None, augment_pipe=None):
def __call__(
self,
net,
images,
condition=None,
labels=None,
augment_pipe=None,
lead_time_label=None,
):
"""
Calculate and return the loss corresponding to the EDM formulation.

Expand Down Expand Up @@ -258,16 +266,24 @@ def __call__(self, net, images, condition=None, labels=None, augment_pipe=None):
augment_pipe(images) if augment_pipe is not None else (images, None)
)
n = torch.randn_like(y) * sigma
additional_labels = {
"augment_labels": augment_labels,
"lead_time_label": lead_time_label,
}
# drop None items to support models that don't have these arguments in `forward`
additional_labels = {
k: v for (k, v) in additional_labels.items() if v is not None
}
if condition is not None:
D_yn = net(
y + n,
sigma,
condition=condition,
class_labels=labels,
augment_labels=augment_labels,
**additional_labels,
)
else:
D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
D_yn = net(y + n, sigma, labels, **additional_labels)
loss = weight * ((D_yn - y) ** 2)
return loss

Expand Down
71 changes: 43 additions & 28 deletions physicsnemo/models/diffusion/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def __init__(
profile_mode: bool = False,
amp_mode: bool = False,
lead_time_mode: bool = False,
lead_time_channels: int = None,
lead_time_channels: int | None = None,
lead_time_steps: int = 9,
prob_channels: List[int] = [],
):
Expand Down Expand Up @@ -826,7 +826,7 @@ def __init__(

self.gridtype = gridtype
self.N_grid_channels = N_grid_channels
if self.gridtype == "learnable":
if (self.gridtype == "learnable") or (self.N_grid_channels == 0):
self.pos_embd = self._get_positional_embedding()
else:
self.register_buffer("pos_embd", self._get_positional_embedding().float())
Expand All @@ -840,6 +840,8 @@ def __init__(
self.scalar = torch.nn.Parameter(
torch.ones((1, len(self.prob_channels), 1, 1))
)
else:
self.lt_embd = None

def forward(
self,
Expand All @@ -861,11 +863,11 @@ def forward(
"Cannot provide both embedding_selector and global_index."
)

if x.dtype != self.pos_embd.dtype:
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype):
self.pos_embd = self.pos_embd.to(x.dtype)

# Append positional embedding to input conditioning
if self.pos_embd is not None:
if (self.pos_embd is not None) or (self.lt_embd is not None):
# Select positional embeddings with a selector function
if embedding_selector is not None:
selected_pos_embd = self.positional_embedding_selector(
Expand Down Expand Up @@ -918,15 +920,24 @@ def positional_embedding_indexing(
the patches to extract from the positional embedding grid.
:math:`P` is the number of distinct patches in the input tensor ``x``.
The channel dimension should contain :math:`j`, :math:`i` indices that
should represent the indices of the pixels to extract from the embedding grid.
should represent the indices of the pixels to extract from the
embedding grid.
lead_time_label : Optional[torch.Tensor], default=None
Tensor of shape :math:`(P,)` that corresponds to the lead-time label for each patch.
Only used if ``lead_time_mode`` is True.

Returns
-------
torch.Tensor
Selected positional embeddings with shape :math:`(P \times B, C_{PE}, H_{in}, W_{in})`
(same spatial resolution as ``global_index``) if ``global_index`` is provided.
If ``global_index`` is None, the entire positional embedding grid
is duplicated :math:`B` times and returned with shape :math:`(B, C_{PE}, H, W)`.
Selected embeddings with shape :math:`(P \times B, C_{PE} [+
C_{LT}], H_{in}, W_{in})`. :math:`C_{PE}` is the number of
embedding channels in the positional embedding grid, and
:math:`C_{LT}` is the number of embedding channels in the lead-time
embedding grid. If ``lead_time_label`` is provided, the lead-time
embedding channels are included. If ``global_index`` is `None`,
:math:`P = 1` is assumed, and the positional embedding grid is
duplicated :math:`B` times and returned with shape
:math:`(B, C_{PE} [+ C_{LT}], H, W)`.

Example
-------
Expand All @@ -951,7 +962,7 @@ def positional_embedding_indexing(
"""
# If no global indices are provided, select all embeddings and expand
# to match the batch size of the input
if x.dtype != self.pos_embd.dtype:
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype):
self.pos_embd = self.pos_embd.to(x.dtype)

if global_index is None:
Expand Down Expand Up @@ -989,23 +1000,26 @@ def positional_embedding_indexing(
global_index = torch.reshape(
torch.permute(global_index, (1, 0, 2, 3)), (2, -1)
) # (P, 2, X, Y) to (2, P*X*Y)
selected_pos_embd = self.pos_embd[
:, global_index[0], global_index[1]
] # (C_pe, P*X*Y)
selected_pos_embd = torch.permute(
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)),
(1, 0, 2, 3),
) # (P, C_pe, X, Y)

selected_pos_embd = selected_pos_embd.repeat(
B, 1, 1, 1
) # (B*P, C_pe, X, Y)

if self.pos_embd is not None:
selected_pos_embd = self.pos_embd[
:, global_index[0], global_index[1]
] # (C_pe, P*X*Y)
selected_pos_embd = torch.permute(
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)),
(1, 0, 2, 3),
) # (P, C_pe, X, Y)

selected_pos_embd = selected_pos_embd.repeat(
B, 1, 1, 1
) # (B*P, C_pe, X, Y)

embeds = [selected_pos_embd]
else:
embeds = []

# Append positional and lead time embeddings to input conditioning
if self.lead_time_mode:
embeds = []
if self.pos_embd is not None:
embeds.append(selected_pos_embd) # reuse code below
if self.lt_embd is not None:
lt_embds = self.lt_embd[
lead_time_label.int()
Expand All @@ -1026,8 +1040,8 @@ def positional_embedding_indexing(
) # (B*P, C_pe, X, Y)
embeds.append(selected_lt_pos_embd)

if len(embeds) > 0:
selected_pos_embd = torch.cat(embeds, dim=1)
if len(embeds) > 0:
selected_pos_embd = torch.cat(embeds, dim=1)

return selected_pos_embd

Expand Down Expand Up @@ -1090,8 +1104,9 @@ def positional_embedding_selector(
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
>>>
"""
if x.dtype != self.pos_embd.dtype:
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype):
self.pos_embd = self.pos_embd.to(x.dtype)

if lead_time_label is not None:
# all patches share same lead_time_label
embeddings = torch.cat(
Expand Down Expand Up @@ -1331,7 +1346,7 @@ def __init__(
resample_filter: List[int] = [1, 1],
gridtype: str = "sinusoidal",
N_grid_channels: int = 4,
lead_time_channels: int = None,
lead_time_channels: int | None = None,
lead_time_steps: int = 9,
prob_channels: List[int] = [],
checkpoint_level: int = 0,
Expand Down
23 changes: 17 additions & 6 deletions physicsnemo/utils/diffusion/deterministic_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def deterministic_sampler(
S_min: float = 0.0,
S_max: float = float("inf"),
S_noise: float = 1.0,
lead_time_label: torch.Tensor | None = None,
) -> torch.Tensor:
r"""
Generalized sampler, representing the superset of all sampling methods
Expand Down Expand Up @@ -167,6 +168,11 @@ def deterministic_sampler(
# conditioning
x_lr = img_lr

# do not pass lead time labels to nets that may not support them
lead_time_label = (
{} if lead_time_label is None else {"lead_time_label": lead_time_label}
)

if solver not in ["euler", "heun"]:
raise ValueError(f"Unknown solver {solver}")
if discretization not in ["vp", "ve", "iddpm", "edm"]:
Expand Down Expand Up @@ -198,8 +204,7 @@ def deterministic_sampler(
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
ve_sigma_inv = lambda sigma: sigma**2

# Select default noise level range based on the specified
# time step discretization.
# Select default noise level range based on the specified time step discretization.
if sigma_min is None:
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[
Expand Down Expand Up @@ -304,11 +309,12 @@ def deterministic_sampler(
sigma(t_hat),
condition=x_lr,
class_labels=class_labels,
**lead_time_label,
).to(torch.float64)
else:
denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to(
torch.float64
)
denoised = net(
x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels, **lead_time_label
).to(torch.float64)
d_cur = (
sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)
) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
Expand All @@ -326,10 +332,15 @@ def deterministic_sampler(
sigma(t_prime),
condition=x_lr,
class_labels=class_labels,
**lead_time_label,
).to(torch.float64)
else:
denoised = net(
x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels
x_prime / s(t_prime),
x_lr,
sigma(t_prime),
class_labels,
**lead_time_label,
).to(torch.float64)
d_prime = (
sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)
Expand Down
21 changes: 21 additions & 0 deletions test/metrics/diffusion/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ def fake_condition_net(y, sigma, condition, class_labels=None, augment_labels=No
def fake_net(y, sigma, labels, augment_labels=None):
return torch.tensor([1.0])

def fake_condition_net_lt(
y,
sigma,
condition,
class_labels=None,
augment_labels=None,
lead_time_label=None,
):
assert lead_time_label is not None # test that this is properly passed through
return torch.tensor([1.0])

loss_func = EDMLoss()

img = torch.tensor([[[[1.0]]]])
Expand All @@ -160,6 +171,16 @@ def mock_augment_pipe(imgs):
loss_value_with_augmentation = loss_func(fake_net, img, labels, mock_augment_pipe)
assert isinstance(loss_value_with_augmentation, torch.Tensor)

lead_time_label = torch.tensor([1])
loss_value = loss_func(
fake_condition_net_lt,
img,
condition=condition,
labels=labels,
lead_time_label=lead_time_label,
)
assert isinstance(loss_value, torch.Tensor)


# RegressionLoss tests

Expand Down
38 changes: 38 additions & 0 deletions test/models/diffusion/test_song_unet_pos_embd.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,41 @@ def test_son_unet_deploy(device):
assert common.validate_onnx_runtime(
model, (*[input_image, noise_labels, class_labels],)
)


@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
@pytest.mark.parametrize("lead_time_mode", [False, True])
@pytest.mark.parametrize("N_grid_channels", [0, 4])
def test_song_unet_positional_leadtime(device, lead_time_mode, N_grid_channels):
"""Test that both positional and lead-time embeddings can be used independently"""

img_resolution = 16
out_channels = 2
lead_time_channels = 4
lead_time_steps = 2
in_channels = 2 + N_grid_channels + (lead_time_channels if lead_time_mode else 0)
model = UNet(
img_resolution=img_resolution,
in_channels=in_channels,
out_channels=out_channels,
N_grid_channels=N_grid_channels,
lead_time_mode=lead_time_mode,
lead_time_channels=lead_time_channels,
lead_time_steps=lead_time_steps,
).to(device)
noise_labels = torch.randn([2]).to(device)
class_labels = torch.randint(0, 1, (2, 1)).to(device)
input_image = torch.ones([2, 2, 16, 16]).to(device)
lead_time_label = torch.as_tensor([0, 1]).to(device)

assert bool(N_grid_channels) == (model.pos_embd is not None)
assert lead_time_mode == (hasattr(model, "lt_embd") and (model.lt_embd is not None))

if lead_time_mode:
output_image = model(
input_image, noise_labels, class_labels, lead_time_label=lead_time_label
)
else:
output_image = model(input_image, noise_labels, class_labels)

assert output_image.shape == (2, out_channels, img_resolution, img_resolution)
27 changes: 27 additions & 0 deletions test/utils/generative/test_deterministic_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,23 @@ def round_sigma(self, sigma):
return torch.tensor(sigma)


# Version that supports lead time labels
class MockNetLt(MockNet):
def __call__(self, x, img_lr, sigma, class_labels, lead_time_label=None):
return x


# Define a fixture for the network
@pytest.fixture
def mock_net():
return MockNet()


@pytest.fixture
def mock_net_lt():
return MockNetLt()


# Basic functionality test
@import_or_fail("cftime")
def test_deterministic_sampler_output_type_and_shape(mock_net, pytestconfig):
Expand Down Expand Up @@ -177,3 +188,19 @@ def test_deterministic_sampler_scaling_validation(mock_net, scaling, pytestconfi
net=mock_net, latents=latents, img_lr=img_lr, scaling=scaling
)
assert isinstance(output, torch.Tensor)


# Test support for lead time labels
@import_or_fail("cftime")
def test_deterministic_sampler_lead_time(mock_net_lt, pytestconfig):

from physicsnemo.utils.diffusion import deterministic_sampler

latents = torch.randn(1, 3, 64, 64)
img_lr = torch.randn(1, 3, 64, 64)
lt_label = torch.randint(0, 10, (1, 1))

output = deterministic_sampler(
net=mock_net_lt, latents=latents, img_lr=img_lr, lead_time_label=lt_label
)
assert isinstance(output, torch.Tensor)