@@ -140,8 +140,19 @@ def __str__(self) -> str:
140
140
"pip" : "*" ,
141
141
"pydoc" : "pipepager" , # pydoc.pipepager('help','echo pwned')
142
142
"timeit" : "*" ,
143
+ "torch._dynamo.guards" : {"GuardBuilder.get" },
143
144
"torch._inductor.codecache" : "compile_file" , # compile_file('', '', ['sh', '-c','$(echo pwned)'])
145
+ "torch.fx.experimental.symbolic_shapes" : {"ShapeEnv.evaluate_guards_expression" },
146
+ "torch.jit.unsupported_tensor_ops" : {"execWrapper" },
144
147
"torch.serialization" : "load" , # pickle could be used to load a different file
148
+ "torch.utils._config_module" : {
149
+ "ConfigModule.load_config"
150
+ }, # allows storing a pickle inside a pickle (if this has valid use cases, scan the input bytes instead of flagging the global)
151
+ "torch.utils.bottleneck.__main__" : {"run_cprofile" },
152
+ "torch.utils.collect_env" : {"run" },
153
+ "torch.utils.data.datapipes.utils.decoder" : {
154
+ "basichandlers"
155
+ }, # allows storing a pickle inside a pickle (if this has valid use cases, scan the input bytes instead of flagging the global)
145
156
"venv" : "*" ,
146
157
"webbrowser" : "*" , # Includes webbrowser.open()
147
158
}
@@ -200,9 +211,7 @@ def _http_get(url) -> bytes:
200
211
_log .debug (f"Request: GET { url } " )
201
212
202
213
parsed_url = urllib .parse .urlparse (url )
203
- path_and_query = parsed_url .path + (
204
- "?" + parsed_url .query if len (parsed_url .query ) > 0 else ""
205
- )
214
+ path_and_query = parsed_url .path + ("?" + parsed_url .query if len (parsed_url .query ) > 0 else "" )
206
215
207
216
conn = http .client .HTTPSConnection (parsed_url .netloc )
208
217
try :
@@ -271,18 +280,14 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str]
271
280
"BINSTRING" ,
272
281
"SHORT_BINSTRING" ,
273
282
]:
274
- _log .debug (
275
- "Presence of non-string opcode, categorizing as an unknown dangerous import"
276
- )
283
+ _log .debug ("Presence of non-string opcode, categorizing as an unknown dangerous import" )
277
284
values .append ("unknown" )
278
285
else :
279
286
values .append (ops [n - offset ][1 ])
280
287
if len (values ) == 2 :
281
288
break
282
289
if len (values ) != 2 :
283
- raise ValueError (
284
- f"Found { len (values )} values for STACK_GLOBAL at position { n } instead of 2."
285
- )
290
+ raise ValueError (f"Found { len (values )} values for STACK_GLOBAL at position { n } instead of 2." )
286
291
globals .add ((values [1 ], values [0 ]))
287
292
288
293
if not multiple_pickles :
@@ -312,17 +317,11 @@ def _build_scan_result_from_raw_globals(
312
317
unsafe_filter = _unsafe_globals .get (g .module )
313
318
if "unknown" in g .module or "unknown" in g .name :
314
319
g .safety = SafetyLevel .Dangerous
315
- _log .warning (
316
- "%s: %s import '%s %s' FOUND" , file_id , g .safety .value , g .module , g .name
317
- )
320
+ _log .warning ("%s: %s import '%s %s' FOUND" , file_id , g .safety .value , g .module , g .name )
318
321
issues_count += 1
319
- elif unsafe_filter is not None and (
320
- unsafe_filter == "*" or g .name in unsafe_filter
321
- ):
322
+ elif unsafe_filter is not None and (unsafe_filter == "*" or g .name in unsafe_filter ):
322
323
g .safety = SafetyLevel .Dangerous
323
- _log .warning (
324
- "%s: %s import '%s %s' FOUND" , file_id , g .safety .value , g .module , g .name
325
- )
324
+ _log .warning ("%s: %s import '%s %s' FOUND" , file_id , g .safety .value , g .module , g .name )
326
325
issues_count += 1
327
326
elif safe_filter is not None and (safe_filter == "*" or g .name in safe_filter ):
328
327
g .safety = SafetyLevel .Innocuous
@@ -341,9 +340,7 @@ def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanRe
341
340
except GenOpsError as e :
342
341
_log .error (f"ERROR: parsing pickle in { file_id } : { e } " )
343
342
if e .globals is not None :
344
- return _build_scan_result_from_raw_globals (
345
- e .globals , file_id , scan_err = True
346
- )
343
+ return _build_scan_result_from_raw_globals (e .globals , file_id , scan_err = True )
347
344
else :
348
345
return ScanResult ([], scan_err = True )
349
346
@@ -357,9 +354,7 @@ def scan_7z_bytes(data: IO[bytes], file_id) -> ScanResult:
357
354
try :
358
355
import py7zr
359
356
except ImportError :
360
- raise Exception (
361
- "py7zr is required to scan 7z archives, install picklescan using: 'pip install picklescan[7z]'"
362
- )
357
+ raise Exception ("py7zr is required to scan 7z archives, install picklescan using: 'pip install picklescan[7z]'" )
363
358
result = ScanResult ([])
364
359
365
360
with py7zr .SevenZipFile (data , mode = "r" ) as archive :
@@ -389,24 +384,18 @@ def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult:
389
384
magic_bytes = file .read (8 )
390
385
file_ext = os .path .splitext (file_name )[1 ]
391
386
392
- if file_ext in _pickle_file_extensions or any (
393
- magic_bytes .startswith (mn ) for mn in _pickle_magic_bytes
394
- ):
387
+ if file_ext in _pickle_file_extensions or any (magic_bytes .startswith (mn ) for mn in _pickle_magic_bytes ):
395
388
_log .debug ("Scanning file %s in zip archive %s" , file_name , file_id )
396
389
with zip .open (file_name , "r" ) as file :
397
390
result .merge (scan_pickle_bytes (file , f"{ file_id } :{ file_name } " ))
398
391
399
- elif file_ext in _numpy_file_extensions or magic_bytes .startswith (
400
- _numpy_magic_bytes
401
- ):
392
+ elif file_ext in _numpy_file_extensions or magic_bytes .startswith (_numpy_magic_bytes ):
402
393
_log .debug ("Scanning file %s in zip archive %s" , file_name , file_id )
403
394
with zip .open (file_name , "r" ) as file :
404
395
result .merge (scan_numpy (file , f"{ file_id } :{ file_name } " ))
405
396
except (zipfile .BadZipFile , RuntimeError ) as e :
406
397
# Log decompression issues (password protected, corrupted, etc.)
407
- _log .warning (
408
- "Invalid file %s in zip archive %s: %s" , file_name , file_id , str (e )
409
- )
398
+ _log .warning ("Invalid file %s in zip archive %s: %s" , file_name , file_id , str (e ))
410
399
411
400
return result
412
401
@@ -491,24 +480,14 @@ def scan_bytes(data: IO[bytes], file_id, file_ext: Optional[str] = None) -> Scan
491
480
492
481
def scan_huggingface_model (repo_id ):
493
482
# List model files
494
- model = json .loads (
495
- _http_get (f"https://huggingface.co/api/models/{ repo_id } " ).decode ("utf-8" )
496
- )
497
- file_names = [
498
- file_name
499
- for file_name in (sibling .get ("rfilename" ) for sibling in model ["siblings" ])
500
- if file_name is not None
501
- ]
483
+ model = json .loads (_http_get (f"https://huggingface.co/api/models/{ repo_id } " ).decode ("utf-8" ))
484
+ file_names = [file_name for file_name in (sibling .get ("rfilename" ) for sibling in model ["siblings" ]) if file_name is not None ]
502
485
503
486
# Scan model files
504
487
scan_result = ScanResult ([])
505
488
for file_name in file_names :
506
489
file_ext = os .path .splitext (file_name )[1 ]
507
- if (
508
- file_ext not in _zip_file_extensions
509
- and file_ext not in _pickle_file_extensions
510
- and file_ext not in _pytorch_file_extensions
511
- ):
490
+ if file_ext not in _zip_file_extensions and file_ext not in _pickle_file_extensions and file_ext not in _pytorch_file_extensions :
512
491
continue
513
492
_log .debug ("Scanning file %s in model %s" , file_name , repo_id )
514
493
url = f"https://huggingface.co/{ repo_id } /resolve/main/{ file_name } "
0 commit comments