16
16
from functools import lru_cache
17
17
from typing import TYPE_CHECKING , Callable , List , Literal , Protocol , cast , get_args
18
18
19
+ import yaml
20
+
19
21
import ramalama .amdkfd as amdkfd
20
22
from ramalama .logger import logger
21
23
from ramalama .version import version
@@ -127,7 +129,7 @@ def exec_cmd(args, stdout2null: bool = False, stderr2null: bool = False):
127
129
raise
128
130
129
131
130
- def run_cmd (args , cwd = None , stdout = subprocess .PIPE , ignore_stderr = False , ignore_all = False ):
132
+ def run_cmd (args , cwd = None , stdout = subprocess .PIPE , ignore_stderr = False , ignore_all = False , encoding = None ):
131
133
"""
132
134
Run the given command arguments.
133
135
@@ -137,6 +139,7 @@ def run_cmd(args, cwd=None, stdout=subprocess.PIPE, ignore_stderr=False, ignore_
137
139
stdout: standard output configuration
138
140
ignore_stderr: if True, ignore standard error
139
141
ignore_all: if True, ignore both standard output and standard error
142
+ encoding: encoding to apply to the result text
140
143
"""
141
144
logger .debug (f"run_cmd: { quoted (args )} " )
142
145
logger .debug (f"Working directory: { cwd } " )
@@ -151,7 +154,7 @@ def run_cmd(args, cwd=None, stdout=subprocess.PIPE, ignore_stderr=False, ignore_
151
154
if ignore_all :
152
155
sout = subprocess .DEVNULL
153
156
154
- result = subprocess .run (args , check = True , cwd = cwd , stdout = sout , stderr = serr )
157
+ result = subprocess .run (args , check = True , cwd = cwd , stdout = sout , stderr = serr , encoding = encoding )
155
158
logger .debug (f"Command finished with return code: { result .returncode } " )
156
159
157
160
return result
@@ -225,34 +228,56 @@ def engine_version(engine: SUPPORTED_ENGINES) -> str:
225
228
return run_cmd (cmd_args ).stdout .decode ("utf-8" ).strip ()
226
229
227
230
228
- def resolve_cdi (spec_dirs : List [str ]):
229
- """Loads all CDI specs from the given directories."""
231
+ def load_cdi_config (spec_dirs : List [str ]) -> dict :
232
+ """Load the first YAML or JSON CDI configuration file found in the given directories."""
230
233
for spec_dir in spec_dirs :
231
234
for root , _ , files in os .walk (spec_dir ):
232
235
for file in files :
233
- if file .endswith ('.json' ) or file .endswith ('.yaml' ):
234
- if load_spec (os .path .join (root , file )):
235
- return True
236
-
237
- return False
238
-
236
+ _ , ext = os .path .splitext (file )
237
+ file_path = os .path .join (root , file )
238
+ if ext == ".yaml" or ext == ".yml" :
239
+ try :
240
+ with open (file_path , "r" ) as stream :
241
+ config = yaml .safe_load (stream )
242
+ return config
243
+ except Exception :
244
+ continue
245
+ elif ext == ".json" :
246
+ try :
247
+ with open (file_path , "r" ) as stream :
248
+ config = json .load (stream )
249
+ return config
250
+ except json .JSONDecodeError :
251
+ continue
252
+ except UnicodeDecodeError :
253
+ continue
254
+ return None
239
255
240
- def yaml_safe_load (stream ) -> dict :
241
- data = {}
242
- for line in stream :
243
- if ':' in line :
244
- key , value = line .split (':' , 1 )
245
- data [key .strip ()] = value .strip ()
246
256
247
- return data
257
+ def find_in_cdi (devices : List [str ]) -> (List [str ], List [str ]):
258
+ # Attempt to find CDI configuration for each device in devices and
259
+ # return lists of configured and unconfigured devices.
260
+ cdi = load_cdi_config (['/etc/cdi' , '/var/run/cdi' ])
261
+ cdi_devices = cdi ["devices" ] if cdi else []
262
+ cdi_device_names = [cdi_device ["name" ] for cdi_device in cdi_devices ]
248
263
264
+ logger .debug (f"cdi_device_names: { ',' .join (cdi_device_names )} " )
249
265
250
- def load_spec (path : str ):
251
- """Loads a single CDI spec file."""
252
- with open (path , 'r' ) as f :
253
- spec = json .load (f ) if path .endswith ('.json' ) else yaml_safe_load (f )
266
+ configured = []
267
+ unconfigured = []
268
+ for device in devices :
269
+ if device in cdi_device_names :
270
+ logger .debug (f"device { device } found" )
271
+ configured .append (device )
272
+ # A device can be specified by a prefix of the uuid
273
+ elif device .startswith ("GPU" ) and any (name .startswith (device ) for name in cdi_device_names ):
274
+ logger .debug (f"device { device } found" )
275
+ configured .append (device )
276
+ else :
277
+ perror (f"Device { device } does not have a CDI configuration" )
278
+ unconfigured .append (device )
254
279
255
- return spec . get ( 'kind' )
280
+ return configured , unconfigured
256
281
257
282
258
283
def check_asahi () -> Literal ["asahi" ] | None :
@@ -278,27 +303,37 @@ def check_metal(args: ContainerArgType) -> bool:
278
303
@lru_cache (maxsize = 1 )
279
304
def check_nvidia () -> Literal ["cuda" ] | None :
280
305
try :
281
- command = ['nvidia-smi' ]
282
- run_cmd (command ).stdout .decode ("utf-8" )
283
-
284
- # ensure at least one CDI device resolves
285
- if resolve_cdi (['/etc/cdi' , '/var/run/cdi' ]):
286
- if "CUDA_VISIBLE_DEVICES" not in os .environ :
287
- dev_command = ['nvidia-smi' , '--query-gpu=index' , '--format=csv,noheader' ]
288
- try :
289
- result = run_cmd (dev_command )
290
- output = result .stdout .decode ("utf-8" ).strip ()
291
- if not output :
292
- raise ValueError ("nvidia-smi returned empty GPU indices" )
293
- devices = ',' .join (output .split ('\n ' ))
294
- except Exception :
295
- devices = "0"
296
-
297
- os .environ ["CUDA_VISIBLE_DEVICES" ] = devices
298
-
299
- return "cuda"
300
- except Exception :
301
- pass
306
+ command = ['nvidia-smi' , '--query-gpu=index,uuid' , '--format=csv,noheader' ]
307
+ result = run_cmd (command , encoding = "utf-8" )
308
+ except Exception as e :
309
+ # If nvidia-smi failed to run for some reason, default to all devices
310
+ perror (f"Unable to get device(s) from nvidia-smi: { e } " )
311
+ return "all"
312
+
313
+ smi_lines = result .stdout .splitlines ()
314
+ indices , uuids = zip (* [[item .strip () for item in line .split (',' )] for line in smi_lines if line ])
315
+ # Get the list of devices specified by CUDA_VISIBLE_DEVICES, if any
316
+ cuda_visible_devices = os .environ ["CUDA_VISIBLE_DEVICES" ] if "CUDA_VISIBLE_DEVICES" in os .environ else "all"
317
+ visible_devices = list (cuda_visible_devices .split (',' ) if cuda_visible_devices else uuids )
318
+
319
+ logger .debug (f"visible devices { ',' .join (visible_devices )} " )
320
+ configured , unconfigured = find_in_cdi (visible_devices )
321
+
322
+ if unconfigured :
323
+ print (f'No CDI configuration found for { "," .join (unconfigured )} ' )
324
+ print ("You can use the \" nvidia-ctk cdi generate\" command from the " )
325
+ print ("nvidia-container-toolkit to generate a CDI configuration." )
326
+ print ("See ramalama-cuda(7)." )
327
+ return None
328
+ elif configured :
329
+ logger .debug (f"configured devices: { ',' .join (configured )} " )
330
+ # If the only configured device is "any", default to 0
331
+ if "all" in configured :
332
+ configured .remove ("all" )
333
+ if not configured :
334
+ configured = ["0" ]
335
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = ',' .join (configured )
336
+ return "cuda"
302
337
303
338
return None
304
339
0 commit comments