@@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
6060 if os .path .exists (download_target ) and not os .path .isfile (download_target ):
6161 raise RuntimeError (f"{ download_target } exists and is not a regular file" )
6262
63+ def compute_sha256 (file_path : str ) -> str :
64+ sha256 = hashlib .sha256 ()
65+ with open (file_path , "rb" ) as f :
66+ for chunk in iter (lambda : f .read (8192 ), b"" ):
67+ sha256 .update (chunk )
68+ return sha256 .hexdigest ()
69+
6370 if os .path .isfile (download_target ):
64- with open (download_target , "rb" ) as f :
65- model_bytes = f .read ()
66- if hashlib .sha256 (model_bytes ).hexdigest () == expected_sha256 :
67- return model_bytes if in_memory else download_target
71+ if compute_sha256 (download_target ) == expected_sha256 :
72+ if in_memory :
73+ with open (download_target , "rb" ) as f :
74+ return f .read ()
75+ else :
76+ return download_target
6877 else :
6978 warnings .warn (
7079 f"{ download_target } exists, but the SHA256 checksum does not match; re-downloading the file"
@@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
8695 output .write (buffer )
8796 loop .update (len (buffer ))
8897
89- model_bytes = open (download_target , "rb" ).read ()
90- if hashlib .sha256 (model_bytes ).hexdigest () != expected_sha256 :
98+ if compute_sha256 (download_target ) != expected_sha256 :
9199 raise RuntimeError (
92100 "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
93101 )
94102
95- return model_bytes if in_memory else download_target
103+ if in_memory :
104+ with open (download_target , "rb" ) as f :
105+ return f .read ()
106+ else :
107+ return download_target
96108
97109
98110def available_models () -> List [str ]:
@@ -147,7 +159,7 @@ def load_model(
147159 with (
148160 io .BytesIO (checkpoint_file ) if in_memory else open (checkpoint_file , "rb" )
149161 ) as fp :
150- checkpoint = torch .load (fp , map_location = device )
162+ checkpoint = torch .load (fp , map_location = device , weights_only = True )
151163 del checkpoint_file
152164
153165 dims = ModelDimensions (** checkpoint ["dims" ])
@@ -157,4 +169,4 @@ def load_model(
157169 if alignment_heads is not None :
158170 model .set_alignment_heads (alignment_heads )
159171
160- return model .to (device )
172+ return model .to (device )
0 commit comments