1+ """
2+ This file is part of ComfyUI.
3+ Copyright (C) 2024 Comfy
4+
5+ This program is free software: you can redistribute it and/or modify
6+ it under the terms of the GNU General Public License as published by
7+ the Free Software Foundation, either version 3 of the License, or
8+ (at your option) any later version.
9+
10+ This program is distributed in the hope that it will be useful,
11+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+ GNU General Public License for more details.
14+
15+ You should have received a copy of the GNU General Public License
16+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17+ """
18+
119import torch
220import copy
321import inspect
422import logging
523import uuid
24+ import collections
625
726import comfy .utils
827import comfy .model_management
@@ -63,6 +82,21 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
6382 model_options ["disable_cfg1_optimization" ] = True
6483 return model_options
6584
85+ def wipe_lowvram_weight (m ):
86+ if hasattr (m , "prev_comfy_cast_weights" ):
87+ m .comfy_cast_weights = m .prev_comfy_cast_weights
88+ del m .prev_comfy_cast_weights
89+ m .weight_function = None
90+ m .bias_function = None
91+
92+ class LowVramPatch :
93+ def __init__ (self , key , model_patcher ):
94+ self .key = key
95+ self .model_patcher = model_patcher
96+ def __call__ (self , weight ):
97+ return self .model_patcher .calculate_weight (self .model_patcher .patches [self .key ], weight , self .key )
98+
99+
66100class ModelPatcher :
67101 def __init__ (self , model , load_device , offload_device , size = 0 , weight_inplace_update = False ):
68102 self .size = size
@@ -82,16 +116,29 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
82116 self .load_device = load_device
83117 self .offload_device = offload_device
84118 self .weight_inplace_update = weight_inplace_update
85- self .model_lowvram = False
86- self .lowvram_patch_counter = 0
87119 self .patches_uuid = uuid .uuid4 ()
88120
121+ if not hasattr (self .model , 'model_loaded_weight_memory' ):
122+ self .model .model_loaded_weight_memory = 0
123+
124+ if not hasattr (self .model , 'lowvram_patch_counter' ):
125+ self .model .lowvram_patch_counter = 0
126+
127+ if not hasattr (self .model , 'model_lowvram' ):
128+ self .model .model_lowvram = False
129+
89130 def model_size (self ):
90131 if self .size > 0 :
91132 return self .size
92133 self .size = comfy .model_management .module_size (self .model )
93134 return self .size
94135
136+ def loaded_size (self ):
137+ return self .model .model_loaded_weight_memory
138+
139+ def lowvram_patch_counter (self ):
140+ return self .model .lowvram_patch_counter
141+
95142 def clone (self ):
96143 n = ModelPatcher (self .model , self .load_device , self .offload_device , self .size , weight_inplace_update = self .weight_inplace_update )
97144 n .patches = {}
@@ -265,16 +312,16 @@ def model_state_dict(self, filter_prefix=None):
265312 sd .pop (k )
266313 return sd
267314
268- def patch_weight_to_device (self , key , device_to = None ):
315+ def patch_weight_to_device (self , key , device_to = None , inplace_update = False ):
269316 if key not in self .patches :
270317 return
271318
272319 weight = comfy .utils .get_attr (self .model , key )
273320
274- inplace_update = self .weight_inplace_update
321+ inplace_update = self .weight_inplace_update or inplace_update
275322
276323 if key not in self .backup :
277- self .backup [key ] = weight .to (device = self .offload_device , copy = inplace_update )
324+ self .backup [key ] = collections . namedtuple ( 'Dimension' , [ ' weight' , 'inplace_update' ])( weight .to (device = self .offload_device , copy = inplace_update ), inplace_update )
278325
279326 if device_to is not None :
280327 temp_weight = comfy .model_management .cast_to_device (weight , device_to , torch .float32 , copy = True )
@@ -304,28 +351,24 @@ def patch_model(self, device_to=None, patch_weights=True):
304351 if device_to is not None :
305352 self .model .to (device_to )
306353 self .model .device = device_to
354+ self .model .model_loaded_weight_memory = self .model_size ()
307355
308356 return self .model
309357
310- def patch_model_lowvram (self , device_to = None , lowvram_model_memory = 0 , force_patch_weights = False ):
311- self .patch_model (device_to , patch_weights = False )
312-
358+ def lowvram_load (self , device_to = None , lowvram_model_memory = 0 , force_patch_weights = False ):
313359 logging .info ("loading in lowvram mode {}" .format (lowvram_model_memory / (1024 * 1024 )))
314- class LowVramPatch :
315- def __init__ (self , key , model_patcher ):
316- self .key = key
317- self .model_patcher = model_patcher
318- def __call__ (self , weight ):
319- return self .model_patcher .calculate_weight (self .model_patcher .patches [self .key ], weight , self .key )
320-
321360 mem_counter = 0
322361 patch_counter = 0
362+ lowvram_counter = 0
323363 for n , m in self .model .named_modules ():
324364 lowvram_weight = False
325365 if hasattr (m , "comfy_cast_weights" ):
326366 module_mem = comfy .model_management .module_size (m )
327367 if mem_counter + module_mem >= lowvram_model_memory :
328368 lowvram_weight = True
369+ lowvram_counter += 1
370+ if m .comfy_cast_weights :
371+ continue
329372
330373 weight_key = "{}.weight" .format (n )
331374 bias_key = "{}.bias" .format (n )
@@ -347,16 +390,31 @@ def __call__(self, weight):
347390 m .prev_comfy_cast_weights = m .comfy_cast_weights
348391 m .comfy_cast_weights = True
349392 else :
393+ if hasattr (m , "comfy_cast_weights" ):
394+ if m .comfy_cast_weights :
395+ wipe_lowvram_weight (m )
396+
350397 if hasattr (m , "weight" ):
351- self .patch_weight_to_device (weight_key ) #TODO: speed this up without causing OOM
398+ mem_counter += comfy .model_management .module_size (m )
399+ if m .weight is not None and m .weight .device == device_to :
400+ continue
401+ self .patch_weight_to_device (weight_key ) #TODO: speed this up without OOM
352402 self .patch_weight_to_device (bias_key )
353403 m .to (device_to )
354- mem_counter += comfy .model_management .module_size (m )
355404 logging .debug ("lowvram: loaded module regularly {} {}" .format (n , m ))
356405
357- self .model_lowvram = True
358- self .lowvram_patch_counter = patch_counter
406+ if lowvram_counter > 0 :
407+ self .model .model_lowvram = True
408+ else :
409+ self .model .model_lowvram = False
410+ self .model .lowvram_patch_counter += patch_counter
359411 self .model .device = device_to
412+ self .model .model_loaded_weight_memory = mem_counter
413+
414+
415+ def patch_model_lowvram (self , device_to = None , lowvram_model_memory = 0 , force_patch_weights = False ):
416+ self .patch_model (device_to , patch_weights = False )
417+ self .lowvram_load (device_to , lowvram_model_memory = lowvram_model_memory , force_patch_weights = force_patch_weights )
360418 return self .model
361419
362420 def calculate_weight (self , patches , weight , key ):
@@ -529,37 +587,86 @@ def calculate_weight(self, patches, weight, key):
529587
530588 def unpatch_model (self , device_to = None , unpatch_weights = True ):
531589 if unpatch_weights :
532- if self .model_lowvram :
590+ if self .model . model_lowvram :
533591 for m in self .model .modules ():
534- if hasattr (m , "prev_comfy_cast_weights" ):
535- m .comfy_cast_weights = m .prev_comfy_cast_weights
536- del m .prev_comfy_cast_weights
537- m .weight_function = None
538- m .bias_function = None
592+ wipe_lowvram_weight (m )
539593
540- self .model_lowvram = False
541- self .lowvram_patch_counter = 0
594+ self .model . model_lowvram = False
595+ self .model . lowvram_patch_counter = 0
542596
543597 keys = list (self .backup .keys ())
544598
545- if self . weight_inplace_update :
546- for k in keys :
547- comfy . utils . copy_to_param ( self . model , k , self . backup [ k ])
548- else :
549- for k in keys :
550- comfy .utils .set_attr_param (self .model , k , self . backup [ k ] )
599+ for k in keys :
600+ bk = self . backup [ k ]
601+ if bk . inplace_update :
602+ comfy . utils . copy_to_param ( self . model , k , bk . weight )
603+ else :
604+ comfy .utils .set_attr_param (self .model , k , bk . weight )
551605
552606 self .backup .clear ()
553607
554608 if device_to is not None :
555609 self .model .to (device_to )
556610 self .model .device = device_to
611+ self .model .model_loaded_weight_memory = 0
557612
558613 keys = list (self .object_patches_backup .keys ())
559614 for k in keys :
560615 comfy .utils .set_attr (self .model , k , self .object_patches_backup [k ])
561616
562617 self .object_patches_backup .clear ()
563618
619+ def partially_unload (self , device_to , memory_to_free = 0 ):
620+ memory_freed = 0
621+ patch_counter = 0
622+
623+ for n , m in list (self .model .named_modules ())[::- 1 ]:
624+ if memory_to_free < memory_freed :
625+ break
626+
627+ shift_lowvram = False
628+ if hasattr (m , "comfy_cast_weights" ):
629+ module_mem = comfy .model_management .module_size (m )
630+ weight_key = "{}.weight" .format (n )
631+ bias_key = "{}.bias" .format (n )
632+
633+
634+ if m .weight is not None and m .weight .device != device_to :
635+ for key in [weight_key , bias_key ]:
636+ bk = self .backup .get (key , None )
637+ if bk is not None :
638+ if bk .inplace_update :
639+ comfy .utils .copy_to_param (self .model , key , bk .weight )
640+ else :
641+ comfy .utils .set_attr_param (self .model , key , bk .weight )
642+ self .backup .pop (key )
643+
644+ m .to (device_to )
645+ if weight_key in self .patches :
646+ m .weight_function = LowVramPatch (weight_key , self )
647+ patch_counter += 1
648+ if bias_key in self .patches :
649+ m .bias_function = LowVramPatch (bias_key , self )
650+ patch_counter += 1
651+
652+ m .prev_comfy_cast_weights = m .comfy_cast_weights
653+ m .comfy_cast_weights = True
654+ memory_freed += module_mem
655+ logging .debug ("freed {}" .format (n ))
656+
657+ self .model .model_lowvram = True
658+ self .model .lowvram_patch_counter += patch_counter
659+ self .model .model_loaded_weight_memory -= memory_freed
660+ return memory_freed
661+
662+ def partially_load (self , device_to , extra_memory = 0 ):
663+ if self .model .model_lowvram == False :
664+ return 0
665+ if self .model .model_loaded_weight_memory + extra_memory > self .model_size ():
666+ pass #TODO: Full load
667+ current_used = self .model .model_loaded_weight_memory
668+ self .lowvram_load (device_to , lowvram_model_memory = current_used + extra_memory )
669+ return self .model .model_loaded_weight_memory - current_used
670+
564671 def current_loaded_device (self ):
565672 return self .model .device
0 commit comments