@@ -352,6 +352,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
352352def free_memory (memory_required , device , keep_loaded = []):
353353 unloaded_model = []
354354 can_unload = []
355+ unloaded_models = []
355356
356357 for i in range (len (current_loaded_models ) - 1 , - 1 , - 1 ):
357358 shift_model = current_loaded_models [i ]
@@ -369,7 +370,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
369370 unloaded_model .append (i )
370371
371372 for i in sorted (unloaded_model , reverse = True ):
372- current_loaded_models .pop (i )
373+ unloaded_models . append ( current_loaded_models .pop (i ) )
373374
374375 if len (unloaded_model ) > 0 :
375376 soft_empty_cache ()
@@ -378,6 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
378379 mem_free_total , mem_free_torch = get_free_memory (device , torch_free_too = True )
379380 if mem_free_torch > mem_free_total * 0.25 :
380381 soft_empty_cache ()
382+ return unloaded_models
381383
382384def load_models_gpu (models , memory_required = 0 , force_patch_weights = False , minimum_memory_required = None ):
383385 global vram_state
@@ -421,7 +423,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
421423 for d in devs :
422424 if d != torch .device ("cpu" ):
423425 free_memory (extra_mem , d , models_already_loaded )
424- return
426+ free_mem = get_free_memory (d )
427+ if free_mem < minimum_memory_required :
428+ logging .info ("Unloading models for lowram load." ) #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
429+ models_to_load = free_memory (minimum_memory_required , d )
430+ logging .info ("{} models unloaded." .format (len (models_to_load )))
431+ if len (models_to_load ) == 0 :
432+ return
425433
426434 logging .info (f"Loading { len (models_to_load )} new model{ 's' if len (models_to_load ) > 1 else '' } " )
427435
0 commit comments