Skip to content

Commit eb43ed3

Browse files
authored
Merge pull request #1764 from ramalama-labs/imp/typing
Typing and bug squashes
2 parents 1845ef9 + 9ec66d5 commit eb43ed3

23 files changed

+154
-115
lines changed

Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,8 @@ docs:
111111
.PHONY: lint
112112
lint:
113113
ifneq (,$(wildcard /usr/bin/python3))
114-
/usr/bin/python3 -m compileall -q .
114+
/usr/bin/python3 -m compileall -q -x '\.venv' .
115115
endif
116-
117116
! grep -ri --exclude-dir ".venv" --exclude-dir "*/.venv" "#\!/usr/bin/python3" .
118117
flake8 $(PROJECT_DIR) $(PYTHON_SCRIPTS)
119118
shellcheck *.sh */*.sh */*/*.sh

ramalama/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self, args: ChatArgsType, operational_args: ChatOperationalArgs | N
9696
operational_args = ChatOperationalArgs()
9797

9898
super().__init__()
99-
self.conversation_history = []
99+
self.conversation_history: list[dict] = []
100100
self.args = args
101101
self.request_in_process = False
102102
self.prompt = args.prefix

ramalama/cli.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
try:
1717
import argcomplete
1818

19-
suppressCompleter = argcomplete.completers.SuppressCompleter
19+
suppressCompleter: type[argcomplete.completers.SuppressCompleter] | None = argcomplete.completers.SuppressCompleter
2020
except Exception:
2121
suppressCompleter = None
2222

@@ -44,7 +44,6 @@
4444

4545

4646
class ParsedGenerateInput:
47-
4847
def __init__(self, gen_type: str, output_dir: str):
4948
self.gen_type = gen_type
5049
self.output_dir = output_dir
@@ -1235,7 +1234,6 @@ def inspect_cli(args):
12351234

12361235

12371236
def main():
1238-
12391237
def eprint(e, exit_code):
12401238
perror("Error: " + str(e).strip("'\""))
12411239
sys.exit(exit_code)

ramalama/common.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
import string
1414
import subprocess
1515
import sys
16+
from collections.abc import Callable, Iterable
1617
from functools import lru_cache
17-
from typing import TYPE_CHECKING, Callable, List, Literal, Protocol, cast, get_args
18+
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast, get_args
1819

1920
import ramalama.amdkfd as amdkfd
2021
from ramalama.logger import logger
@@ -230,15 +231,23 @@ def engine_version(engine: SUPPORTED_ENGINES) -> str:
230231
return run_cmd(cmd_args).stdout.decode("utf-8").strip()
231232

232233

233-
def load_cdi_yaml(stream) -> dict:
234+
class CDI_DEVICE(TypedDict):
235+
name: str
236+
237+
238+
class CDI_RETURN_TYPE(TypedDict):
239+
devices: list[CDI_DEVICE]
240+
241+
242+
def load_cdi_yaml(stream: Iterable[str]) -> CDI_RETURN_TYPE:
234243
# Returns a dict containing just the "devices" key, whose value is
235244
# a list of dicts, each mapping the key "name" to a device name.
236245
# For example: {'devices': [{'name': 'all'}]}
237246
# This depends on the key "name" being unique to the list of dicts
238247
# under "devices" and the value of the "name" key being on the
239248
# same line following a colon.
240249

241-
data = {"devices": []}
250+
data: CDI_RETURN_TYPE = {"devices": []}
242251
for line in stream:
243252
if ':' in line:
244253
key, value = line.split(':', 1)
@@ -247,7 +256,7 @@ def load_cdi_yaml(stream) -> dict:
247256
return data
248257

249258

250-
def load_cdi_config(spec_dirs: List[str]) -> dict | None:
259+
def load_cdi_config(spec_dirs: list[str]) -> CDI_RETURN_TYPE | None:
251260
# Loads the first YAML or JSON CDI configuration file found in the
252261
# given directories."""
253262

@@ -275,7 +284,7 @@ def load_cdi_config(spec_dirs: List[str]) -> dict | None:
275284
return None
276285

277286

278-
def find_in_cdi(devices: List[str]) -> tuple[List[str], List[str]]:
287+
def find_in_cdi(devices: list[str]) -> tuple[list[str], list[str]]:
279288
# Attempts to find a CDI configuration for each device in devices
280289
# and returns a list of configured devices and a list of
281290
# unconfigured devices.
@@ -327,11 +336,12 @@ def check_nvidia() -> Literal["cuda"] | None:
327336
return None
328337

329338
smi_lines = result.stdout.splitlines()
330-
parsed_lines = [[item.strip() for item in line.split(',')] for line in smi_lines if line]
339+
parsed_lines: list[list[str]] = [[item.strip() for item in line.split(',')] for line in smi_lines if line]
340+
331341
if not parsed_lines:
332342
return None
333343

334-
indices, uuids = zip(*parsed_lines) if parsed_lines else (tuple(), tuple())
344+
indices, uuids = map(list, zip(*parsed_lines))
335345
# Get the list of devices specified by CUDA_VISIBLE_DEVICES, if any
336346
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
337347
visible_devices = cuda_visible_devices.split(',') if cuda_visible_devices else []
@@ -342,14 +352,14 @@ def check_nvidia() -> Literal["cuda"] | None:
342352

343353
configured, unconfigured = find_in_cdi(visible_devices + ["all"])
344354

345-
if unconfigured and "all" not in configured:
355+
if unconfigured and not (configured_has_all := "all" in configured):
346356
perror(f"No CDI configuration found for {','.join(unconfigured)}")
347357
perror("You can use the \"nvidia-ctk cdi generate\" command from the ")
348358
perror("nvidia-container-toolkit to generate a CDI configuration.")
349359
perror("See ramalama-cuda(7).")
350360
return None
351361
elif configured:
352-
if "all" in configured:
362+
if configured_has_all:
353363
configured.remove("all")
354364
if not configured:
355365
configured = indices
@@ -442,7 +452,7 @@ def check_mthreads() -> Literal["musa"] | None:
442452
return None
443453

444454

445-
AccelType = Literal["asahi", "cuda", "cann", "hip", "intel", "musa"]
455+
AccelType: TypeAlias = Literal["asahi", "cuda", "cann", "hip", "intel", "musa"]
446456

447457

448458
def get_accel() -> AccelType | Literal["none"]:
@@ -474,7 +484,7 @@ def set_gpu_type_env_vars():
474484
get_accel()
475485

476486

477-
GPUEnvVar = Literal[
487+
GPUEnvVar: TypeAlias = Literal[
478488
"ASAHI_VISIBLE_DEVICES",
479489
"ASCEND_VISIBLE_DEVICES",
480490
"CUDA_VISIBLE_DEVICES",
@@ -486,10 +496,10 @@ def set_gpu_type_env_vars():
486496

487497

488498
def get_gpu_type_env_vars() -> dict[GPUEnvVar, str]:
489-
return {k: os.environ[k] for k in get_args(GPUEnvVar) if k in os.environ}
499+
return {k: v for k in get_args(GPUEnvVar) if (v := os.environ.get(k))}
490500

491501

492-
AccelEnvVar = Literal[
502+
AccelEnvVar: TypeAlias = Literal[
493503
"CUDA_LAUNCH_BLOCKING",
494504
"HSA_VISIBLE_DEVICES",
495505
"HSA_OVERRIDE_GFX_VERSION",
@@ -498,7 +508,7 @@ def get_gpu_type_env_vars() -> dict[GPUEnvVar, str]:
498508

499509
def get_accel_env_vars() -> dict[GPUEnvVar | AccelEnvVar, str]:
500510
gpu_env_vars: dict[GPUEnvVar, str] = get_gpu_type_env_vars()
501-
accel_env_vars: dict[AccelEnvVar, str] = {k: os.environ[k] for k in get_args(AccelEnvVar) if k in os.environ}
511+
accel_env_vars: dict[AccelEnvVar, str] = {k: v for k in get_args(AccelEnvVar) if (v := os.environ.get(k))}
502512
return gpu_env_vars | accel_env_vars
503513

504514

@@ -599,7 +609,9 @@ class AccelImageArgsOtherRuntimeRAG(Protocol):
599609
quiet: bool
600610

601611

602-
AccelImageArgs = None | AccelImageArgsVLLMRuntime | AccelImageArgsOtherRuntime | AccelImageArgsOtherRuntimeRAG
612+
AccelImageArgs: TypeAlias = (
613+
None | AccelImageArgsVLLMRuntime | AccelImageArgsOtherRuntime | AccelImageArgsOtherRuntimeRAG
614+
)
603615

604616

605617
def accel_image(config: Config) -> str:
@@ -627,7 +639,7 @@ def accel_image(config: Config) -> str:
627639
vers = minor_release()
628640

629641
should_pull = config.pull in ["always", "missing"] and not config.dryrun
630-
if attempt_to_use_versioned(config.engine, image, vers, True, should_pull):
642+
if config.engine and attempt_to_use_versioned(config.engine, image, vers, True, should_pull):
631643
return f"{image}:{vers}"
632644

633645
return f"{image}:latest"

ramalama/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
import sys
44
from dataclasses import dataclass, field
55
from pathlib import Path
6-
from typing import Any, Literal, Mapping
6+
from typing import Any, Literal, Mapping, TypeAlias
77

88
from ramalama.common import available
99
from ramalama.layered_config import LayeredMixin, deep_merge
1010
from ramalama.toml_parser import TOMLParser
1111

12-
PathStr = str
12+
PathStr: TypeAlias = str
1313
DEFAULT_PORT_RANGE: tuple[int, int] = (8080, 8090)
1414
DEFAULT_PORT: int = DEFAULT_PORT_RANGE[0]
15-
DEFAULT_IMAGE = "quay.io/ramalama/ramalama"
16-
SUPPORTED_ENGINES = Literal["podman", "docker"] | PathStr
17-
SUPPORTED_RUNTIMES = Literal["llama.cpp", "vllm", "mlx"]
18-
COLOR_OPTIONS = Literal["auto", "always", "never"]
15+
DEFAULT_IMAGE: str = "quay.io/ramalama/ramalama"
16+
SUPPORTED_ENGINES: TypeAlias = Literal["podman", "docker"] | PathStr
17+
SUPPORTED_RUNTIMES: TypeAlias = Literal["llama.cpp", "vllm", "mlx"]
18+
COLOR_OPTIONS: TypeAlias = Literal["auto", "always", "never"]
1919

2020

2121
def get_default_engine() -> SUPPORTED_ENGINES | None:
@@ -158,7 +158,7 @@ def load_env_config(env: Mapping[str, str] | None = None) -> dict[str, Any]:
158158
if env is None:
159159
env = os.environ
160160

161-
config = {}
161+
config: dict[str, Any] = {}
162162
for k, v in env.items():
163163
if not k.startswith("RAMALAMA"):
164164
continue

ramalama/file_loaders/file_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _get_loader(self, file: str) -> base.BaseFileLoader:
2323
return loader
2424

2525
@abstractmethod
26-
def load(self):
26+
def load(self, *args, **kwargs):
2727
pass
2828

2929
@classmethod
@@ -121,12 +121,11 @@ def load(self, file_path: str) -> list[dict]:
121121
if unsupported_files:
122122
unsupported_files_warning(unsupported_files, list(self.supported_extensions()))
123123

124-
messages = []
124+
messages: list[dict] = []
125125
if text_files:
126126
messages.append({"role": "system", "content": self.text_manager.load(text_files)})
127127
if image_files:
128-
message = {"role": "system", "content": []}
129-
for content in self.image_manager.load(image_files):
130-
message['content'].append({"type": "image_url", "image_url": {"url": content}})
128+
content = [{"type": "image_url", "image_url": {"url": c}} for c in self.image_manager.load(image_files)]
129+
message = {"role": "system", "content": content}
131130
messages.append(message)
132131
return messages

ramalama/hf_style_repo_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, name: str, organization: str, tag: str = 'latest'):
6464
self.name = name
6565
self.organization = organization
6666
self.tag = tag
67-
self.headers = {}
67+
self.headers: dict = {}
6868
self.blob_url = None
6969
self.model_filename = None
7070
self.model_hash = None

ramalama/model.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import socket
66
import sys
77
import time
8+
from abc import ABC, abstractmethod
89
from typing import Optional
910

1011
import ramalama.chat as chat
@@ -54,7 +55,6 @@
5455

5556

5657
class NoRefFileFound(Exception):
57-
5858
def __init__(self, model: str, *args):
5959
super().__init__(*args)
6060

@@ -74,7 +74,10 @@ def trim_model_name(model):
7474
return model
7575

7676

77-
class ModelBase:
77+
class ModelBase(ABC):
78+
model: str
79+
type: str
80+
7881
def __not_implemented_error(self, param):
7982
return NotImplementedError(f"ramalama {param} for '{type(self).__name__}' not implemented")
8083

@@ -90,40 +93,46 @@ def pull(self, args):
9093
def push(self, source_model, args):
9194
raise self.__not_implemented_error("push")
9295

96+
@abstractmethod
9397
def remove(self, args):
9498
raise self.__not_implemented_error("rm")
9599

100+
@abstractmethod
96101
def bench(self, args):
97102
raise self.__not_implemented_error("bench")
98103

104+
@abstractmethod
99105
def run(self, args):
100106
raise self.__not_implemented_error("run")
101107

108+
@abstractmethod
102109
def perplexity(self, args):
103110
raise self.__not_implemented_error("perplexity")
104111

112+
@abstractmethod
105113
def serve(self, args):
106114
raise self.__not_implemented_error("serve")
107115

116+
@abstractmethod
108117
def exists(self) -> bool:
109118
raise self.__not_implemented_error("exists")
110119

120+
@abstractmethod
111121
def inspect(self, args):
112122
raise self.__not_implemented_error("inspect")
113123

114124

115125
class Model(ModelBase):
116126
"""Model super class"""
117127

118-
model = ""
119-
type = "Model"
128+
type: str = "Model"
120129

121-
def __init__(self, model, model_store_path):
130+
def __init__(self, model: str, model_store_path: str):
122131
self.model = model
123132

124-
split = self.model.rsplit("/", 1)
125-
self.directory = split[0] if len(split) > 1 else ""
126-
self.filename = split[1] if len(split) > 1 else split[0]
133+
split: list[str] = self.model.rsplit("/", 1)
134+
self.directory: str = split[0] if len(split) > 1 else ""
135+
self.filename: str = split[1] if len(split) > 1 else split[0]
127136

128137
self._model_name: str
129138
self._model_tag: str
@@ -432,7 +441,7 @@ def _handle_mlx_chat(self, args):
432441
chat.chat(args)
433442
break
434443
else:
435-
logger.debug(f"MLX server not ready, waiting... (attempt {i+1}/{max_retries})")
444+
logger.debug(f"MLX server not ready, waiting... (attempt {i + 1}/{max_retries})")
436445
time.sleep(3)
437446
continue
438447

@@ -441,7 +450,7 @@ def _handle_mlx_chat(self, args):
441450
perror(f"Error: Failed to connect to MLX server after {max_retries} attempts: {e}")
442451
self._cleanup_server_process(args.pid2kill)
443452
raise e
444-
logger.debug(f"Connection attempt failed, retrying... (attempt {i+1}/{max_retries}): {e}")
453+
logger.debug(f"Connection attempt failed, retrying... (attempt {i + 1}/{max_retries}): {e}")
445454
time.sleep(3)
446455

447456
args.initial_connection = False
@@ -701,7 +710,6 @@ def handle_runtime(self, args, exec_args):
701710
return exec_args
702711

703712
def generate_container_config(self, args, exec_args):
704-
705713
# Get the blob paths (src) and mounted paths (dest)
706714
model_src_path = self._get_entry_model_path(False, False, args.dryrun)
707715
chat_template_src_path = self._get_chat_template_path(False, False, args.dryrun)
@@ -791,7 +799,7 @@ def kube(self, model_paths, chat_template_paths, mmproj_paths, args, exec_args,
791799
kube = Kube(self.model_name, model_paths, chat_template_paths, mmproj_paths, args, exec_args)
792800
kube.generate().write(output_dir)
793801

794-
def inspect(self, args):
802+
def inspect(self, args) -> None:
795803
self.ensure_model_exists(args)
796804

797805
model_name = self.filename

0 commit comments

Comments
 (0)