@@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference(
366
366
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n " # noqa: E501
367
367
368
368
369
+ def enable_tqdm (use_tqdm_on_load : bool ):
370
+ return use_tqdm_on_load and (not torch .distributed .is_initialized ()
371
+ or torch .distributed .get_rank () == 0 )
372
+
373
+
369
374
def np_cache_weights_iterator (
370
- model_name_or_path : str , cache_dir : Optional [str ], hf_folder : str ,
371
- hf_weights_files : List [str ]
375
+ model_name_or_path : str ,
376
+ cache_dir : Optional [str ],
377
+ hf_folder : str ,
378
+ hf_weights_files : List [str ],
379
+ use_tqdm_on_load : bool ,
372
380
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
373
381
"""Iterate over the weights in the model np files.
374
382
375
383
Will dump the model weights to numpy files if they are not already dumped.
376
384
"""
377
- enable_tqdm = not torch .distributed .is_initialized (
378
- ) or torch .distributed .get_rank () == 0
379
385
# Convert the model weights from torch tensors to numpy arrays for
380
386
# faster loading.
381
387
np_folder = os .path .join (hf_folder , "np" )
@@ -389,7 +395,7 @@ def np_cache_weights_iterator(
389
395
for bin_file in tqdm (
390
396
hf_weights_files ,
391
397
desc = "Loading np_cache checkpoint shards" ,
392
- disable = not enable_tqdm ,
398
+ disable = not enable_tqdm ( use_tqdm_on_load ) ,
393
399
bar_format = _BAR_FORMAT ,
394
400
):
395
401
state = torch .load (bin_file ,
@@ -414,15 +420,14 @@ def np_cache_weights_iterator(
414
420
415
421
416
422
def safetensors_weights_iterator (
417
- hf_weights_files : List [str ]
423
+ hf_weights_files : List [str ],
424
+ use_tqdm_on_load : bool ,
418
425
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
419
426
"""Iterate over the weights in the model safetensor files."""
420
- enable_tqdm = not torch .distributed .is_initialized (
421
- ) or torch .distributed .get_rank () == 0
422
427
for st_file in tqdm (
423
428
hf_weights_files ,
424
429
desc = "Loading safetensors checkpoint shards" ,
425
- disable = not enable_tqdm ,
430
+ disable = not enable_tqdm ( use_tqdm_on_load ) ,
426
431
bar_format = _BAR_FORMAT ,
427
432
):
428
433
with safe_open (st_file , framework = "pt" ) as f :
@@ -432,32 +437,30 @@ def safetensors_weights_iterator(
432
437
433
438
434
439
def runai_safetensors_weights_iterator (
435
- hf_weights_files : List [str ]
440
+ hf_weights_files : List [str ],
441
+ use_tqdm_on_load : bool ,
436
442
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
437
443
"""Iterate over the weights in the model safetensor files."""
438
- enable_tqdm = not torch .distributed .is_initialized (
439
- ) or torch .distributed .get_rank () == 0
440
444
with SafetensorsStreamer () as streamer :
441
445
for st_file in tqdm (
442
446
hf_weights_files ,
443
447
desc = "Loading safetensors using Runai Model Streamer" ,
444
- disable = not enable_tqdm ,
448
+ disable = not enable_tqdm ( use_tqdm_on_load ) ,
445
449
bar_format = _BAR_FORMAT ,
446
450
):
447
451
streamer .stream_file (st_file )
448
452
yield from streamer .get_tensors ()
449
453
450
454
451
455
def pt_weights_iterator (
452
- hf_weights_files : List [str ]
456
+ hf_weights_files : List [str ],
457
+ use_tqdm_on_load : bool ,
453
458
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
454
459
"""Iterate over the weights in the model bin/pt files."""
455
- enable_tqdm = not torch .distributed .is_initialized (
456
- ) or torch .distributed .get_rank () == 0
457
460
for bin_file in tqdm (
458
461
hf_weights_files ,
459
462
desc = "Loading pt checkpoint shards" ,
460
- disable = not enable_tqdm ,
463
+ disable = not enable_tqdm ( use_tqdm_on_load ) ,
461
464
bar_format = _BAR_FORMAT ,
462
465
):
463
466
state = torch .load (bin_file , map_location = "cpu" , weights_only = True )
0 commit comments