Skip to content

Commit 10d468e

Browse files
committed
Added support for safetensors to inspect command
Signed-off-by: Michael Engel <[email protected]>
1 parent 0947e11 commit 10d468e

File tree

9 files changed

+176
-58
lines changed

9 files changed

+176
-58
lines changed

ramalama/model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,13 @@
2727
from ramalama.config import CONFIG, DEFAULT_PORT, DEFAULT_PORT_RANGE
2828
from ramalama.console import should_colorize
2929
from ramalama.engine import Engine, dry_run
30-
from ramalama.gguf_parser import GGUFInfoParser
3130
from ramalama.kube import Kube
3231
from ramalama.logger import logger
33-
from ramalama.model_inspect import GGUFModelInfo, ModelInfoBase
32+
from ramalama.model_inspect.base_info import ModelInfoBase
33+
from ramalama.model_inspect.gguf_info import GGUFModelInfo
34+
from ramalama.model_inspect.gguf_parser import GGUFInfoParser
35+
from ramalama.model_inspect.safetensor_info import SafetensorModelInfo
36+
from ramalama.model_inspect.safetensor_parser import SafetensorInfoParser
3437
from ramalama.model_store.global_store import GlobalModelStore
3538
from ramalama.model_store.store import ModelStore
3639
from ramalama.quadlet import Quadlet
@@ -838,6 +841,10 @@ def inspect(self, args):
838841
gguf_info: GGUFModelInfo = GGUFInfoParser.parse(model_name, model_registry, model_path)
839842
print(gguf_info.serialize(json=args.json, all=args.all))
840843
return
844+
if SafetensorInfoParser.is_model_safetensor(model_name):
845+
safetensor_info: SafetensorModelInfo = SafetensorInfoParser.parse(model_name, model_registry, model_path)
846+
print(safetensor_info.serialize(json=args.json, all=args.all))
847+
return
841848

842849
print(ModelInfoBase(model_name, model_registry, model_path).serialize(json=args.json))
843850

ramalama/model_inspect/base_info.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import json
2+
import shutil
3+
import sys
4+
from dataclasses import dataclass
5+
6+
7+
def get_terminal_width():
8+
return shutil.get_terminal_size().columns if sys.stdout.isatty() else 80
9+
10+
11+
def adjust_new_line(line: str) -> str:
12+
filler = "..."
13+
max_width = get_terminal_width()
14+
adjusted_length = max_width - len(filler)
15+
16+
adjust_for_newline = 1 if line.endswith("\n") else 0
17+
if len(line) - adjust_for_newline > max_width:
18+
return line[: adjusted_length - adjust_for_newline] + filler + "\n" if adjust_for_newline == 1 else ""
19+
return line if line.endswith("\n") else line + "\n"
20+
21+
22+
@dataclass
23+
class Tensor:
24+
name: str
25+
n_dimensions: int
26+
dimensions: list[int]
27+
type: str
28+
offset: int
29+
30+
31+
@dataclass
32+
class ModelInfoBase:
33+
Name: str
34+
Registry: str
35+
Path: str
36+
37+
def serialize(self, json: bool = False) -> str:
38+
if json:
39+
return self.to_json()
40+
41+
ret = adjust_new_line(f"{self.Name}\n")
42+
ret = ret + adjust_new_line(f" Path: {self.Path}\n")
43+
ret = ret + adjust_new_line(f" Registry: {self.Registry}")
44+
return ret
45+
46+
def to_json(self) -> str:
47+
return json.dumps(self.__dict__, sort_keys=True, indent=4)

ramalama/model_inspect/error.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Basic error when parsing model files
2+
class ParseError(Exception):
3+
pass

ramalama/model_inspect.py renamed to ramalama/model_inspect/gguf_info.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,8 @@
11
import json
2-
import shutil
3-
import sys
4-
from dataclasses import dataclass
52
from typing import Any, Dict
63

