Skip to content

Commit 767c009

Browse files
committed
Improve NVIDIA GPU detection.
Allow GPUs to be specified by UUID as well as index since the index is not guaranteed to persist across reboots. Crosscheck requested GPUs with nvidia-smi and CDI configuration. If any requested GPUs lack corresponding CDI configuration, print a message with a pointer to documentation. If the only GPU specified in the CDI configuration is "all", as appears to be the case on WSL2, use "all" as the default. Add an optional encoding argument to run_cmd() to facilitate checking the output of the command. Add pyYAML as a dependency for parsing the CDI configuration. Signed-off-by: John Wiele <[email protected]>
1 parent 64e22ee commit 767c009

File tree

2 files changed

+79
-43
lines changed

2 files changed

+79
-43
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ requires-python = ">=3.10"
1111
keywords = ["ramalama", "llama", "AI"]
1212
dependencies = [
1313
"argcomplete",
14+
"pyYAML",
1415
]
1516
maintainers = [
1617
{ name="Dan Walsh", email = "[email protected]" },

ramalama/common.py

Lines changed: 78 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from functools import lru_cache
1717
from typing import TYPE_CHECKING, Callable, List, Literal, Protocol, cast, get_args
1818

19+
import yaml
20+
1921
import ramalama.amdkfd as amdkfd
2022
from ramalama.logger import logger
2123
from ramalama.version import version
@@ -127,7 +129,7 @@ def exec_cmd(args, stdout2null: bool = False, stderr2null: bool = False):
127129
raise
128130

129131

130-
def run_cmd(args, cwd=None, stdout=subprocess.PIPE, ignore_stderr=False, ignore_all=False):
132+
def run_cmd(args, cwd=None, stdout=subprocess.PIPE, ignore_stderr=False, ignore_all=False, encoding=None):
131133
"""
132134
Run the given command arguments.
133135
@@ -137,6 +139,7 @@ def run_cmd(args, cwd=None, stdout=subprocess.PIPE, ignore_stderr=False, ignore_
137139
stdout: standard output configuration
138140
ignore_stderr: if True, ignore standard error
139141
ignore_all: if True, ignore both standard output and standard error
142+
encoding: encoding to apply to the result text
140143
"""
141144
logger.debug(f"run_cmd: {quoted(args)}")
142145
logger.debug(f"Working directory: {cwd}")
@@ -151,7 +154,7 @@ def run_cmd(args, cwd=None, stdout=subprocess.PIPE, ignore_stderr=False, ignore_
151154
if ignore_all:
152155
sout = subprocess.DEVNULL
153156

154-
result = subprocess.run(args, check=True, cwd=cwd, stdout=sout, stderr=serr)
157+
result = subprocess.run(args, check=True, cwd=cwd, stdout=sout, stderr=serr, encoding=encoding)
155158
logger.debug(f"Command finished with return code: {result.returncode}")
156159

157160
return result
@@ -225,34 +228,56 @@ def engine_version(engine: SUPPORTED_ENGINES) -> str:
225228
return run_cmd(cmd_args).stdout.decode("utf-8").strip()
226229

227230

228-
def resolve_cdi(spec_dirs: List[str]):
229-
"""Loads all CDI specs from the given directories."""
231+
def load_cdi_config(spec_dirs: List[str]) -> dict:
232+
"""Load the first YAML or JSON CDI configuration file found in the given directories."""
230233
for spec_dir in spec_dirs:
231234
for root, _, files in os.walk(spec_dir):
232235
for file in files:
233-
if file.endswith('.json') or file.endswith('.yaml'):
234-
if load_spec(os.path.join(root, file)):
235-
return True
236-
237-
return False
238-
236+
_, ext = os.path.splitext(file)
237+
file_path = os.path.join(root, file)
238+
if ext == ".yaml" or ext == ".yml":
239+
try:
240+
with open(file_path, "r") as stream:
241+
config = yaml.safe_load(stream)
242+
return config
243+
except Exception:
244+
continue
245+
elif ext == ".json":
246+
try:
247+
with open(file_path, "r") as stream:
248+
config = json.load(stream)
249+
return config
250+
except json.JSONDecodeError:
251+
continue
252+
except UnicodeDecodeError:
253+
continue
254+
return None
239255

240-
def yaml_safe_load(stream) -> dict:
241-
data = {}
242-
for line in stream:
243-
if ':' in line:
244-
key, value = line.split(':', 1)
245-
data[key.strip()] = value.strip()
246256

247-
return data
257+
def find_in_cdi(devices: List[str]) -> (List[str], List[str]):
258+
# Attempt to find CDI configuration for each device in devices and
259+
# return lists of configured and unconfigured devices.
260+
cdi = load_cdi_config(['/etc/cdi', '/var/run/cdi'])
261+
cdi_devices = cdi["devices"] if cdi else []
262+
cdi_device_names = [cdi_device["name"] for cdi_device in cdi_devices]
248263

264+
logger.debug(f"cdi_device_names: {','.join(cdi_device_names)}")
249265

250-
def load_spec(path: str):
251-
"""Loads a single CDI spec file."""
252-
with open(path, 'r') as f:
253-
spec = json.load(f) if path.endswith('.json') else yaml_safe_load(f)
266+
configured = []
267+
unconfigured = []
268+
for device in devices:
269+
if device in cdi_device_names:
270+
logger.debug(f"device {device} found")
271+
configured.append(device)
272+
# A device can be specified by a prefix of the uuid
273+
elif device.startswith("GPU") and any(name.startswith(device) for name in cdi_device_names):
274+
logger.debug(f"device {device} found")
275+
configured.append(device)
276+
else:
277+
perror(f"Device {device} does not have a CDI configuration")
278+
unconfigured.append(device)
254279

255-
return spec.get('kind')
280+
return configured, unconfigured
256281

257282

258283
def check_asahi() -> Literal["asahi"] | None:
@@ -278,27 +303,37 @@ def check_metal(args: ContainerArgType) -> bool:
278303
@lru_cache(maxsize=1)
279304
def check_nvidia() -> Literal["cuda"] | None:
280305
try:
281-
command = ['nvidia-smi']
282-
run_cmd(command).stdout.decode("utf-8")
283-
284-
# ensure at least one CDI device resolves
285-
if resolve_cdi(['/etc/cdi', '/var/run/cdi']):
286-
if "CUDA_VISIBLE_DEVICES" not in os.environ:
287-
dev_command = ['nvidia-smi', '--query-gpu=index', '--format=csv,noheader']
288-
try:
289-
result = run_cmd(dev_command)
290-
output = result.stdout.decode("utf-8").strip()
291-
if not output:
292-
raise ValueError("nvidia-smi returned empty GPU indices")
293-
devices = ','.join(output.split('\n'))
294-
except Exception:
295-
devices = "0"
296-
297-
os.environ["CUDA_VISIBLE_DEVICES"] = devices
298-
299-
return "cuda"
300-
except Exception:
301-
pass
306+
command = ['nvidia-smi', '--query-gpu=index,uuid', '--format=csv,noheader']
307+
result = run_cmd(command, encoding="utf-8")
308+
except Exception as e:
309+
# If nvidia-smi failed to run for some reason, default to all devices
310+
perror(f"Unable to get device(s) from nvidia-smi: {e}")
311+
return "all"
312+
313+
smi_lines = result.stdout.splitlines()
314+
indices, uuids = zip(*[[item.strip() for item in line.split(',')] for line in smi_lines if line])
315+
# Get the list of devices specified by CUDA_VISIBLE_DEVICES, if any
316+
cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "all"
317+
visible_devices = list(cuda_visible_devices.split(',') if cuda_visible_devices else uuids)
318+
319+
logger.debug(f"visible devices {','.join(visible_devices)}")
320+
configured, unconfigured = find_in_cdi(visible_devices)
321+
322+
if unconfigured:
323+
print(f'No CDI configuration found for {",".join(unconfigured)}')
324+
print("You can use the \"nvidia-ctk cdi generate\" command from the ")
325+
print("nvidia-container-toolkit to generate a CDI configuration.")
326+
print("See ramalama-cuda(7).")
327+
return None
328+
elif configured:
329+
logger.debug(f"configured devices: {','.join(configured)}")
330+
# If the only configured device is "any", default to 0
331+
if "all" in configured:
332+
configured.remove("all")
333+
if not configured:
334+
configured = ["0"]
335+
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(configured)
336+
return "cuda"
302337

303338
return None
304339

0 commit comments

Comments
 (0)