Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ramalama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ramalama.logger import configure_logger, logger
from ramalama.model import MODEL_TYPES
from ramalama.model_factory import ModelFactory, New
from ramalama.model_store import GlobalModelStore
from ramalama.model_store.global_store import GlobalModelStore
from ramalama.shortnames import Shortnames
from ramalama.stack import Stack
from ramalama.version import print_version, version
Expand Down
2 changes: 1 addition & 1 deletion ramalama/hf_style_repo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ramalama.common import available, exec_cmd, generate_sha256, perror, run_cmd
from ramalama.logger import logger
from ramalama.model import Model
from ramalama.model_store import SnapshotFile, SnapshotFileType
from ramalama.model_store.snapshot_file import SnapshotFile, SnapshotFileType


class HFStyleRepoFile(SnapshotFile):
Expand Down
2 changes: 1 addition & 1 deletion ramalama/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
fetch_checksum_from_api_base,
)
from ramalama.logger import logger
from ramalama.model_store import SnapshotFileType
from ramalama.model_store.snapshot_file import SnapshotFileType

missing_huggingface = """
Optional: Huggingface models require the huggingface-cli module.
Expand Down
3 changes: 2 additions & 1 deletion ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from ramalama.kube import Kube
from ramalama.logger import logger
from ramalama.model_inspect import GGUFModelInfo, ModelInfoBase
from ramalama.model_store import GlobalModelStore, ModelStore
from ramalama.model_store.global_store import GlobalModelStore
from ramalama.model_store.store import ModelStore
from ramalama.quadlet import Quadlet
from ramalama.rag import rag_image
from ramalama.version import version
Expand Down
3 changes: 3 additions & 0 deletions ramalama/model_store/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
DIRECTORY_NAME_BLOBS = "blobs"
DIRECTORY_NAME_REFS = "refs"
DIRECTORY_NAME_SNAPSHOTS = "snapshots"
90 changes: 90 additions & 0 deletions ramalama/model_store/global_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List

import ramalama.oci
from ramalama.arg_types import EngineArgs
from ramalama.model_store.constants import DIRECTORY_NAME_BLOBS, DIRECTORY_NAME_REFS, DIRECTORY_NAME_SNAPSHOTS
from ramalama.model_store.reffile import RefFile


@dataclass
class ModelFile:
name: str
modified: float
size: int
is_partial: bool


class GlobalModelStore:
def __init__(
self,
base_path: str,
):
self._store_base_path = os.path.join(base_path, "store")

@property
def path(self) -> str:
return self._store_base_path

def list_models(self, engine: str, show_container: bool) -> Dict[str, List[ModelFile]]:
models: Dict[str, List[ModelFile]] = {}

for root, subdirs, _ in os.walk(self.path):
if DIRECTORY_NAME_REFS in subdirs:
ref_dir = os.path.join(root, DIRECTORY_NAME_REFS)
for ref_file_name in os.listdir(ref_dir):
ref_file: RefFile = RefFile.from_path(os.path.join(ref_dir, ref_file_name))
model_path = root.replace(self.path, "").replace(os.sep, "", 1)

parts = model_path.split("/")
model_source = parts[0]
model_path_without_source = f"{os.sep}".join(parts[1:])

separator = ":///" if model_source == "file" else "://" # Use ':///' for file URLs, '://' otherwise
model_name = f"{model_source}{separator}{model_path_without_source}:{ref_file_name}"

collected_files = []
for snapshot_file in ref_file.filenames:
is_partially_downloaded = False
snapshot_file_path = os.path.join(root, DIRECTORY_NAME_SNAPSHOTS, ref_file.hash, snapshot_file)
if not os.path.exists(snapshot_file_path):
blobs_partial_file_path = os.path.join(
root, DIRECTORY_NAME_BLOBS, ref_file.hash + ".partial"
)
if not os.path.exists(blobs_partial_file_path):
continue

snapshot_file_path = blobs_partial_file_path
is_partially_downloaded = True

last_modified = os.path.getmtime(snapshot_file_path)
file_size = os.path.getsize(snapshot_file_path)
collected_files.append(
ModelFile(snapshot_file, last_modified, file_size, is_partially_downloaded)
)
models[model_name] = collected_files

if show_container:
oci_models = ramalama.oci.list_models(EngineArgs(engine=engine))
for oci_model in oci_models:
name, modified, size = (oci_model["name"], oci_model["modified"], oci_model["size"])
# ramalama.oci.list_models provides modified as timestamp string, convert it to unix timestamp
modified_unix = datetime.fromisoformat(modified).timestamp()
models[name] = [ModelFile(name, modified_unix, size, is_partial=False)]

return models

# TODO:
# iterating over all symlinks in snapshot dir, check valid
def verify_snapshot(self):
pass

# TODO:
# iterating over models and check
# 1. for broken symlinks in snapshot dirs -> delete and update refs
# 2. for blobs not reached by ref->snapshot chain -> delete
# 3. for empty folders -> delete
def cleanup(self):
pass
File renamed without changes.
70 changes: 70 additions & 0 deletions ramalama/model_store/reffile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
class RefFile:
SEP = "---"
MODEL_SUFFIX = "model"
CHAT_TEMPLATE_SUFFIX = "chat"
MMPROJ_SUFFIX = "mmproj"

def __init__(self):
self.hash: str = ""
self.filenames: list[str] = []
self.model_name: str = ""
self.chat_template_name: str = ""
self.mmproj_name: str = ""
self._path: str = ""

@property
def path(self) -> str:
return self._path

def from_path(path: str) -> "RefFile":
ref_file = RefFile()
ref_file._path = path
with open(path, "r") as file:
ref_file.hash = file.readline().strip()
filename = file.readline().strip()
while filename != "":
parts = filename.split(RefFile.SEP)
if len(parts) != 2:
ref_file.filenames.append(filename)
filename = file.readline().strip()
continue

ref_file.filenames.append(parts[0])
if parts[1] == RefFile.MODEL_SUFFIX:
ref_file.model_name = parts[0]
if parts[1] == RefFile.CHAT_TEMPLATE_SUFFIX:
ref_file.chat_template_name = parts[0]
if parts[1] == RefFile.MMPROJ_SUFFIX:
ref_file.mmproj_name = parts[0]

filename = file.readline().strip()
return ref_file

def remove_file(self, name: str):
if name in self.filenames:
self.filenames.remove(name)

if self.chat_template_name == name:
self.chat_template_name = ""
if self.model_name == name:
self.model_name = ""
if self.mmproj_name == name:
self.mmproj_name = ""

def serialize(self) -> str:
lines = [self.hash]
for filename in self.filenames:
line = f"{filename}{RefFile.SEP}"
if filename == self.model_name:
line = line + RefFile.MODEL_SUFFIX
if filename == self.chat_template_name:
line = line + RefFile.CHAT_TEMPLATE_SUFFIX
if filename == self.mmproj_name:
line = line + RefFile.MMPROJ_SUFFIX
lines.append(line)
return "\n".join(lines)

def write_to_file(self):
with open(self.path, "w") as file:
file.write(self.serialize())
file.flush()
96 changes: 96 additions & 0 deletions ramalama/model_store/snapshot_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
from enum import IntEnum
from typing import Dict

from ramalama.common import download_file, generate_sha256
from ramalama.logger import logger


class SnapshotFileType(IntEnum):
Model = 1
ChatTemplate = 2
Other = 3
Mmproj = 4


class SnapshotFile:
def __init__(
self,
url: str,
header: Dict,
hash: str,
name: str,
type: SnapshotFileType,
should_show_progress: bool = False,
should_verify_checksum: bool = False,
required: bool = True,
):
self.url: str = url
self.header: Dict = header
self.hash: str = hash
self.name: str = name
self.type: SnapshotFileType = type
self.should_show_progress: bool = should_show_progress
self.should_verify_checksum: bool = should_verify_checksum
self.required: bool = required

def download(self, blob_file_path: str, snapshot_dir: str) -> str:
if not os.path.exists(blob_file_path):
download_file(
url=self.url,
headers=self.header,
dest_path=blob_file_path,
show_progress=self.should_show_progress,
)
else:
logger.debug(f"Using cached blob for {self.name} ({os.path.basename(blob_file_path)})")
return os.path.relpath(blob_file_path, start=snapshot_dir)


class LocalSnapshotFile(SnapshotFile):
def __init__(
self,
content: str,
name: str,
type: SnapshotFileType,
should_show_progress: bool = False,
should_verify_checksum: bool = False,
required: bool = True,
):
super().__init__(
"",
"",
generate_sha256(content),
name,
type,
should_show_progress,
should_verify_checksum,
required,
)
self.content = content

def download(self, blob_file_path, snapshot_dir):
with open(blob_file_path, "w") as file:
file.write(self.content)
file.flush()
return os.path.relpath(blob_file_path, start=snapshot_dir)


def validate_snapshot_files(snapshot_files: list[SnapshotFile]):
model_files = []
chat_template_files = []
mmproj_files = []
for file in snapshot_files:
if file.type == SnapshotFileType.Model:
model_files.append(file)
if file.type == SnapshotFileType.ChatTemplate:
chat_template_files.append(file)
if file.type == SnapshotFileType.Mmproj:
mmproj_files.append(file)

if len(model_files) > 1:
raise ValueError(f"Only one model supported, got {len(model_files)}: {model_files}")
if len(chat_template_files) > 1:
raise ValueError(f"Only one chat template supported, got {len(chat_template_files)}: {chat_template_files}")
if len(mmproj_files) > 1:
raise ValueError(f"Only one mmproj supported, got {len(mmproj_files)}: {mmproj_files}")
Loading
Loading