@@ -379,11 +379,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
379379 if mem_free_torch > mem_free_total * 0.25 :
380380 soft_empty_cache ()
381381
382- def load_models_gpu (models , memory_required = 0 , force_patch_weights = False ):
382+ def load_models_gpu (models , memory_required = 0 , force_patch_weights = False , minimum_memory_required = None ):
383383 global vram_state
384384
385385 inference_memory = minimum_inference_memory ()
386386 extra_mem = max (inference_memory , memory_required )
387+ if minimum_memory_required is None :
388+ minimum_memory_required = extra_mem
389+ else :
390+ minimum_memory_required = max (inference_memory , minimum_memory_required )
387391
388392 models = set (models )
389393
@@ -446,8 +450,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
446450 if lowvram_available and (vram_set_state == VRAMState .LOW_VRAM or vram_set_state == VRAMState .NORMAL_VRAM ):
447451 model_size = loaded_model .model_memory_required (torch_dev )
448452 current_free_mem = get_free_memory (torch_dev )
449- lowvram_model_memory = int (max (64 * (1024 * 1024 ), (current_free_mem - extra_mem )))
450- if model_size <= ( current_free_mem - inference_memory ) : #only switch to lowvram if really necessary
453+ lowvram_model_memory = int (max (64 * (1024 * 1024 ), (current_free_mem - minimum_memory_required )))
454+ if model_size <= lowvram_model_memory : #only switch to lowvram if really necessary
451455 lowvram_model_memory = 0
452456
453457 if vram_set_state == VRAMState .NO_VRAM :
0 commit comments