Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ramalama/endian.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import sys
from enum import IntEnum


class GGUFEndian(IntEnum):
LITTLE = 0
BIG = 1

little = LITTLE
big = BIG

def __str__(self):
return self.name


def get_system_endianness() -> GGUFEndian:
return GGUFEndian.LITTLE if sys.byteorder == 'little' else GGUFEndian.BIG


class EndianMismatchError(Exception):
pass

def __init__(self, host_endianness: GGUFEndian, model_endianness: GGUFEndian, *args):
super().__init__(f"Endian mismatch of host ({host_endianness}) and model ({model_endianness})", *args)
16 changes: 13 additions & 3 deletions ramalama/gguf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def read_value(model: io.BufferedReader, value_type: GGUFValueType, model_endian
raise ParseError(f"Unknown type '{value_type}'")

@staticmethod
def parse(model_name: str, model_registry: str, model_path: str) -> GGUFModelInfo:
def get_model_endianness(model_path: str) -> GGUFEndian:
# Pin model endianness to Little Endian by default.
# Models downloaded via HuggingFace are majority Little Endian.
model_endianness = GGUFEndian.LITTLE
Expand All @@ -182,9 +182,19 @@ def parse(model_name: str, model_registry: str, model_path: str) -> GGUFModelInf
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
if gguf_version & 0xFFFF == 0x0000:
model_endianness = GGUFEndian.BIG
model.seek(4) # Backtrack the reader by 4 bytes to re-read
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)

return model_endianness

@staticmethod
def parse(model_name: str, model_registry: str, model_path: str) -> GGUFModelInfo:
model_endianness = GGUFInfoParser.get_model_endianness(model_path)

with open(model_path, "rb") as model:
magic_number = GGUFInfoParser.read_string(model, model_endianness, 4)
if magic_number != GGUFModelInfo.MAGIC_NUMBER:
raise ParseError(f"Invalid GGUF magic number '{magic_number}'")

gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
if gguf_version != GGUFModelInfo.VERSION:
raise ParseError(f"Expected GGUF version '{GGUFModelInfo.VERSION}', but got '{gguf_version}'")

Expand Down
76 changes: 50 additions & 26 deletions ramalama/model_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import shutil
import sys
import urllib
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -12,7 +11,7 @@
import ramalama.go2jinja as go2jinja
import ramalama.oci
from ramalama.common import download_file, generate_sha256, perror, verify_checksum
from ramalama.endian import EndianMismatchError, GGUFEndian
from ramalama.endian import EndianMismatchError, get_system_endianness
from ramalama.gguf_parser import GGUFInfoParser, GGUFModelInfo
from ramalama.logger import logger

