@@ -114,7 +114,8 @@ def create_dataloader(path,
114114                      image_weights = False ,
115115                      quad = False ,
116116                      prefix = '' ,
117-                       shuffle = False ):
117+                       shuffle = False ,
118+                       size_conf = None ):
118119    if  rect  and  shuffle :
119120        LOGGER .warning ('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False' )
120121        shuffle  =  False 
@@ -131,7 +132,8 @@ def create_dataloader(path,
131132            stride = int (stride ),
132133            pad = pad ,
133134            image_weights = image_weights ,
134-             prefix = prefix )
135+             prefix = prefix ,
136+             size_conf = size_conf )
135137
136138    batch_size  =  min (batch_size , len (dataset ))
137139    nd  =  torch .cuda .device_count ()  # number of CUDA devices 
@@ -393,7 +395,8 @@ def __init__(self,
393395                 single_cls = False ,
394396                 stride = 32 ,
395397                 pad = 0.0 ,
396-                  prefix = '' ):
398+                  prefix = '' ,
399+                  size_conf = None ):
397400        self .img_size  =  img_size 
398401        self .augment  =  augment 
399402        self .hyp  =  hyp 
@@ -430,11 +433,12 @@ def __init__(self,
430433        self .label_files  =  img2label_paths (self .im_files )  # labels 
431434        cache_path  =  (p  if  p .is_file () else  Path (self .label_files [0 ]).parent ).with_suffix ('.cache' )
432435        try :
436+             cache_path .unlink ()  # remove old cache 
433437            cache , exists  =  np .load (cache_path , allow_pickle = True ).item (), True   # load dict 
434438            assert  cache ['version' ] ==  self .cache_version   # matches current version 
435439            assert  cache ['hash' ] ==  get_hash (self .label_files  +  self .im_files )  # identical hash 
436440        except  Exception :
437-             cache , exists  =  self .cache_labels (cache_path , prefix ), False   # run cache ops 
441+             cache , exists  =  self .cache_labels (cache_path , prefix ,  size_conf ), False   # run cache ops 
438442
439443        # Display cache 
440444        nf , nm , ne , nc , n  =  cache .pop ('results' )  # found, missing, empty, corrupt, total 
@@ -517,13 +521,14 @@ def __init__(self,
517521                pbar .desc  =  f'{ prefix } { gb  /  1E9 :.1f} { cache_images }  
518522            pbar .close ()
519523
520-     def  cache_labels (self , path = Path ('./labels.cache' ), prefix = '' ):
524+     def  cache_labels (self , path = Path ('./labels.cache' ), prefix = '' ,  size_conf = None ):
521525        # Cache dataset labels, check images and read shapes 
522526        x  =  {}  # dict 
523527        nm , nf , ne , nc , msgs  =  0 , 0 , 0 , 0 , []  # number missing, found, empty, corrupt, messages 
524528        desc  =  f"{ prefix } { path .parent  /  path .stem }  
525529        with  Pool (NUM_THREADS ) as  pool :
526-             pbar  =  tqdm (pool .imap (verify_image_label , zip (self .im_files , self .label_files , repeat (prefix ))),
530+             pbar  =  tqdm (pool .imap (verify_image_label ,
531+                                   zip (self .im_files , self .label_files , repeat (prefix ), repeat (size_conf ))),
527532                        desc = desc ,
528533                        total = len (self .im_files ),
529534                        bar_format = BAR_FORMAT )
@@ -902,7 +907,7 @@ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), ann
902907
903908def  verify_image_label (args ):
904909    # Verify one image-label pair 
905-     im_file , lb_file , prefix  =  args 
910+     im_file , lb_file , prefix ,  size_conf  =  args 
906911    nm , nf , ne , nc , msg , segments  =  0 , 0 , 0 , 0 , '' , []  # number (missing, found, empty, corrupt), message, segments 
907912    try :
908913        # verify images 
@@ -928,6 +933,14 @@ def verify_image_label(args):
928933                    segments  =  [np .array (x [1 :], dtype = np .float32 ).reshape (- 1 , 2 ) for  x  in  lb ]  # (cls, xy1...) 
929934                    lb  =  np .concatenate ((classes .reshape (- 1 , 1 ), segments2boxes (segments )), 1 )  # (cls, xywh) 
930935                lb  =  np .array (lb , dtype = np .float32 )
936+                 if  size_conf  is  not None :
937+                     size_thres  =  np .array ([size_conf [int (i )] for  i  in  lb [:, 0 ]], dtype = np .float32 ) **  2 
938+                     areas  =  (lb [:, 3 :] *  np .array (shape , dtype = np .float32 )).prod (1 )
939+                     idx  =  (areas  >  size_thres [:, 0 ]) &  (areas  <=  size_thres [:, 1 ])
940+                     if  idx .any ():
941+                         lb  =  lb [idx .T ]
942+                     else :
943+                         lb  =  np .zeros ((0 , 5 ), dtype = np .float32 )
931944            nl  =  len (lb )
932945            if  nl :
933946                assert  lb .shape [1 ] ==  5 , f'labels require 5 columns, { lb .shape [1 ]}  
0 commit comments