74
from ramalama.endian import GGUFEndian
8-
9-
10-
def get_terminal_width():
11-
if sys.stdout.isatty():
12-
return shutil.get_terminal_size().columns
13-
return 80
14-
15-
16-
def adjust_new_line(line: str) -> str:
17-
filler = "..."
18-
max_width = get_terminal_width()
19-
adjusted_length = max_width - len(filler)
20-
21-
adjust_for_newline = 1 if line.endswith("\n") else 0
22-
if len(line) - adjust_for_newline > max_width:
23-
return line[: adjusted_length - adjust_for_newline] + filler + "\n" if adjust_for_newline == 1 else ""
24-
if not line.endswith("\n"):
25-
return line + "\n"
26-
return line
27-
28-
29-
@dataclass
30-
class Tensor:
31-
name: str
32-
n_dimensions: int
33-
dimensions: list[int]
34-
type: str
35-
offset: int
36-
37-
38-
@dataclass
39-
class ModelInfoBase:
40-
Name: str
41-
Registry: str
42-
Path: str
43-
44-
def serialize(self, json: bool = False) -> str:
45-
ret = adjust_new_line(f"{self.Name}\n")
46-
ret = ret + adjust_new_line(f" Path: {self.Path}\n")
47-
ret = ret + adjust_new_line(f" Registry: {self.Registry}")
48-
return ret
49-
50-
def to_json(self) -> str:
51-
return json.dumps(self, sort_keys=True, indent=4)
5+
from ramalama.model_inspect.base_info import ModelInfoBase, Tensor, adjust_new_line
526

537

548
class GGUFModelInfo(ModelInfoBase):

ramalama/gguf_parser.py renamed to ramalama/model_inspect/gguf_parser.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from enum import IntEnum
44
from typing import Any, Dict
55

6-
import ramalama.console as console
76
from ramalama.endian import GGUFEndian
8-
from ramalama.model_inspect import GGUFModelInfo, Tensor
7+
from ramalama.model_inspect.error import ParseError
8+
from ramalama.model_inspect.gguf_info import GGUFModelInfo, Tensor
99

1010

1111
# Based on ggml_type in
@@ -99,19 +99,14 @@ class GGUFValueType(IntEnum):
9999
]
100100

101101

102-
class ParseError(Exception):
103-
pass
104-
105-
106102
class GGUFInfoParser:
107103
@staticmethod
108104
def is_model_gguf(model_path: str) -> bool:
109105
try:
110106
with open(model_path, "rb") as model_file:
111107
magic_number = GGUFInfoParser.read_string(model_file, GGUFEndian.LITTLE, 4)
112108
return magic_number == GGUFModelInfo.MAGIC_NUMBER
113-
except Exception as ex:
114-
console.warning(f"Failed to read model '{model_path}': {ex}")
109+
except Exception:
115110
return False
116111

