Skip to content

Commit 6969fc9

Browse files
Make supported_dtypes a priority list.
1 parent cb7c4b4 commit 6969fc9

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

comfy/model_management.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
562562
if model_params * 2 > free_model_memory:
563563
return fp8_dtype
564564

565-
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
566-
if torch.float16 in supported_dtypes:
567-
return torch.float16
568-
if should_use_bf16(device, model_params=model_params, manual_cast=True):
569-
if torch.bfloat16 in supported_dtypes:
570-
return torch.bfloat16
565+
for dt in supported_dtypes:
566+
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
567+
if torch.float16 in supported_dtypes:
568+
return torch.float16
569+
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
570+
if torch.bfloat16 in supported_dtypes:
571+
return torch.bfloat16
572+
573+
for dt in supported_dtypes:
574+
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
575+
if torch.float16 in supported_dtypes:
576+
return torch.float16
577+
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
578+
if torch.bfloat16 in supported_dtypes:
579+
return torch.bfloat16
580+
571581
return torch.float32
572582

573583
# None means no manual cast
@@ -583,13 +593,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
583593
if bf16_supported and weight_dtype == torch.bfloat16:
584594
return None
585595

586-
if fp16_supported and torch.float16 in supported_dtypes:
587-
return torch.float16
596+
for dt in supported_dtypes:
597+
if dt == torch.float16 and fp16_supported:
598+
return torch.float16
599+
if dt == torch.bfloat16 and bf16_supported:
600+
return torch.bfloat16
588601

589-
elif bf16_supported and torch.bfloat16 in supported_dtypes:
590-
return torch.bfloat16
591-
else:
592-
return torch.float32
602+
return torch.float32
593603

594604
def text_encoder_offload_device():
595605
if args.gpu_only:

0 commit comments

Comments
 (0)