Skip to content

Add Learned PE selection for Auraflow #9182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 15, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ def __init__(
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size

def pe_selection_index_based_on_dim(self, h, w):
# select subset of positional embedding based on H, W, where H, W is size of latent
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size ** 0.5), int(self.pos_embed_max_size ** 0.5)
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
starth = h_max // 2 - h_p // 2
endh = starth + h_p
startw = w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[ starth:endh, startw:endw]
return original_pe_indexes.flatten()

def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
Expand All @@ -80,7 +95,8 @@ def forward(self, latent):
)
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
latent = self.proj(latent)
return latent + self.pos_embed
pe_index = self.pe_selection_index_based_on_dim(height, width)
return latent + self.pos_embed[:, pe_index]


# Taken from the original Aura flow inference code.
Expand Down
Loading