Skip to content

Do not override a small subset of env vars #1475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,12 @@ def set_accel_env_vars():
get_accel()


def get_accel_env_vars():
def get_gpu_type_env_vars():
gpu_vars = (
"ASAHI_VISIBLE_DEVICES",
"ASCEND_VISIBLE_DEVICES",
"CUDA_VISIBLE_DEVICES",
"CUDA_LAUNCH_BLOCKING",
"HIP_VISIBLE_DEVICES",
"HSA_VISIBLE_DEVICES",
"HSA_OVERRIDE_GFX_VERSION",
"INTEL_VISIBLE_DEVICES",
"MUSA_VISIBLE_DEVICES",
)
Expand All @@ -543,6 +540,22 @@ def get_accel_env_vars():
return env_vars


def get_accel_env_vars():
# Start with GPU type env vars
env_vars = get_gpu_type_env_vars()

# Add other accelerator-specific vars
accel_vars = (
"CUDA_LAUNCH_BLOCKING",
"HSA_VISIBLE_DEVICES",
"HSA_OVERRIDE_GFX_VERSION",
)
for k in accel_vars:
if k in os.environ:
env_vars[k] = os.environ[k]
return env_vars


def rm_until_substring(input, substring):
pos = input.find(substring)
if pos == -1:
Expand Down Expand Up @@ -627,29 +640,40 @@ def select_cuda_image(config):
raise RuntimeError(f"CUDA version {cuda_version} is not supported. Minimum required version is 12.4.")


def accel_image(config, args):
if args and args.image and len(args.image.split(":")) > 1:
def resolve_image_from_args_and_env(config, args):
"""
Resolves the base image based on arguments, environment variables, and config.
Returns the resolved image string, or None if not found.
"""
if args and getattr(args, "image", None) and len(args.image.split(":")) > 1:
return args.image

if hasattr(args, 'image_override'):
return tagged_image(args.image)

image = os.getenv("RAMALAMA_IMAGE")
if image:
return tagged_image(image)

if config.is_set('image'):
return tagged_image(config['image'])
if config.is_set("image"):
return tagged_image(config["image"])

return None # Defer to the next function for more advanced logic


def accel_image(config, args):
"""
Selects and the appropriate image based on config, arguments, environment.
"""
# Try to resolve using args/environment/config
image = resolve_image_from_args_and_env(config, args)
if image:
return image

conman = config['engine']
images = config['images']
set_accel_env_vars()
env_vars = get_accel_env_vars()

if not env_vars:
gpu_type = None
if gpu_type_env_vars := get_gpu_type_env_vars():
gpu_type, _ = next(iter(gpu_type_env_vars.items()))
else:
gpu_type, _ = next(iter(env_vars.items()))
gpu_type = None

# Get image based on detected GPU type
image = images.get(gpu_type, config["image"])
Expand Down
Loading