117112
@staticmethod
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import json
2+
from typing import Any, Dict
3+
4+
from ramalama.model_inspect.base_info import ModelInfoBase, adjust_new_line
5+
6+
7+
class SafetensorModelInfo(ModelInfoBase):
8+
9+
def __init__(
10+
self,
11+
Name: str,
12+
Registry: str,
13+
Path: str,
14+
header_data: Dict[str, Any],
15+
):
16+
super().__init__(Name, Registry, Path)
17+
18+
self.Header: Dict[str, Any] = header_data
19+
20+
def serialize(self, json: bool = False, all: bool = False) -> str:
21+
if json:
22+
return self.to_json(all)
23+
24+
fmt = ""
25+
metadata = self.Header.get("__metadata__", {})
26+
if isinstance(metadata, dict):
27+
fmt = metadata.get("format", "")
28+
29+
ret = super().serialize()
30+
ret = ret + adjust_new_line(f" Format: {fmt}")
31+
metadata_header = " Header: "
32+
if not all:
33+
metadata_header = metadata_header + f"{len(self.Header)} entries"
34+
ret = ret + adjust_new_line(metadata_header)
35+
if all:
36+
for key, value in sorted(self.Header.items()):
37+
ret = ret + adjust_new_line(f" {key}: {value}")
38+
39+
return ret
40+
41+
def to_json(self, all: bool = False) -> str:
42+
if all:
43+
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
44+
45+
d = {k: v for k, v in self.__dict__.items() if k != "Header"}
46+
d["Metadata"] = len(self.Header)
47+
return json.dumps(d, sort_keys=True, indent=4)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import json
2+
import struct
3+
4+
import ramalama.console as console
5+
from ramalama.model_inspect.error import ParseError
6+
from ramalama.model_inspect.safetensor_info import SafetensorModelInfo
7+
8+
# Based on safetensor format description:
9+
# https://github.com/huggingface/safetensors?tab=readme-ov-file#format
10+
11+
12+
class SafetensorInfoParser:
13+
14+
@staticmethod
15+
def is_model_safetensor(model_name: str) -> bool:
16+
17+
# There is no magic number or something similar, so we only rely on the naming of the file here
18+
return model_name.endswith(".safetensor") or model_name.endswith(".safetensors")
19+
20+
@staticmethod
21+
def parse(model_name: str, model_registry: str, model_path: str) -> SafetensorModelInfo:
22+
try:
23+
with open(model_path, "rb") as model_file:
24+
prefix = '<'
25+
typestring = f"{prefix}Q"
26+
27+
header_size = struct.unpack(typestring, model_file.read(8))[0]
28+
header = json.loads(model_file.read(header_size))
29+
30+
return SafetensorModelInfo(model_name, model_registry, model_path, header)
31+
32+
except Exception as ex:
33+
msg = f"Failed to parse safetensor model '{model_path}': {ex}"
34+
console.warning(msg)
35+
raise ParseError(msg)

ramalama/model_store/store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import ramalama.model_store.go2jinja as go2jinja
1010
from ramalama.common import perror, verify_checksum
1111
from ramalama.endian import EndianMismatchError, get_system_endianness
12-
from ramalama.gguf_parser import GGUFInfoParser, GGUFModelInfo
1312
from ramalama.logger import logger
13+
from ramalama.model_inspect.gguf_parser import GGUFInfoParser, GGUFModelInfo
1414
from ramalama.model_store.constants import DIRECTORY_NAME_BLOBS, DIRECTORY_NAME_REFS, DIRECTORY_NAME_SNAPSHOTS
1515
from ramalama.model_store.global_store import GlobalModelStore
1616
from ramalama.model_store.reffile import RefFile

test/system/100-inspect.bats

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,34 @@ load setup_suite
3737
is "${lines[7]}" " general.architecture: llama" "metadata general.architecture"
3838
}
3939

40+
# bats test_tags=distro-integration
41+
@test "ramalama inspect safetensors model" {
42+
ST_MODEL="https://huggingface.co/LiheYoung/depth-anything-small-hf/resolve/main/model.safetensors"
43+
44+
run_ramalama pull $ST_MODEL
45+
run_ramalama inspect $ST_MODEL
46+
47+
is "${lines[0]}" "model.safetensors" "model name"
48+
is "${lines[1]}" " Path: .*store/https/huggingface.co/.*" "model path"
49+
is "${lines[2]}" " Registry: https" "model registry"
50+
is "${lines[3]}" " Format: pt" "model format"
51+
is "${lines[4]}" " Header: 288 entries" "# of metadata entries"
52+
}
53+
54+
# bats test_tags=distro-integration
55+
@test "ramalama inspect safetensors model with --all" {
56+
ST_MODEL="https://huggingface.co/LiheYoung/depth-anything-small-hf/resolve/main/model.safetensors"
57+
58+
run_ramalama inspect --all $ST_MODEL
59+
60+
is "${lines[0]}" "model.safetensors" "model name"
61+
is "${lines[1]}" " Path: .*store/https/huggingface.co/.*" "model path"
62+
is "${lines[2]}" " Registry: https" "model registry"
63+
is "${lines[3]}" " Format: pt" "model format"
64+
is "${lines[4]}" " Header: " "metadata header"
65+
is "${lines[5]}" " __metadata__: {'format': 'pt'}" "metadata"
66+
67+
run_ramalama rm $ST_MODEL
68+
}
69+
4070
# vim: filetype=sh

0 commit comments

Comments
 (0)