Skip to content

Commit 28a7b4e

Browse files
authored
Fix various vulnerabilities (#50)
Fix: - GHSA-f7qq-56ww-84cr - GHSA-mjqp-26hc-grxg - GHSA-jgw4-cr84-mqxg Improve logs
1 parent 1931c2d commit 28a7b4e

File tree

6 files changed

+74
-17
lines changed

6 files changed

+74
-17
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = picklescan
3-
version = 0.0.30
3+
version = 0.0.31
44
author = Matthieu Maitre
55
author_email = [email protected]
66
description = Security scanner detecting Python Pickle files performing suspicious actions

src/picklescan/relaxed_zipfile.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# A more forgiving implementation of zipfile.ZipFile
22
# Modified from Python code at
33
# https://github.com/python/cpython/blob/edb69578ed74ff04ab78ab953355faa343a7e0ee/Lib/zipfile/__init__.py#L1606
4-
# Changes: removed flag/password/filename checks to align better with PyTorch's zip decoding
4+
# Changes: removed flag/password/filename/CRC checks to align better with PyTorch's zip decoding
55

66
import struct
77
import zipfile
@@ -85,7 +85,12 @@ def open(self, name, mode="r", pwd=None, *, force_zip64=False):
8585
if fheader[_FH_EXTRA_FIELD_LENGTH]:
8686
zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
8787

88-
return zipfile.ZipExtFile(zef_file, mode, zinfo, pwd, True)
88+
zef = zipfile.ZipExtFile(zef_file, mode, zinfo, pwd, True)
89+
90+
# Disable CRC validation as PyTorch may not use it
91+
zef._expected_crc = None
92+
93+
return zef
8994
except BaseException:
9095
zef_file.close()
9196
raise

src/picklescan/scanner.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __str__(self) -> str:
116116
"open",
117117
"breakpoint",
118118
}, # Pickle versions 3, 4 have those function under 'builtins'
119-
"aiohttp.client": "*",
119+
"aiohttp": "*",
120120
"asyncio": "*",
121121
"bdb": "*",
122122
"commands": "*", # Python 2 precursor to subprocess
@@ -134,7 +134,6 @@ def __str__(self) -> str:
134134
"ssl": "*", # DNS exfiltration via ssl.get_server_certificate()
135135
"subprocess": "*",
136136
"sys": "*",
137-
"asyncio.unix_events": {"_UnixSubprocessTransport._start"},
138137
"code": {"InteractiveInterpreter.runcode"},
139138
"cProfile": {"runctx", "run"},
140139
"doctest": {"debug_script"},
@@ -257,6 +256,7 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str]
257256
for op in pickletools.genops(data):
258257
ops.append(op)
259258
except Exception as e:
259+
_log.debug(f"Error parsing pickle: {e}", exc_info=True)
260260
parsing_pkl_error = str(e)
261261
last_byte = data.read(1)
262262
data.seek(-1, 1)
@@ -329,6 +329,11 @@ def _build_scan_result_from_raw_globals(
329329
g = Global(rg[0], rg[1], SafetyLevel.Dangerous)
330330
safe_filter = _safe_globals.get(g.module)
331331
unsafe_filter = _unsafe_globals.get(g.module)
332+
333+
# If the module as a whole is marked as dangerous, submodules are also dangerous
334+
if unsafe_filter is None and "." in g.module and _unsafe_globals.get(g.module.split(".")[0]) == "*":
335+
unsafe_filter = "*"
336+
332337
if "unknown" in g.module or "unknown" in g.name:
333338
g.safety = SafetyLevel.Dangerous
334339
_log.warning("%s: %s import '%s %s' FOUND", file_id, g.safety.value, g.module, g.name)
@@ -348,11 +353,12 @@ def _build_scan_result_from_raw_globals(
348353

349354
def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanResult:
350355
"""Disassemble a Pickle stream and report issues"""
356+
_log.debug(f"scan_pickle_bytes({file_id})")
351357

352358
try:
353359
raw_globals = _list_globals(data, multiple_pickles)
354360
except GenOpsError as e:
355-
_log.error(f"ERROR: parsing pickle in {file_id}: {e}")
361+
_log.error(f"ERROR: parsing pickle in {file_id}: {e}", exc_info=_log.isEnabledFor(logging.DEBUG))
356362
if e.globals is not None:
357363
return _build_scan_result_from_raw_globals(e.globals, file_id, scan_err=True)
358364
else:
@@ -365,6 +371,8 @@ def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanRe
365371

366372
# XXX: it appears there is not way to get the byte stream for a given file within the 7z archive and thus forcing us to unzip to disk before scanning
367373
def scan_7z_bytes(data: IO[bytes], file_id) -> ScanResult:
374+
_log.debug(f"scan_7z_bytes({file_id})")
375+
368376
try:
369377
import py7zr
370378
except ImportError:
@@ -387,6 +395,8 @@ def scan_7z_bytes(data: IO[bytes], file_id) -> ScanResult:
387395

388396

389397
def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult:
398+
_log.debug(f"scan_zip_bytes({file_id})")
399+
390400
result = ScanResult([])
391401

392402
with RelaxedZipFile(data, "r") as zip:
@@ -415,6 +425,8 @@ def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult:
415425

416426

417427
def scan_numpy(data: IO[bytes], file_id) -> ScanResult:
428+
_log.debug(f"scan_numpy({file_id})")
429+
418430
# Delay import to avoid dependency on NumPy
419431
import numpy as np
420432

@@ -445,6 +457,8 @@ def scan_numpy(data: IO[bytes], file_id) -> ScanResult:
445457

446458

447459
def scan_pytorch(data: IO[bytes], file_id) -> ScanResult:
460+
_log.debug(f"scan_pytorch({file_id})")
461+
448462
# new pytorch format
449463
if _is_zipfile(data):
450464
return scan_zip_bytes(data, file_id)
@@ -473,26 +487,34 @@ def scan_pytorch(data: IO[bytes], file_id) -> ScanResult:
473487

474488

475489
def scan_bytes(data: IO[bytes], file_id, file_ext: Optional[str] = None) -> ScanResult:
490+
_log.debug(f"scan_bytes({file_id})")
491+
476492
if file_ext is not None and file_ext in _pytorch_file_extensions:
477493
try:
478494
return scan_pytorch(data, file_id)
479495
except InvalidMagicError as e:
480-
_log.error(f"ERROR: Invalid magic number for file {e}")
481-
return ScanResult([], scan_err=True)
482-
elif file_ext is not None and file_ext in _numpy_file_extensions:
496+
_log.warning(
497+
f"WARNING: Invalid PyTorch magic number for file {e}. Trying to scan as non-PyTorch file.",
498+
exc_info=_log.isEnabledFor(logging.DEBUG),
499+
)
500+
data.seek(0)
501+
502+
if file_ext is not None and file_ext in _numpy_file_extensions:
483503
return scan_numpy(data, file_id)
504+
505+
is_zip = zipfile.is_zipfile(data)
506+
data.seek(0)
507+
if is_zip:
508+
return scan_zip_bytes(data, file_id)
509+
elif _is_7z_file(data):
510+
return scan_7z_bytes(data, file_id)
484511
else:
485-
is_zip = zipfile.is_zipfile(data)
486-
data.seek(0)
487-
if is_zip:
488-
return scan_zip_bytes(data, file_id)
489-
elif _is_7z_file(data):
490-
return scan_7z_bytes(data, file_id)
491-
else:
492-
return scan_pickle_bytes(data, file_id)
512+
return scan_pickle_bytes(data, file_id)
493513

494514

495515
def scan_huggingface_model(repo_id):
516+
_log.debug(f"scan_huggingface_model({repo_id})")
517+
496518
# List model files
497519
model = json.loads(_http_get(f"https://huggingface.co/api/models/{repo_id}").decode("utf-8"))
498520
file_names = [file_name for file_name in (sibling.get("rfilename") for sibling in model["siblings"]) if file_name is not None]
@@ -512,6 +534,8 @@ def scan_huggingface_model(repo_id):
512534

513535

514536
def scan_directory_path(path) -> ScanResult:
537+
_log.debug(f"scan_directory_path({path})")
538+
515539
scan_result = ScanResult([])
516540

517541
for base_path, _, file_names in os.walk(path):
@@ -532,10 +556,14 @@ def scan_directory_path(path) -> ScanResult:
532556

533557

534558
def scan_file_path(path) -> ScanResult:
559+
_log.debug(f"scan_file_path({path})")
560+
535561
file_ext = os.path.splitext(path)[1]
536562
with open(path, "rb") as file:
537563
return scan_bytes(file, path, file_ext)
538564

539565

540566
def scan_url(url) -> ScanResult:
567+
_log.debug(f"scan_url({url})")
568+
541569
return scan_bytes(io.BytesIO(_http_get(url)), url)
92 Bytes
Binary file not shown.

tests/data2/malicious1_crc.zip

165 Bytes
Binary file not shown.

tests/test_scanner.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,21 @@ def initialize_corrupt_zip_file_central_directory(path: str, file_name: str, dat
438438
f.write(modified_data)
439439

440440

441+
def initialize_corrupt_zip_file_crc(path: str, file_name: str, data: bytes):
442+
if not os.path.exists(path):
443+
with io.BytesIO() as buffer:
444+
with zipfile.ZipFile(buffer, "w") as zip:
445+
zip.writestr(file_name, data)
446+
data = buffer.getbuffer().tobytes()
447+
448+
# Corrupt the data, leading to a CRC mismatch
449+
modified_data = data.replace(b"print('456')", b"print('123')", 1)
450+
451+
# Write the corrupted content
452+
with open(path, "wb") as f:
453+
f.write(modified_data)
454+
455+
441456
def initialize_numpy_files():
442457
import numpy as np
443458

@@ -687,6 +702,12 @@ def initialize_pickle_files():
687702
pickle.dumps(Malicious1(), protocol=4),
688703
)
689704

705+
initialize_corrupt_zip_file_crc(
706+
f"{_root_path}/data2/malicious1_crc.zip",
707+
"data.pkl",
708+
pickle.dumps(Malicious1(), protocol=4),
709+
)
710+
690711
initialize_zip_file(
691712
f"{_root_path}/data/malicious1_wrong_ext.zip",
692713
"data.txt", # Pickle file with a non-standard extension
@@ -744,6 +765,7 @@ def initialize_pickle_files():
744765
initialize_pickle_file_from_reduce("GHSA-9w88-8rmg-7g2p.pkl", reduce_GHSA_9w88_8rmg_7g2p)
745766
initialize_pickle_file_from_reduce("GHSA-49gj-c84q-6qm9.pkl", reduce_GHSA_49gj_c84q_6qm9)
746767
initialize_pickle_file_from_reduce("GHSA-q77w-mwjj-7mqx.pkl", reduce_GHSA_q77w_mwjj_7mqx)
768+
initialize_pickle_file_from_reduce("GHSA-jgw4-cr84-mqxg.bin", reduce_GHSA_q77w_mwjj_7mqx)
747769

748770

749771
initialize_pickle_files()
@@ -1022,6 +1044,8 @@ def test_scan_file_path():
10221044
assert_scan("GHSA-9w88-8rmg-7g2p.pkl", [Global("cProfile", "runctx", SafetyLevel.Dangerous)])
10231045
assert_scan("GHSA-49gj-c84q-6qm9.pkl", [Global("cProfile", "run", SafetyLevel.Dangerous)])
10241046
assert_scan("GHSA-q77w-mwjj-7mqx.pkl", [Global("asyncio.unix_events", "_UnixSubprocessTransport._start", SafetyLevel.Dangerous)])
1047+
assert_scan("GHSA-jgw4-cr84-mqxg.bin", [Global("asyncio.unix_events", "_UnixSubprocessTransport._start", SafetyLevel.Dangerous)])
1048+
assert_scan("malicious1_crc.zip", [Global("builtins", name="eval", safety=SafetyLevel.Dangerous)])
10251049

10261050

10271051
def test_scan_file_path_npz():

0 commit comments

Comments
 (0)