@@ -526,15 +526,12 @@ def set_accel_env_vars():
526
526
get_accel ()
527
527
528
528
529
- def get_accel_env_vars ():
529
+ def get_gpu_type_env_vars ():
530
530
gpu_vars = (
531
531
"ASAHI_VISIBLE_DEVICES" ,
532
532
"ASCEND_VISIBLE_DEVICES" ,
533
533
"CUDA_VISIBLE_DEVICES" ,
534
- "CUDA_LAUNCH_BLOCKING" ,
535
534
"HIP_VISIBLE_DEVICES" ,
536
- "HSA_VISIBLE_DEVICES" ,
537
- "HSA_OVERRIDE_GFX_VERSION" ,
538
535
"INTEL_VISIBLE_DEVICES" ,
539
536
"MUSA_VISIBLE_DEVICES" ,
540
537
)
@@ -543,6 +540,22 @@ def get_accel_env_vars():
543
540
return env_vars
544
541
545
542
543
+ def get_accel_env_vars ():
544
+ # Start with GPU type env vars
545
+ env_vars = get_gpu_type_env_vars ()
546
+
547
+ # Add other accelerator-specific vars
548
+ accel_vars = (
549
+ "CUDA_LAUNCH_BLOCKING" ,
550
+ "HSA_VISIBLE_DEVICES" ,
551
+ "HSA_OVERRIDE_GFX_VERSION" ,
552
+ )
553
+ for k in accel_vars :
554
+ if k in os .environ :
555
+ env_vars [k ] = os .environ [k ]
556
+ return env_vars
557
+
558
+
546
559
def rm_until_substring (input , substring ):
547
560
pos = input .find (substring )
548
561
if pos == - 1 :
@@ -627,29 +640,40 @@ def select_cuda_image(config):
627
640
raise RuntimeError (f"CUDA version { cuda_version } is not supported. Minimum required version is 12.4." )
628
641
629
642
630
- def accel_image (config , args ):
631
- if args and args .image and len (args .image .split (":" )) > 1 :
643
+ def resolve_image_from_args_and_env (config , args ):
644
+ """
645
+ Resolves the base image based on arguments, environment variables, and config.
646
+ Returns the resolved image string, or None if not found.
647
+ """
648
+ if args and getattr (args , "image" , None ) and len (args .image .split (":" )) > 1 :
632
649
return args .image
633
650
634
- if hasattr (args , 'image_override' ):
635
- return tagged_image (args .image )
636
-
637
651
image = os .getenv ("RAMALAMA_IMAGE" )
638
652
if image :
639
653
return tagged_image (image )
640
654
641
- if config .is_set ('image' ):
642
- return tagged_image (config ['image' ])
655
+ if config .is_set ("image" ):
656
+ return tagged_image (config ["image" ])
657
+
658
+ return None # Defer to the next function for more advanced logic
659
+
660
+
661
+ def accel_image (config , args ):
662
+ """
663
+ Selects and the appropriate image based on config, arguments, environment.
664
+ """
665
+ # Try to resolve using args/environment/config
666
+ image = resolve_image_from_args_and_env (config , args )
667
+ if image :
668
+ return image
643
669
644
670
conman = config ['engine' ]
645
671
images = config ['images' ]
646
672
set_accel_env_vars ()
647
- env_vars = get_accel_env_vars ()
648
-
649
- if not env_vars :
650
- gpu_type = None
673
+ if gpu_type_env_vars := get_gpu_type_env_vars ():
674
+ gpu_type , _ = next (iter (gpu_type_env_vars .items ()))
651
675
else :
652
- gpu_type , _ = next ( iter ( env_vars . items ()))
676
+ gpu_type = None
653
677
654
678
# Get image based on detected GPU type
655
679
image = images .get (gpu_type , config ["image" ])
0 commit comments