Skip to content

Commit c1713bf

Browse files
Merge pull request #14689 from AUTOMATIC1111/fix-nested-manual-cast
Fix nested manual cast
2 parents 1dbee39 + 4a66d2f commit c1713bf

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

modules/devices.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ def forward_wrapper(self, *args, **kwargs):
164164

165165
@contextlib.contextmanager
166166
def manual_cast(target_dtype):
167+
applied = False
167168
for module_type in patch_module_list:
169+
if hasattr(module_type, "org_forward"):
170+
continue
171+
applied = True
168172
org_forward = module_type.forward
169173
if module_type == torch.nn.MultiheadAttention and has_xpu():
170174
module_type.forward = manual_cast_forward(torch.float32)
@@ -174,8 +178,11 @@ def manual_cast(target_dtype):
174178
try:
175179
yield None
176180
finally:
177-
for module_type in patch_module_list:
178-
module_type.forward = module_type.org_forward
181+
if applied:
182+
for module_type in patch_module_list:
183+
if hasattr(module_type, "org_forward"):
184+
module_type.forward = module_type.org_forward
185+
delattr(module_type, "org_forward")
179186

180187

181188
def autocast(disable=False):

0 commit comments

Comments
 (0)