@@ -228,7 +228,7 @@ def engine_version(engine: SUPPORTED_ENGINES) -> str:
228
228
return run_cmd (cmd_args ).stdout .decode ("utf-8" ).strip ()
229
229
230
230
231
- def load_cdi_config (spec_dirs : List [str ]) -> dict :
231
+ def load_cdi_config (spec_dirs : List [str ]) -> dict | None :
232
232
"""Load the first YAML or JSON CDI configuration file found in the given directories."""
233
233
for spec_dir in spec_dirs :
234
234
for root , _ , files in os .walk (spec_dir ):
@@ -240,7 +240,7 @@ def load_cdi_config(spec_dirs: List[str]) -> dict:
240
240
with open (file_path , "r" ) as stream :
241
241
config = yaml .safe_load (stream )
242
242
return config
243
- except Exception :
243
+ except ( yaml . YAMLError , OSError ) :
244
244
continue
245
245
elif ext == ".json" :
246
246
try :
@@ -254,12 +254,12 @@ def load_cdi_config(spec_dirs: List[str]) -> dict:
254
254
return None
255
255
256
256
257
- def find_in_cdi (devices : List [str ]) -> ( List [str ], List [str ]) :
257
+ def find_in_cdi (devices : List [str ]) -> tuple [ List [str ], List [str ]] :
258
258
# Attempt to find CDI configuration for each device in devices and
259
259
# return lists of configured and unconfigured devices.
260
260
cdi = load_cdi_config (['/etc/cdi' , '/var/run/cdi' ])
261
- cdi_devices = cdi [ "devices" ] if cdi else []
262
- cdi_device_names = [cdi_device [ " name" ] for cdi_device in cdi_devices ]
261
+ cdi_devices = cdi . get ( "devices" , []) if cdi else []
262
+ cdi_device_names = [name for cdi_device in cdi_devices if ( name := cdi_device . get ( "name" )) ]
263
263
264
264
logger .debug (f"cdi_device_names: { ',' .join (cdi_device_names )} " )
265
265
@@ -301,7 +301,7 @@ def check_metal(args: ContainerArgType) -> bool:
301
301
302
302
303
303
@lru_cache (maxsize = 1 )
304
- def check_nvidia () -> Literal ["cuda" ] | None :
304
+ def check_nvidia () -> Literal ["cuda" , "all" ] | None :
305
305
try :
306
306
command = ['nvidia-smi' , '--query-gpu=index,uuid' , '--format=csv,noheader' ]
307
307
result = run_cmd (command , encoding = "utf-8" )
@@ -311,9 +311,10 @@ def check_nvidia() -> Literal["cuda"] | None:
311
311
return "all"
312
312
313
313
smi_lines = result .stdout .splitlines ()
314
- indices , uuids = zip (* [[item .strip () for item in line .split (',' )] for line in smi_lines if line ])
314
+ parsed_lines = [[item .strip () for item in line .split (',' )] for line in smi_lines if line ]
315
+ indices , uuids = zip (* parsed_lines ) if parsed_lines else (tuple (), tuple ())
315
316
# Get the list of devices specified by CUDA_VISIBLE_DEVICES, if any
316
- cuda_visible_devices = os .environ [ "CUDA_VISIBLE_DEVICES" ] if "CUDA_VISIBLE_DEVICES" in os . environ else " all"
317
+ cuda_visible_devices = os .environ . get ( "CUDA_VISIBLE_DEVICES" , " all")
317
318
visible_devices = list (cuda_visible_devices .split (',' ) if cuda_visible_devices else uuids )
318
319
319
320
logger .debug (f"visible devices { ',' .join (visible_devices )} " )
0 commit comments