Skip to content

Commit b85216a

Browse files
Lower T5 memory usage by a few hundred MB.
1 parent 82cae45 commit b85216a

File tree

3 files changed

+32
-16
lines changed

3 files changed

+32
-16
lines changed

comfy/ldm/hydit/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def forward(self,
355355
if self.use_style_cond:
356356
if style is None:
357357
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
358-
style_embedding = self.style_embedder(style)
358+
style_embedding = self.style_embedder(style, out_dtype=x.dtype)
359359
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
360360

361361
# Concatenate all extra vectors

comfy/ops.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,27 @@
1919
import torch
2020
import comfy.model_management
2121

22+
23+
def cast_to(weight, dtype=None, device=None, non_blocking=False):
24+
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
25+
2226
def cast_to_input(weight, input, non_blocking=False):
23-
return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
27+
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
28+
29+
def cast_bias_weight(s, input=None, dtype=None, device=None):
30+
if input is not None:
31+
if dtype is None:
32+
dtype = input.dtype
33+
if device is None:
34+
device = input.device
2435

25-
def cast_bias_weight(s, input):
2636
bias = None
27-
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
37+
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
2838
if s.bias is not None:
29-
bias = cast_to_input(s.bias, input, non_blocking=non_blocking)
39+
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
3040
if s.bias_function is not None:
3141
bias = s.bias_function(bias)
32-
weight = cast_to_input(s.weight, input, non_blocking=non_blocking)
42+
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
3343
if s.weight_function is not None:
3444
weight = s.weight_function(weight)
3545
return weight, bias
@@ -176,14 +186,19 @@ def reset_parameters(self):
176186
self.bias = None
177187
return None
178188

179-
def forward_comfy_cast_weights(self, input):
180-
weight, bias = cast_bias_weight(self, input)
181-
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
189+
def forward_comfy_cast_weights(self, input, out_dtype=None):
190+
output_dtype = out_dtype
191+
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
192+
out_dtype = None
193+
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
194+
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
182195

183196
def forward(self, *args, **kwargs):
184197
if self.comfy_cast_weights:
185198
return self.forward_comfy_cast_weights(*args, **kwargs)
186199
else:
200+
if "out_dtype" in kwargs:
201+
kwargs.pop("out_dtype")
187202
return super().forward(*args, **kwargs)
188203

189204
@classmethod

comfy/text_encoders/t5.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import math
33
from comfy.ldm.modules.attention import optimized_attention_for_device
4+
import comfy.ops
45

56
class T5LayerNorm(torch.nn.Module):
67
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
@@ -11,7 +12,7 @@ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=No
1112
def forward(self, x):
1213
variance = x.pow(2).mean(-1, keepdim=True)
1314
x = x * torch.rsqrt(variance + self.variance_epsilon)
14-
return self.weight.to(device=x.device, dtype=x.dtype) * x
15+
return comfy.ops.cast_to_input(self.weight, x) * x
1516

1617
activations = {
1718
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
@@ -82,7 +83,7 @@ def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dty
8283
if relative_attention_bias:
8384
self.relative_attention_num_buckets = 32
8485
self.relative_attention_max_distance = 128
85-
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
86+
self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype)
8687

8788
@staticmethod
8889
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
@@ -132,7 +133,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
132133
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
133134
return relative_buckets
134135

135-
def compute_bias(self, query_length, key_length, device):
136+
def compute_bias(self, query_length, key_length, device, dtype):
136137
"""Compute binned relative position bias"""
137138
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
138139
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
@@ -143,7 +144,7 @@ def compute_bias(self, query_length, key_length, device):
143144
num_buckets=self.relative_attention_num_buckets,
144145
max_distance=self.relative_attention_max_distance,
145146
)
146-
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
147+
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
147148
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
148149
return values
149150

@@ -152,7 +153,7 @@ def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
152153
k = self.k(x)
153154
v = self.v(x)
154155
if self.relative_attention_bias is not None:
155-
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
156+
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
156157

157158
if past_bias is not None:
158159
if mask is not None:
@@ -225,7 +226,7 @@ def __init__(self, config_dict, dtype, device, operations):
225226

226227
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
227228
self.dtype = dtype
228-
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
229+
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
229230

230231
def get_input_embeddings(self):
231232
return self.shared
@@ -234,5 +235,5 @@ def set_input_embeddings(self, embeddings):
234235
self.shared = embeddings
235236

236237
def forward(self, input_ids, *args, **kwargs):
237-
x = self.shared(input_ids)
238+
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
238239
return self.encoder(x, *args, **kwargs)

0 commit comments

Comments
 (0)