Skip to content

Commit ff446f9

Browse files
committed
Do not override a small subset of env vars
RamaLama does not try to detect GPU if the user has already set certain env vars. Make this list smaller. Signed-off-by: Eric Curtin <[email protected]>
1 parent ef7bd2a commit ff446f9

File tree

1 file changed

+40
-16
lines changed

1 file changed

+40
-16
lines changed

ramalama/common.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -526,15 +526,12 @@ def set_accel_env_vars():
526526
get_accel()
527527

528528

529-
def get_accel_env_vars():
529+
def get_gpu_type_env_vars():
530530
gpu_vars = (
531531
"ASAHI_VISIBLE_DEVICES",
532532
"ASCEND_VISIBLE_DEVICES",
533533
"CUDA_VISIBLE_DEVICES",
534-
"CUDA_LAUNCH_BLOCKING",
535534
"HIP_VISIBLE_DEVICES",
536-
"HSA_VISIBLE_DEVICES",
537-
"HSA_OVERRIDE_GFX_VERSION",
538535
"INTEL_VISIBLE_DEVICES",
539536
"MUSA_VISIBLE_DEVICES",
540537
)
@@ -543,6 +540,22 @@ def get_accel_env_vars():
543540
return env_vars
544541

545542

543+
def get_accel_env_vars():
544+
# Start with GPU type env vars
545+
env_vars = get_gpu_type_env_vars()
546+
547+
# Add other accelerator-specific vars
548+
accel_vars = (
549+
"CUDA_LAUNCH_BLOCKING",
550+
"HSA_VISIBLE_DEVICES",
551+
"HSA_OVERRIDE_GFX_VERSION",
552+
)
553+
for k in accel_vars:
554+
if k in os.environ:
555+
env_vars[k] = os.environ[k]
556+
return env_vars
557+
558+
546559
def rm_until_substring(input, substring):
547560
pos = input.find(substring)
548561
if pos == -1:
@@ -627,29 +640,40 @@ def select_cuda_image(config):
627640
raise RuntimeError(f"CUDA version {cuda_version} is not supported. Minimum required version is 12.4.")
628641

629642

630-
def accel_image(config, args):
631-
if args and args.image and len(args.image.split(":")) > 1:
643+
def resolve_image_from_args_and_env(config, args):
644+
"""
645+
Resolves the base image based on arguments, environment variables, and config.
646+
Returns the resolved image string, or None if not found.
647+
"""
648+
if args and getattr(args, "image", None) and len(args.image.split(":")) > 1:
632649
return args.image
633650

634-
if hasattr(args, 'image_override'):
635-
return tagged_image(args.image)
636-
637651
image = os.getenv("RAMALAMA_IMAGE")
638652
if image:
639653
return tagged_image(image)
640654

641-
if config.is_set('image'):
642-
return tagged_image(config['image'])
655+
if config.is_set("image"):
656+
return tagged_image(config["image"])
657+
658+
return None # Defer to the next function for more advanced logic
659+
660+
661+
def accel_image(config, args):
662+
"""
663+
Selects and the appropriate image based on config, arguments, environment.
664+
"""
665+
# Try to resolve using args/environment/config
666+
image = resolve_image_from_args_and_env(config, args)
667+
if image:
668+
return image
643669

644670
conman = config['engine']
645671
images = config['images']
646672
set_accel_env_vars()
647-
env_vars = get_accel_env_vars()
648-
649-
if not env_vars:
650-
gpu_type = None
673+
if gpu_type_env_vars := get_gpu_type_env_vars():
674+
gpu_type, _ = next(iter(gpu_type_env_vars.items()))
651675
else:
652-
gpu_type, _ = next(iter(env_vars.items()))
676+
gpu_type = None
653677

654678
# Get image based on detected GPU type
655679
image = images.get(gpu_type, config["image"])

0 commit comments

Comments
 (0)