1
1
import os
2
2
import shutil
3
- import sys
4
3
import urllib
5
4
from dataclasses import dataclass
6
5
from datetime import datetime
12
11
import ramalama .go2jinja as go2jinja
13
12
import ramalama .oci
14
13
from ramalama .common import download_file , generate_sha256 , perror , verify_checksum
15
- from ramalama .endian import EndianMismatchError , GGUFEndian
14
+ from ramalama .endian import EndianMismatchError , get_system_endianness
16
15
from ramalama .gguf_parser import GGUFInfoParser , GGUFModelInfo
17
16
from ramalama .logger import logger
18
17
@@ -38,7 +37,6 @@ def __init__(
38
37
type : SnapshotFileType ,
39
38
should_show_progress : bool = False ,
40
39
should_verify_checksum : bool = False ,
41
- should_verify_endianness : bool = True ,
42
40
required : bool = True ,
43
41
):
44
42
self .url : str = url
@@ -48,7 +46,6 @@ def __init__(
48
46
self .type : SnapshotFileType = type
49
47
self .should_show_progress : bool = should_show_progress
50
48
self .should_verify_checksum : bool = should_verify_checksum
51
- self .should_verify_endianness : bool = should_verify_endianness
52
49
self .required : bool = required
53
50
54
51
def download (self , blob_file_path : str , snapshot_dir : str ) -> str :
@@ -69,7 +66,6 @@ def __init__(
69
66
type : SnapshotFileType ,
70
67
should_show_progress : bool = False ,
71
68
should_verify_checksum : bool = False ,
72
- should_verify_endianness : bool = True ,
73
69
required : bool = True ,
74
70
):
75
71
super ().__init__ (
@@ -80,7 +76,6 @@ def __init__(
80
76
type ,
81
77
should_show_progress ,
82
78
should_verify_checksum ,
83
- should_verify_endianness ,
84
79
required ,
85
80
)
86
81
self .content = content
@@ -439,7 +434,6 @@ def _prepare_new_snapshot(self, model_tag: str, snapshot_hash: str, snapshot_fil
439
434
os .makedirs (snapshot_directory , exist_ok = True )
440
435
441
436
def _download_snapshot_files (self , model_tag : str , snapshot_hash : str , snapshot_files : list [SnapshotFile ]):
442
- host_endianness = GGUFEndian .LITTLE if sys .byteorder == 'little' else GGUFEndian .BIG
443
437
ref_file = self .get_ref_file (model_tag )
444
438
445
439
for file in snapshot_files :
@@ -463,20 +457,6 @@ def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_
463
457
if not verify_checksum (dest_path ):
464
458
raise ValueError (f"Checksum verification failed for blob { dest_path } " )
465
459
466
- if file .should_verify_endianness and GGUFInfoParser .is_model_gguf (dest_path ):
467
- model_info = GGUFInfoParser .parse ("model" , "registry" , dest_path )
468
- if host_endianness != model_info .Endianness :
469
- os .remove (dest_path )
470
- perror ()
471
- perror (
472
- f"Failed to pull model: "
473
- f"host endian is { host_endianness } but the model endian is { model_info .Endianness } "
474
- )
475
- perror ("Failed to pull model: ramalama currently does not support transparent byteswapping" )
476
- raise EndianMismatchError (
477
- f"Unexpected model endianness: wanted { host_endianness } , got { model_info .Endianness } "
478
- )
479
-
480
460
os .symlink (blob_relative_path , self .get_snapshot_file_path (snapshot_hash , file .name ))
481
461
482
462
# save updated ref file
@@ -541,11 +521,52 @@ def _ensure_chat_template(self, model_tag: str, snapshot_hash: str, snapshot_fil
541
521
542
522
self .update_snapshot (model_tag , snapshot_hash , files )
543
523
524
+ def _verify_endianness (self , model_tag : str ):
525
+ ref_file = self .get_ref_file (model_tag )
526
+ if ref_file is None :
527
+ return
528
+
529
+ model_hash = self .get_blob_file_hash (ref_file .hash , ref_file .model_name )
530
+ model_path = self .get_blob_file_path (model_hash )
531
+
532
+ # only check endianness for gguf models
533
+ if not GGUFInfoParser .is_model_gguf (model_path ):
534
+ return
535
+
536
+ model_endianness = GGUFInfoParser .get_model_endianness (model_path )
537
+ host_endianness = get_system_endianness ()
538
+ if host_endianness != model_endianness :
539
+ raise EndianMismatchError (host_endianness , model_endianness )
540
+
541
+ def verify_snapshot (self , model_tag : str ):
542
+ self ._verify_endianness (model_tag )
543
+ self ._store .verify_snapshot ()
544
+
544
545
def new_snapshot (self , model_tag : str , snapshot_hash : str , snapshot_files : list [SnapshotFile ]):
545
546
snapshot_hash = sanitize_filename (snapshot_hash )
546
- self ._prepare_new_snapshot (model_tag , snapshot_hash , snapshot_files )
547
- self ._download_snapshot_files (model_tag , snapshot_hash , snapshot_files )
548
- self ._ensure_chat_template (model_tag , snapshot_hash , snapshot_files )
547
+
548
+ try :
549
+ self ._prepare_new_snapshot (model_tag , snapshot_hash , snapshot_files )
550
+ self ._download_snapshot_files (model_tag , snapshot_hash , snapshot_files )
551
+ self ._ensure_chat_template (model_tag , snapshot_hash , snapshot_files )
552
+ except urllib .error .HTTPError as ex :
553
+ perror (f"Failed to fetch required file: { ex } " )
554
+ perror ("Removing snapshot..." )
555
+ self .remove_snapshot (model_tag )
556
+ raise ex
557
+ except Exception as ex :
558
+ perror (f"Failed to create new snapshot: { ex } " )
559
+ perror ("Removing snapshot..." )
560
+ self .remove_snapshot (model_tag )
561
+ raise ex
562
+
563
+ try :
564
+ self .verify_snapshot (model_tag )
565
+ except EndianMismatchError as ex :
566
+ perror (f"Verification of snapshot failed: { ex } " )
567
+ perror ("Removing snapshot..." )
568
+ self .remove_snapshot (model_tag )
569
+ raise ex
549
570
550
571
def update_snapshot (self , model_tag : str , snapshot_hash : str , new_snapshot_files : list [SnapshotFile ]) -> bool :
551
572
validate_snapshot_files (new_snapshot_files )
@@ -595,6 +616,9 @@ def remove_snapshot(self, model_tag: str):
595
616
snapshot_directory = self .get_snapshot_directory_from_tag (model_tag )
596
617
shutil .rmtree (snapshot_directory , ignore_errors = False )
597
618
598
- # Remove ref file
619
+ # Remove ref file, ignore if file is not found
599
620
ref_file_path = self .get_ref_file_path (model_tag )
600
- os .remove (ref_file_path )
621
+ try :
622
+ os .remove (ref_file_path )
623
+ except FileNotFoundError :
624
+ pass
0 commit comments