Skip to content

Commit d7430a1

Browse files
Add a way to load the diffusion model in fp8 with UNETLoader node.
1 parent f2b80f9 commit d7430a1

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

comfy/sd.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
567567
return (model_patcher, clip, vae, clipvision)
568568

569569

570-
def load_unet_state_dict(sd): #load unet in diffusers or regular format
570+
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
571571

572572
#Allow loading unets from checkpoint files
573573
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
@@ -576,7 +576,6 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
576576
sd = temp_sd
577577

578578
parameters = comfy.utils.calculate_parameters(sd)
579-
unet_dtype = model_management.unet_dtype(model_params=parameters)
580579
load_device = model_management.get_torch_device()
581580
model_config = model_detection.model_config_from_unet(sd, "")
582581

@@ -603,7 +602,11 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
603602
logging.warning("{} {}".format(diffusers_keys[k], k))
604603

605604
offload_device = model_management.unet_offload_device()
606-
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
605+
if dtype is None:
606+
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
607+
else:
608+
unet_dtype = dtype
609+
607610
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
608611
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
609612
model = model_config.get_model(new_sd, "")
@@ -614,9 +617,9 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
614617
logging.info("left over keys in unet: {}".format(left_over))
615618
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
616619

617-
def load_unet(unet_path):
620+
def load_unet(unet_path, dtype=None):
618621
sd = comfy.utils.load_torch_file(unet_path)
619-
model = load_unet_state_dict(sd)
622+
model = load_unet_state_dict(sd, dtype=dtype)
620623
if model is None:
621624
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
622625
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))

nodes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -818,15 +818,17 @@ class UNETLoader:
818818
@classmethod
819819
def INPUT_TYPES(s):
820820
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
821+
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
821822
}}
822823
RETURN_TYPES = ("MODEL",)
823824
FUNCTION = "load_unet"
824825

825826
CATEGORY = "advanced/loaders"
826827

827-
def load_unet(self, unet_name):
828+
def load_unet(self, unet_name, weight_dtype):
829+
weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype]
828830
unet_path = folder_paths.get_full_path("unet", unet_name)
829-
model = comfy.sd.load_unet(unet_path)
831+
model = comfy.sd.load_unet(unet_path, dtype=weight_dtype)
830832
return (model,)
831833

832834
class CLIPLoader:

0 commit comments

Comments
 (0)