|
7 | 7 | import torch.nn as nn |
8 | 8 | from .. import attention |
9 | 9 | from einops import rearrange, repeat |
| 10 | +from .util import timestep_embedding |
10 | 11 |
|
11 | 12 | def default(x, y): |
12 | 13 | if x is not None: |
@@ -230,34 +231,8 @@ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device |
230 | 231 | ) |
231 | 232 | self.frequency_embedding_size = frequency_embedding_size |
232 | 233 |
|
233 | | - @staticmethod |
234 | | - def timestep_embedding(t, dim, max_period=10000): |
235 | | - """ |
236 | | - Create sinusoidal timestep embeddings. |
237 | | - :param t: a 1-D Tensor of N indices, one per batch element. |
238 | | - These may be fractional. |
239 | | - :param dim: the dimension of the output. |
240 | | - :param max_period: controls the minimum frequency of the embeddings. |
241 | | - :return: an (N, D) Tensor of positional embeddings. |
242 | | - """ |
243 | | - half = dim // 2 |
244 | | - freqs = torch.exp( |
245 | | - -math.log(max_period) |
246 | | - * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) |
247 | | - / half |
248 | | - ) |
249 | | - args = t[:, None].float() * freqs[None] |
250 | | - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
251 | | - if dim % 2: |
252 | | - embedding = torch.cat( |
253 | | - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 |
254 | | - ) |
255 | | - if torch.is_floating_point(t): |
256 | | - embedding = embedding.to(dtype=t.dtype) |
257 | | - return embedding |
258 | | - |
259 | 234 | def forward(self, t, dtype, **kwargs): |
260 | | - t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) |
| 235 | + t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) |
261 | 236 | t_emb = self.mlp(t_freq) |
262 | 237 | return t_emb |
263 | 238 |
|
|
0 commit comments