@@ -567,7 +567,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
567567 return (model_patcher , clip , vae , clipvision )
568568
569569
570- def load_unet_state_dict (sd ): #load unet in diffusers or regular format
570+ def load_unet_state_dict (sd , dtype = None ): #load unet in diffusers or regular format
571571
572572 #Allow loading unets from checkpoint files
573573 diffusion_model_prefix = model_detection .unet_prefix_from_state_dict (sd )
@@ -576,7 +576,6 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
576576 sd = temp_sd
577577
578578 parameters = comfy .utils .calculate_parameters (sd )
579- unet_dtype = model_management .unet_dtype (model_params = parameters )
580579 load_device = model_management .get_torch_device ()
581580 model_config = model_detection .model_config_from_unet (sd , "" )
582581
@@ -603,7 +602,11 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
603602 logging .warning ("{} {}" .format (diffusers_keys [k ], k ))
604603
605604 offload_device = model_management .unet_offload_device ()
606- unet_dtype = model_management .unet_dtype (model_params = parameters , supported_dtypes = model_config .supported_inference_dtypes )
605+ if dtype is None :
606+ unet_dtype = model_management .unet_dtype (model_params = parameters , supported_dtypes = model_config .supported_inference_dtypes )
607+ else :
608+ unet_dtype = dtype
609+
607610 manual_cast_dtype = model_management .unet_manual_cast (unet_dtype , load_device , model_config .supported_inference_dtypes )
608611 model_config .set_inference_dtype (unet_dtype , manual_cast_dtype )
609612 model = model_config .get_model (new_sd , "" )
@@ -614,9 +617,9 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
614617 logging .info ("left over keys in unet: {}" .format (left_over ))
615618 return comfy .model_patcher .ModelPatcher (model , load_device = load_device , offload_device = offload_device )
616619
617- def load_unet (unet_path ):
620+ def load_unet (unet_path , dtype = None ):
618621 sd = comfy .utils .load_torch_file (unet_path )
619- model = load_unet_state_dict (sd )
622+ model = load_unet_state_dict (sd , dtype = dtype )
620623 if model is None :
621624 logging .error ("ERROR UNSUPPORTED UNET {}" .format (unet_path ))
622625 raise RuntimeError ("ERROR: Could not detect model type of: {}" .format (unet_path ))
0 commit comments