Skip to content

Commit c14ac98

Browse files
Unload models and load them back in lowvram mode no free vram.
1 parent 2894511 commit c14ac98

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

comfy/model_management.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
352352
def free_memory(memory_required, device, keep_loaded=[]):
353353
unloaded_model = []
354354
can_unload = []
355+
unloaded_models = []
355356

356357
for i in range(len(current_loaded_models) -1, -1, -1):
357358
shift_model = current_loaded_models[i]
@@ -369,7 +370,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
369370
unloaded_model.append(i)
370371

371372
for i in sorted(unloaded_model, reverse=True):
372-
current_loaded_models.pop(i)
373+
unloaded_models.append(current_loaded_models.pop(i))
373374

374375
if len(unloaded_model) > 0:
375376
soft_empty_cache()
@@ -378,6 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
378379
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
379380
if mem_free_torch > mem_free_total * 0.25:
380381
soft_empty_cache()
382+
return unloaded_models
381383

382384
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
383385
global vram_state
@@ -421,7 +423,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
421423
for d in devs:
422424
if d != torch.device("cpu"):
423425
free_memory(extra_mem, d, models_already_loaded)
424-
return
426+
free_mem = get_free_memory(d)
427+
if free_mem < minimum_memory_required:
428+
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
429+
models_to_load = free_memory(minimum_memory_required, d)
430+
logging.info("{} models unloaded.".format(len(models_to_load)))
431+
if len(models_to_load) == 0:
432+
return
425433

426434
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
427435

0 commit comments

Comments
 (0)