@@ -211,3 +211,80 @@ def open_folder(path):
211
211
subprocess .Popen (["explorer.exe" , subprocess .check_output (["wslpath" , "-w" , path ])])
212
212
else :
213
213
subprocess .Popen (["xdg-open" , path ])
214
+
215
+
216
+ def load_file_from_url (
217
+ url : str ,
218
+ * ,
219
+ model_dir : str ,
220
+ progress : bool = True ,
221
+ file_name : str | None = None ,
222
+ hash_prefix : str | None = None ,
223
+ re_download : bool = False ,
224
+ ) -> str :
225
+ """Download a file from `url` into `model_dir`, using the file present if possible.
226
+ Returns the path to the downloaded file.
227
+
228
+ file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
229
+ file is downloaded to {file_name}.tmp then moved to the final location after download is complete.
230
+ hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix.
231
+ if the hash does not match, the temporary file is deleted and a ValueError is raised.
232
+ re_download: forcibly re-download the file even if it already exists.
233
+ """
234
+ from urllib .parse import urlparse
235
+ import requests
236
+ try :
237
+ from tqdm import tqdm
238
+ except ImportError :
239
+ class tqdm :
240
+ def __init__ (self , * args , ** kwargs ):
241
+ pass
242
+
243
+ def update (self , n = 1 , * args , ** kwargs ):
244
+ pass
245
+
246
+ def __enter__ (self ):
247
+ return self
248
+
249
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
250
+ pass
251
+
252
+ if not file_name :
253
+ parts = urlparse (url )
254
+ file_name = os .path .basename (parts .path )
255
+
256
+ cached_file = os .path .abspath (os .path .join (model_dir , file_name ))
257
+
258
+ if re_download or not os .path .exists (cached_file ):
259
+ os .makedirs (model_dir , exist_ok = True )
260
+ temp_file = os .path .join (model_dir , f"{ file_name } .tmp" )
261
+ print (f'\n Downloading: "{ url } " to { cached_file } ' )
262
+ response = requests .get (url , stream = True )
263
+ response .raise_for_status ()
264
+ total_size = int (response .headers .get ('content-length' , 0 ))
265
+ with tqdm (total = total_size , unit = 'B' , unit_scale = True , desc = file_name , disable = not progress ) as progress_bar :
266
+ with open (temp_file , 'wb' ) as file :
267
+ for chunk in response .iter_content (chunk_size = 1024 ):
268
+ if chunk :
269
+ file .write (chunk )
270
+ progress_bar .update (len (chunk ))
271
+
272
+ if hash_prefix and not compare_sha256 (temp_file , hash_prefix ):
273
+ print (f"Hash mismatch for { temp_file } . Deleting the temporary file." )
274
+ os .remove (temp_file )
275
+ raise ValueError (f"File hash does not match the expected hash prefix { hash_prefix } !" )
276
+
277
+ os .rename (temp_file , cached_file )
278
+ return cached_file
279
+
280
+
281
+ def compare_sha256 (file_path : str , hash_prefix : str ) -> bool :
282
+ """Check if the SHA256 hash of the file matches the given prefix."""
283
+ import hashlib
284
+ hash_sha256 = hashlib .sha256 ()
285
+ blksize = 1024 * 1024
286
+
287
+ with open (file_path , "rb" ) as f :
288
+ for chunk in iter (lambda : f .read (blksize ), b"" ):
289
+ hash_sha256 .update (chunk )
290
+ return hash_sha256 .hexdigest ().startswith (hash_prefix .strip ().lower ())
0 commit comments