Expand All @@ -38,7 +37,6 @@ def __init__(
type: SnapshotFileType,
should_show_progress: bool = False,
should_verify_checksum: bool = False,
should_verify_endianness: bool = True,
required: bool = True,
):
self.url: str = url
Expand All @@ -48,7 +46,6 @@ def __init__(
self.type: SnapshotFileType = type
self.should_show_progress: bool = should_show_progress
self.should_verify_checksum: bool = should_verify_checksum
self.should_verify_endianness: bool = should_verify_endianness
self.required: bool = required

def download(self, blob_file_path: str, snapshot_dir: str) -> str:
Expand All @@ -69,7 +66,6 @@ def __init__(
type: SnapshotFileType,
should_show_progress: bool = False,
should_verify_checksum: bool = False,
should_verify_endianness: bool = True,
required: bool = True,
):
super().__init__(
Expand All @@ -80,7 +76,6 @@ def __init__(
type,
should_show_progress,
should_verify_checksum,
should_verify_endianness,
required,
)
self.content = content
Expand Down Expand Up @@ -439,7 +434,6 @@ def _prepare_new_snapshot(self, model_tag: str, snapshot_hash: str, snapshot_fil
os.makedirs(snapshot_directory, exist_ok=True)

def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_files: list[SnapshotFile]):
host_endianness = GGUFEndian.LITTLE if sys.byteorder == 'little' else GGUFEndian.BIG
ref_file = self.get_ref_file(model_tag)

for file in snapshot_files:
Expand All @@ -463,20 +457,6 @@ def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_
if not verify_checksum(dest_path):
raise ValueError(f"Checksum verification failed for blob {dest_path}")

if file.should_verify_endianness and GGUFInfoParser.is_model_gguf(dest_path):
model_info = GGUFInfoParser.parse("model", "registry", dest_path)
if host_endianness != model_info.Endianness:
os.remove(dest_path)
perror()
perror(
f"Failed to pull model: "
f"host endian is {host_endianness} but the model endian is {model_info.Endianness}"
)
perror("Failed to pull model: ramalama currently does not support transparent byteswapping")
raise EndianMismatchError(
f"Unexpected model endianness: wanted {host_endianness}, got {model_info.Endianness}"
)

os.symlink(blob_relative_path, self.get_snapshot_file_path(snapshot_hash, file.name))

# save updated ref file
Expand Down Expand Up @@ -541,11 +521,52 @@ def _ensure_chat_template(self, model_tag: str, snapshot_hash: str, snapshot_fil

self.update_snapshot(model_tag, snapshot_hash, files)

def _verify_endianness(self, model_tag: str):
ref_file = self.get_ref_file(model_tag)
if ref_file is None:
return

model_hash = self.get_blob_file_hash(ref_file.hash, ref_file.model_name)
model_path = self.get_blob_file_path(model_hash)

# only check endianness for gguf models
if not GGUFInfoParser.is_model_gguf(model_path):
return

model_endianness = GGUFInfoParser.get_model_endianness(model_path)
host_endianness = get_system_endianness()
if host_endianness != model_endianness:
raise EndianMismatchError(host_endianness, model_endianness)

def verify_snapshot(self, model_tag: str):
self._verify_endianness(model_tag)
self._store.verify_snapshot()

def new_snapshot(self, model_tag: str, snapshot_hash: str, snapshot_files: list[SnapshotFile]):
snapshot_hash = sanitize_filename(snapshot_hash)
self._prepare_new_snapshot(model_tag, snapshot_hash, snapshot_files)
self._download_snapshot_files(model_tag, snapshot_hash, snapshot_files)
self._ensure_chat_template(model_tag, snapshot_hash, snapshot_files)

try:
self._prepare_new_snapshot(model_tag, snapshot_hash, snapshot_files)
self._download_snapshot_files(model_tag, snapshot_hash, snapshot_files)
self._ensure_chat_template(model_tag, snapshot_hash, snapshot_files)
except urllib.error.HTTPError as ex:
perror(f"Failed to fetch required file: {ex}")
perror("Removing snapshot...")
self.remove_snapshot(model_tag)
raise ex
except Exception as ex:
perror(f"Failed to create new snapshot: {ex}")
perror("Removing snapshot...")
self.remove_snapshot(model_tag)
raise ex

try:
self.verify_snapshot(model_tag)
except EndianMismatchError as ex:
perror(f"Verification of snapshot failed: {ex}")
perror("Removing snapshot...")
self.remove_snapshot(model_tag)
raise ex

def update_snapshot(self, model_tag: str, snapshot_hash: str, new_snapshot_files: list[SnapshotFile]) -> bool:
validate_snapshot_files(new_snapshot_files)
Expand Down Expand Up @@ -595,6 +616,9 @@ def remove_snapshot(self, model_tag: str):
snapshot_directory = self.get_snapshot_directory_from_tag(model_tag)
shutil.rmtree(snapshot_directory, ignore_errors=False)

# Remove ref file
# Remove ref file, ignore if file is not found
ref_file_path = self.get_ref_file_path(model_tag)
os.remove(ref_file_path)
try:
os.remove(ref_file_path)
except FileNotFoundError:
pass
9 changes: 1 addition & 8 deletions ramalama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,7 @@ def _pull_with_modelstore(self, args):
self.print_pull_message(f"ollama://{name}:{tag}")

model_hash = ollama_repo.get_model_hash(manifest)
try:
self.store.new_snapshot(tag, model_hash, files)
except urllib.error.HTTPError as e:
if "Not Found" in e.reason:
raise KeyError(f"{name}:{tag} was not found in the Ollama registry")

err = str(e).strip("'")
raise KeyError(f"failed to fetch snapshot files: {err}")
self.store.new_snapshot(tag, model_hash, files)

# If a model has been downloaded via ollama cli, only create symlink in the snapshots directory
if is_model_in_ollama_cache:
Expand Down
13 changes: 2 additions & 11 deletions ramalama/repo_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,17 +333,8 @@ def _pull_with_model_store(self, args):
repo = self.create_repository(name, organization)
snapshot_hash = repo.model_hash
files = repo.get_file_list(cached_files)
try:
self.store.new_snapshot(tag, snapshot_hash, files)
except Exception as e:
# Cleanup failed snapshot
try:
self.store.remove_snapshot(tag)
except Exception as exc:
logger.debug(f"ignoring failure to remove snapshot: {exc}")
# ignore any error when removing snapshot
pass
raise e
self.store.new_snapshot(tag, snapshot_hash, files)

except Exception as e:
if not available(self.get_cli_command()):
perror(f"URL pull failed and {self.get_cli_command()} not available")
Expand Down
Loading