@@ -64,9 +64,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
6464 return model_options
6565
6666class ModelPatcher :
67- def __init__ (self , model , load_device , offload_device , size = 0 , current_device = None , weight_inplace_update = False ):
67+ def __init__ (self , model , load_device , offload_device , size = 0 , weight_inplace_update = False ):
6868 self .size = size
6969 self .model = model
70+ if not hasattr (self .model , 'device' ):
71+ logging .info ("Model doesn't have a device attribute." )
72+ self .model .device = offload_device
73+ elif self .model .device is None :
74+ self .model .device = offload_device
75+
7076 self .patches = {}
7177 self .backup = {}
7278 self .object_patches = {}
@@ -75,11 +81,6 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No
7581 self .model_size ()
7682 self .load_device = load_device
7783 self .offload_device = offload_device
78- if current_device is None :
79- self .current_device = self .offload_device
80- else :
81- self .current_device = current_device
82-
8384 self .weight_inplace_update = weight_inplace_update
8485 self .model_lowvram = False
8586 self .lowvram_patch_counter = 0
@@ -92,7 +93,7 @@ def model_size(self):
9293 return self .size
9394
9495 def clone (self ):
95- n = ModelPatcher (self .model , self .load_device , self .offload_device , self .size , self . current_device , weight_inplace_update = self .weight_inplace_update )
96+ n = ModelPatcher (self .model , self .load_device , self .offload_device , self .size , weight_inplace_update = self .weight_inplace_update )
9697 n .patches = {}
9798 for k in self .patches :
9899 n .patches [k ] = self .patches [k ][:]
@@ -302,7 +303,7 @@ def patch_model(self, device_to=None, patch_weights=True):
302303
303304 if device_to is not None :
304305 self .model .to (device_to )
305- self .current_device = device_to
306+ self .model . device = device_to
306307
307308 return self .model
308309
@@ -355,6 +356,7 @@ def __call__(self, weight):
355356
356357 self .model_lowvram = True
357358 self .lowvram_patch_counter = patch_counter
359+ self .model .device = device_to
358360 return self .model
359361
360362 def calculate_weight (self , patches , weight , key ):
@@ -551,10 +553,13 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
551553
552554 if device_to is not None :
553555 self .model .to (device_to )
554- self .current_device = device_to
556+ self .model . device = device_to
555557
556558 keys = list (self .object_patches_backup .keys ())
557559 for k in keys :
558560 comfy .utils .set_attr (self .model , k , self .object_patches_backup [k ])
559561
560562 self .object_patches_backup .clear ()
563+
564+ def current_loaded_device (self ):
565+ return self .model .device
0 commit comments