Skip to content

Commit 75b36dc

Browse files
authored
Merge pull request #1458 from engelmi/snapshot-verification
Snapshot verification
2 parents c15a1e3 + b84527b commit 75b36dc

File tree

5 files changed

+74
-52
lines changed

5 files changed

+74
-52
lines changed

ramalama/endian.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
import sys
12
from enum import IntEnum
23

34

45
class GGUFEndian(IntEnum):
56
LITTLE = 0
67
BIG = 1
78

8-
little = LITTLE
9-
big = BIG
10-
119
def __str__(self):
1210
return self.name
1311

1412

13+
def get_system_endianness() -> GGUFEndian:
14+
return GGUFEndian.LITTLE if sys.byteorder == 'little' else GGUFEndian.BIG
15+
16+
1517
class EndianMismatchError(Exception):
16-
pass
18+
19+
def __init__(self, host_endianness: GGUFEndian, model_endianness: GGUFEndian, *args):
20+
super().__init__(f"Endian mismatch of host ({host_endianness}) and model ({model_endianness})", *args)

ramalama/gguf_parser.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def read_value(model: io.BufferedReader, value_type: GGUFValueType, model_endian
169169
raise ParseError(f"Unknown type '{value_type}'")
170170

171171
@staticmethod
172-
def parse(model_name: str, model_registry: str, model_path: str) -> GGUFModelInfo:
172+
def get_model_endianness(model_path: str) -> GGUFEndian:
173173
# Pin model endianness to Little Endian by default.
174174
# Models downloaded via HuggingFace are majority Little Endian.
175175
model_endianness = GGUFEndian.LITTLE
@@ -182,9 +182,19 @@ def parse(model_name: str, model_registry: str, model_path: str) -> GGUFModelInf
182182
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
183183
if gguf_version & 0xFFFF == 0x0000:
184184
model_endianness = GGUFEndian.BIG
185-
model.seek(4) # Backtrack the reader by 4 bytes to re-read
186-
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
187185

186+
return model_endianness
187+
188+
@staticmethod
189+
def parse(model_name: str, model_registry: str, model_path: str) -> GGUFModelInfo:
190+
model_endianness = GGUFInfoParser.get_model_endianness(model_path)
191+
192+
with open(model_path, "rb") as model:
193+
magic_number = GGUFInfoParser.read_string(model, model_endianness, 4)
194+
if magic_number != GGUFModelInfo.MAGIC_NUMBER:
195+
raise ParseError(f"Invalid GGUF magic number '{magic_number}'")
196+
197+
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
188198
if gguf_version != GGUFModelInfo.VERSION:
189199
raise ParseError(f"Expected GGUF version '{GGUFModelInfo.VERSION}', but got '{gguf_version}'")
190200

ramalama/model_store.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import shutil
3-
import sys
43
import urllib
54
from dataclasses import dataclass
65
from datetime import datetime
@@ -12,7 +11,7 @@
1211
import ramalama.go2jinja as go2jinja
1312
import ramalama.oci
1413
from ramalama.common import download_file, generate_sha256, perror, verify_checksum
15-
from ramalama.endian import EndianMismatchError, GGUFEndian
14+
from ramalama.endian import EndianMismatchError, get_system_endianness
1615
from ramalama.gguf_parser import GGUFInfoParser, GGUFModelInfo
1716
from ramalama.logger import logger
1817

@@ -38,7 +37,6 @@ def __init__(
3837
type: SnapshotFileType,
3938
should_show_progress: bool = False,
4039
should_verify_checksum: bool = False,
41-
should_verify_endianness: bool = True,
4240
required: bool = True,
4341
):
4442
self.url: str = url
@@ -48,7 +46,6 @@ def __init__(
4846
self.type: SnapshotFileType = type
4947
self.should_show_progress: bool = should_show_progress
5048
self.should_verify_checksum: bool = should_verify_checksum
51-
self.should_verify_endianness: bool = should_verify_endianness
5249
self.required: bool = required
5350

5451
def download(self, blob_file_path: str, snapshot_dir: str) -> str:
@@ -69,7 +66,6 @@ def __init__(
6966
type: SnapshotFileType,
7067
should_show_progress: bool = False,
7168
should_verify_checksum: bool = False,
72-
should_verify_endianness: bool = True,
7369
required: bool = True,
7470
):
7571
super().__init__(
@@ -80,7 +76,6 @@ def __init__(
8076
type,
8177
should_show_progress,
8278
should_verify_checksum,
83-
should_verify_endianness,
8479
required,
8580
)
8681
self.content = content
@@ -439,7 +434,6 @@ def _prepare_new_snapshot(self, model_tag: str, snapshot_hash: str, snapshot_fil
439434
os.makedirs(snapshot_directory, exist_ok=True)
440435

441436
def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_files: list[SnapshotFile]):
442-
host_endianness = GGUFEndian.LITTLE if sys.byteorder == 'little' else GGUFEndian.BIG
443437
ref_file = self.get_ref_file(model_tag)
444438

445439
for file in snapshot_files:
@@ -463,20 +457,6 @@ def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_
463457
if not verify_checksum(dest_path):
464458
raise ValueError(f"Checksum verification failed for blob {dest_path}")
465459

466-
if file.should_verify_endianness and GGUFInfoParser.is_model_gguf(dest_path):
467-
model_info = GGUFInfoParser.parse("model", "registry", dest_path)
468-
if host_endianness != model_info.Endianness:
469-
os.remove(dest_path)
470-
perror()
471-
perror(
472-
f"Failed to pull model: "
473-
f"host endian is {host_endianness} but the model endian is {model_info.Endianness}"
474-
)
475-
perror("Failed to pull model: ramalama currently does not support transparent byteswapping")
476-
raise EndianMismatchError(
477-
f"Unexpected model endianness: wanted {host_endianness}, got {model_info.Endianness}"
478-
)
479-
480460
os.symlink(blob_relative_path, self.get_snapshot_file_path(snapshot_hash, file.name))
481461

482462
# save updated ref file
@@ -541,11 +521,52 @@ def _ensure_chat_template(self, model_tag: str, snapshot_hash: str, snapshot_fil
541521

542522
self.update_snapshot(model_tag, snapshot_hash, files)
543523

524+
def _verify_endianness(self, model_tag: str):
525+
ref_file = self.get_ref_file(model_tag)
526+
if ref_file is None:
527+
return
528+
529+
model_hash = self.get_blob_file_hash(ref_file.hash, ref_file.model_name)
530+
model_path = self.get_blob_file_path(model_hash)
531+
532+
# only check endianness for gguf models
533+
if not GGUFInfoParser.is_model_gguf(model_path):
534+
return
535+
536+
model_endianness = GGUFInfoParser.get_model_endianness(model_path)
537+
host_endianness = get_system_endianness()
538+
if host_endianness != model_endianness:
539+
raise EndianMismatchError(host_endianness, model_endianness)
540+
541+
def verify_snapshot(self, model_tag: str):
542+
self._verify_endianness(model_tag)
543+
self._store.verify_snapshot()
544+
544545
def new_snapshot(self, model_tag: str, snapshot_hash: str, snapshot_files: list[SnapshotFile]):
545546
snapshot_hash = sanitize_filename(snapshot_hash)
546-
self._prepare_new_snapshot(model_tag, snapshot_hash, snapshot_files)
547-
self._download_snapshot_files(model_tag, snapshot_hash, snapshot_files)
548-
self._ensure_chat_template(model_tag, snapshot_hash, snapshot_files)
547+
548+
try:
549+
self._prepare_new_snapshot(model_tag, snapshot_hash, snapshot_files)
550+
self._download_snapshot_files(model_tag, snapshot_hash, snapshot_files)
551+
self._ensure_chat_template(model_tag, snapshot_hash, snapshot_files)
552+
except urllib.error.HTTPError as ex:
553+
perror(f"Failed to fetch required file: {ex}")
554+
perror("Removing snapshot...")
555+
self.remove_snapshot(model_tag)
556+
raise ex
557+
except Exception as ex:
558+
perror(f"Failed to create new snapshot: {ex}")
559+
perror("Removing snapshot...")
560+
self.remove_snapshot(model_tag)
561+
raise ex
562+
563+
try:
564+
self.verify_snapshot(model_tag)
565+
except EndianMismatchError as ex:
566+
perror(f"Verification of snapshot failed: {ex}")
567+
perror("Removing snapshot...")
568+
self.remove_snapshot(model_tag)
569+
raise ex
549570

550571
def update_snapshot(self, model_tag: str, snapshot_hash: str, new_snapshot_files: list[SnapshotFile]) -> bool:
551572
validate_snapshot_files(new_snapshot_files)
@@ -595,6 +616,9 @@ def remove_snapshot(self, model_tag: str):
595616
snapshot_directory = self.get_snapshot_directory_from_tag(model_tag)
596617
shutil.rmtree(snapshot_directory, ignore_errors=False)
597618

598-
# Remove ref file
619+
# Remove ref file, ignore if file is not found
599620
ref_file_path = self.get_ref_file_path(model_tag)
600-
os.remove(ref_file_path)
621+
try:
622+
os.remove(ref_file_path)
623+
except FileNotFoundError:
624+
pass

ramalama/ollama.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,7 @@ def _pull_with_modelstore(self, args):
250250
self.print_pull_message(f"ollama://{name}:{tag}")
251251

252252
model_hash = ollama_repo.get_model_hash(manifest)
253-
try:
254-
self.store.new_snapshot(tag, model_hash, files)
255-
except urllib.error.HTTPError as e:
256-
if "Not Found" in e.reason:
257-
raise KeyError(f"{name}:{tag} was not found in the Ollama registry")
258-
259-
err = str(e).strip("'")
260-
raise KeyError(f"failed to fetch snapshot files: {err}")
253+
self.store.new_snapshot(tag, model_hash, files)
261254

262255
# If a model has been downloaded via ollama cli, only create symlink in the snapshots directory
263256
if is_model_in_ollama_cache:

ramalama/repo_model_base.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -333,17 +333,8 @@ def _pull_with_model_store(self, args):
333333
repo = self.create_repository(name, organization)
334334
snapshot_hash = repo.model_hash
335335
files = repo.get_file_list(cached_files)
336-
try:
337-
self.store.new_snapshot(tag, snapshot_hash, files)
338-
except Exception as e:
339-
# Cleanup failed snapshot
340-
try:
341-
self.store.remove_snapshot(tag)
342-
except Exception as exc:
343-
logger.debug(f"ignoring failure to remove snapshot: {exc}")
344-
# ignore any error when removing snapshot
345-
pass
346-
raise e
336+
self.store.new_snapshot(tag, snapshot_hash, files)
337+
347338
except Exception as e:
348339
if not available(self.get_cli_command()):
349340
perror(f"URL pull failed and {self.get_cli_command()} not available")

0 commit comments

Comments
 (0)