-
Notifications
You must be signed in to change notification settings - Fork 6.5k
add Paella (Fast Text-Conditional Discrete Denoising on Vector-Quantized Latent Spaces) #2058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Hey @aengusng8, Super cool! This already looks great :-) Please ping me @pcuenca if you'd like to have a review |
| # Copied from Paella/modules.py | ||
| class ModulatedLayerNorm(nn.Module): | ||
| def __init__(self, num_features, eps=1e-6, channels_first=True): | ||
| super().__init__() | ||
| self.ln = nn.LayerNorm(num_features, eps=eps) | ||
| self.gamma = nn.Parameter(torch.randn(1, 1, 1)) | ||
| self.beta = nn.Parameter(torch.randn(1, 1, 1)) | ||
| self.channels_first = channels_first | ||
|
|
||
| def forward(self, x, w=None): | ||
| x = x.permute(0, 2, 3, 1) if self.channels_first else x | ||
| if w is None: | ||
| x = self.ln(x) | ||
| else: | ||
| x = self.gamma * w * self.ln(x) + self.beta * w | ||
| x = x.permute(0, 3, 1, 2) if self.channels_first else x | ||
| return x | ||
|
|
||
|
|
||
| class ResBlock(nn.Module): | ||
| def __init__(self, c, c_hidden, c_cond=0, c_skip=0, scaler=None, layer_scale_init_value=1e-6): | ||
| super().__init__() | ||
| self.depthwise = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(c, c, kernel_size=3, groups=c)) | ||
| self.ln = ModulatedLayerNorm(c, channels_first=False) | ||
| self.channelwise = nn.Sequential( | ||
| nn.Linear(c + c_skip, c_hidden), | ||
| nn.GELU(), | ||
| nn.Linear(c_hidden, c), | ||
| ) | ||
| self.gamma = ( | ||
| nn.Parameter(layer_scale_init_value * torch.ones(c), requires_grad=True) | ||
| if layer_scale_init_value > 0 | ||
| else None | ||
| ) | ||
| self.scaler = scaler | ||
| if c_cond > 0: | ||
| self.cond_mapper = nn.Linear(c_cond, c) | ||
|
|
||
| def forward(self, x, s=None, skip=None): | ||
| res = x | ||
| x = self.depthwise(x) | ||
| if s is not None: | ||
| if s.size(2) == s.size(3) == 1: | ||
| s = s.expand(-1, -1, x.size(2), x.size(3)) | ||
| elif s.size(2) != x.size(2) or s.size(3) != x.size(3): | ||
| s = nn.functional.interpolate(s, size=x.shape[-2:], mode="bilinear") | ||
| s = self.cond_mapper(s.permute(0, 2, 3, 1)) | ||
| # s = self.cond_mapper(s.permute(0, 2, 3, 1)) | ||
| # if s.size(1) == s.size(2) == 1: | ||
| # s = s.expand(-1, x.size(2), x.size(3), -1) | ||
| x = self.ln(x.permute(0, 2, 3, 1), s) | ||
| if skip is not None: | ||
| x = torch.cat([x, skip.permute(0, 2, 3, 1)], dim=-1) | ||
| x = self.channelwise(x) | ||
| x = self.gamma * x if self.gamma is not None else x | ||
| x = res + x.permute(0, 3, 1, 2) | ||
| if self.scaler is not None: | ||
| x = self.scaler(x) | ||
| return x | ||
|
|
||
|
|
||
| class DenoiseUNet(nn.Module): | ||
| def __init__( | ||
| self, | ||
| num_vec_classes, | ||
| c_hidden=1280, | ||
| c_clip=1024, | ||
| c_r=64, | ||
| down_levels=[4, 8, 16], | ||
| up_levels=[16, 8, 4], | ||
| ): | ||
| super().__init__() | ||
| self.num_vec_classes = num_vec_classes | ||
| self.c_r = c_r | ||
| self.down_levels = down_levels | ||
| self.up_levels = up_levels | ||
| c_levels = [c_hidden // (2**i) for i in reversed(range(len(down_levels)))] | ||
| self.embedding = nn.Embedding(num_vec_classes, c_levels[0]) | ||
|
|
||
| # DOWN BLOCKS | ||
| self.down_blocks = nn.ModuleList() | ||
| for i, num_blocks in enumerate(down_levels): | ||
| blocks = [] | ||
| if i > 0: | ||
| blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) | ||
| for _ in range(num_blocks): | ||
| block = ResBlock(c_levels[i], c_levels[i] * 4, c_clip + c_r) | ||
| block.channelwise[-1].weight.data *= np.sqrt(1 / sum(down_levels)) | ||
| blocks.append(block) | ||
| self.down_blocks.append(nn.ModuleList(blocks)) | ||
|
|
||
| # UP BLOCKS | ||
| self.up_blocks = nn.ModuleList() | ||
| for i, num_blocks in enumerate(up_levels): | ||
| blocks = [] | ||
| for j in range(num_blocks): | ||
| block = ResBlock( | ||
| c_levels[len(c_levels) - 1 - i], | ||
| c_levels[len(c_levels) - 1 - i] * 4, | ||
| c_clip + c_r, | ||
| c_levels[len(c_levels) - 1 - i] if (j == 0 and i > 0) else 0, | ||
| ) | ||
| block.channelwise[-1].weight.data *= np.sqrt(1 / sum(up_levels)) | ||
| blocks.append(block) | ||
| if i < len(up_levels) - 1: | ||
| blocks.append( | ||
| nn.ConvTranspose2d( | ||
| c_levels[len(c_levels) - 1 - i], | ||
| c_levels[len(c_levels) - 2 - i], | ||
| kernel_size=4, | ||
| stride=2, | ||
| padding=1, | ||
| ) | ||
| ) | ||
| self.up_blocks.append(nn.ModuleList(blocks)) | ||
|
|
||
| self.clf = nn.Conv2d(c_levels[0], num_vec_classes, kernel_size=1) | ||
|
|
||
| def gamma(self, r): | ||
| return (r * torch.pi / 2).cos() | ||
|
|
||
| def gen_r_embedding(self, r, max_positions=10000): | ||
| dtype = r.dtype | ||
| r = self.gamma(r) * max_positions | ||
| half_dim = self.c_r // 2 | ||
| emb = math.log(max_positions) / (half_dim - 1) | ||
| emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() | ||
| emb = r[:, None] * emb[None, :] | ||
| emb = torch.cat([emb.sin(), emb.cos()], dim=1) | ||
| if self.c_r % 2 == 1: # zero pad | ||
| emb = nn.functional.pad(emb, (0, 1), mode="constant") | ||
| return emb.to(dtype) | ||
|
|
||
| def _down_encode_(self, x, s): | ||
| level_outputs = [] | ||
| for i, blocks in enumerate(self.down_blocks): | ||
| for block in blocks: | ||
| if isinstance(block, ResBlock): | ||
| # s_level = s[:, 0] | ||
| # s = s[:, 1:] | ||
| x = block(x, s) | ||
| else: | ||
| x = block(x) | ||
| level_outputs.insert(0, x) | ||
| return level_outputs | ||
|
|
||
| def _up_decode(self, level_outputs, s): | ||
| x = level_outputs[0] | ||
| for i, blocks in enumerate(self.up_blocks): | ||
| for j, block in enumerate(blocks): | ||
| if isinstance(block, ResBlock): | ||
| # s_level = s[:, 0] | ||
| # s = s[:, 1:] | ||
| if i > 0 and j == 0: | ||
| x = block(x, s, level_outputs[i]) | ||
| else: | ||
| x = block(x, s) | ||
| else: | ||
| x = block(x) | ||
| return x | ||
|
|
||
| def forward(self, x, c, r): # r is a uniform value between 0 and 1 | ||
| r_embed = self.gen_r_embedding(r) | ||
| x = self.embedding(x).permute(0, 3, 1, 2) | ||
| if len(c.shape) == 2: | ||
| s = torch.cat([c, r_embed], dim=-1)[:, :, None, None] | ||
| else: | ||
| r_embed = r_embed[:, :, None, None].expand(-1, -1, c.size(2), c.size(3)) | ||
| s = torch.cat([c, r_embed], dim=1) | ||
| level_outputs = self._down_encode_(x, s) | ||
| x = self._up_decode(level_outputs, s) | ||
| x = self.clf(x) | ||
| return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @pcuenca (cc @patrickvonplaten), should I use layers, blocks, or models in the diffusers\src\diffusers\models folder to replace some parts of the original Paella model class, or should I keep the original Paella model class unchanged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @aengusng8,
No worries! Thanks a lot for working on this :-)
It would be amazing if you could try to "mold" your code into the existing UNet2DConditionModel class:
| class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): |
Also we've just added a design philosophy that might help: https://huggingface.co/docs/diffusers/main/en/conceptual/philosophy
So it be super cool if you could gauge whether it's possible to "force" the whole modeling code into UNet2DConditionModel - feel free to design your own, new unet up and down class
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Any updates? Is this based on https://github.com/dome272/Paella ? |
Hi, @patrickvonplaten. This is my draft PR that we recently mentioned in Discord. I am 90% complete to move Paella to our library, and I think I need your help to finalize this progress.
What is done?
"Pipeline" and "Scheduler" are ready to run, check my Kaggle notebook: https://www.kaggle.com/code/aengusng/notebookd7ca68b633/notebook
Note: run this by CPU only in Colab, or CPU/GPU in Kaggle.
Current bottleneck problems?
I have a few questions that I would appreciate your help with:
diffusers\src\diffusers\modelsfolder to replace some parts of the original Paella model class, or should I keep the original Paella model class unchanged?einops,rudalle, andopen_clip_torch, since they are part of the author's code?vqvaeis initialized fromrudalle.get_vae, theirtext_encoderandtokenizerare initialized fromopen_clip, and How to save and upload the model class/configurations ofvqvae,text_encoder, andtokenizerthat are outside of Diffusers (like this https://huggingface.co/CompVis/stable-diffusion-v1-4)?What is next?
Updated: Closed this PR because comparing internal and external models takes time and deliberation.