@@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
562562 if model_params * 2 > free_model_memory :
563563 return fp8_dtype
564564
565- if should_use_fp16 (device = device , model_params = model_params , manual_cast = True ):
566- if torch .float16 in supported_dtypes :
567- return torch .float16
568- if should_use_bf16 (device , model_params = model_params , manual_cast = True ):
569- if torch .bfloat16 in supported_dtypes :
570- return torch .bfloat16
565+ for dt in supported_dtypes :
566+ if dt == torch .float16 and should_use_fp16 (device = device , model_params = model_params ):
567+ if torch .float16 in supported_dtypes :
568+ return torch .float16
569+ if dt == torch .bfloat16 and should_use_bf16 (device , model_params = model_params ):
570+ if torch .bfloat16 in supported_dtypes :
571+ return torch .bfloat16
572+
573+ for dt in supported_dtypes :
574+ if dt == torch .float16 and should_use_fp16 (device = device , model_params = model_params , manual_cast = True ):
575+ if torch .float16 in supported_dtypes :
576+ return torch .float16
577+ if dt == torch .bfloat16 and should_use_bf16 (device , model_params = model_params , manual_cast = True ):
578+ if torch .bfloat16 in supported_dtypes :
579+ return torch .bfloat16
580+
571581 return torch .float32
572582
573583# None means no manual cast
@@ -583,13 +593,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
583593 if bf16_supported and weight_dtype == torch .bfloat16 :
584594 return None
585595
586- if fp16_supported and torch .float16 in supported_dtypes :
587- return torch .float16
596+ for dt in supported_dtypes :
597+ if dt == torch .float16 and fp16_supported :
598+ return torch .float16
599+ if dt == torch .bfloat16 and bf16_supported :
600+ return torch .bfloat16
588601
589- elif bf16_supported and torch .bfloat16 in supported_dtypes :
590- return torch .bfloat16
591- else :
592- return torch .float32
602+ return torch .float32
593603
594604def text_encoder_offload_device ():
595605 if args .gpu_only :
0 commit comments