Skip to content

Commit 7f994d6

Browse files
authored
Missing detection when calling PyTorch functions (#47)
Address security advisories
1 parent 58983e1 commit 7f994d6

19 files changed

+270
-248
lines changed

.devcontainer/Dockerfile

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
FROM mcr.microsoft.com/azurelinux/base/python:3.12
2+
3+
# Install Dev Container dependencies
4+
RUN tdnf install -y git tar awk && tdnf clean all
5+
6+
# Remove root access
7+
USER nonroot
8+
WORKDIR /home/nonroot
9+
10+
# Install Python packages
11+
# Create a virtual environment to avoid access-denied issues when running pip install
12+
ENV PYTHONUNBUFFERED=1
13+
ENV VIRTUAL_ENV=/home/nonroot/.venv
14+
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
15+
COPY requirements.txt .
16+
RUN python3 -m venv .venv && .venv/bin/pip install --disable-pip-version-check --no-cache-dir -r requirements.txt

.devcontainer/devcontainer.json

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"name": "Python",
3+
"build": {
4+
"dockerfile": "Dockerfile",
5+
"context": ".."
6+
},
7+
"customizations": {
8+
"vscode": {
9+
"extensions": [
10+
"ms-python.python"
11+
],
12+
"settings": {
13+
"chat.tools.autoApprove": true,
14+
"python.defaultInterpreterPath": "/home/nonroot/.venv/bin/python3",
15+
"python.selectInterpreter": "/home/nonroot/.venv/bin/python3"
16+
}
17+
}
18+
}
19+
}

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
python -m pip install -r requirements.txt
2424
python -m pip install -e '.[7z]'
2525
- name: Check code format
26-
run: black --check src tests
26+
run: python -m black src tests --check --line-length 140
2727
- name: Lint with flake8
2828
run: python -m flake8 . --count --show-source --statistics
2929
- name: Test with pytest

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ picklescan -l DEBUG -u https://huggingface.co/prajjwal1/bert-tiny/resolve/main/p
8585

8686
Lint the code:
8787
```
88-
black src tests
88+
black src tests --line-length 140
8989
flake8 src tests --count --show-source
9090
```
9191

conda.extras.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
name: picklescan-extras
2+
channels:
3+
- nodefaults
4+
dependencies:
5+
- python=3.9
6+
- pip
7+
- pip:
8+
- -r requirements_extras.txt

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ pytest-cov==3.0.0
55
requests==2.31.0
66
aiohttp==3.9.1
77
black==22.8.0
8-
numpy>1.24.0
8+
numpy>1.24.0,<2.0.0
99
py7zr==0.22.0

requirements_extras.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Add packages needed to generate tests
2+
-r requirements.txt
3+
torch==2.8.0

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = picklescan
3-
version = 0.0.27
3+
version = 0.0.28
44
author = Matthieu Maitre
55
author_email = [email protected]
66
description = Security scanner detecting Python Pickle files performing suspicious actions

src/picklescan/cli.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,17 @@ def main():
2828
_log.setLevel(logging.INFO)
2929
_log.addHandler(logging.StreamHandler(stream=sys.stdout))
3030

31-
parser = argparse.ArgumentParser(
32-
description="Security scanner detecting Python Pickle files performing suspicious actions."
33-
)
31+
parser = argparse.ArgumentParser(description="Security scanner detecting Python Pickle files performing suspicious actions.")
3432
group = parser.add_mutually_exclusive_group()
35-
group.add_argument(
36-
"-p", "--path", help="Path to the file or folder to scan", dest="path"
37-
)
38-
group.add_argument(
39-
"-u", "--url", help="URL to the file or folder to scan", dest="url"
40-
)
33+
group.add_argument("-p", "--path", help="Path to the file or folder to scan", dest="path")
34+
group.add_argument("-u", "--url", help="URL to the file or folder to scan", dest="url")
4135
group.add_argument(
4236
"-hf",
4337
"--huggingface",
4438
help="Name of the Hugging Face model to scan",
4539
dest="huggingface_model",
4640
)
47-
parser.add_argument(
48-
"-g", "--globals", help="list all globals found", action="store_true"
49-
)
41+
parser.add_argument("-g", "--globals", help="list all globals found", action="store_true")
5042
parser.set_defaults(globals=False)
5143
parser.add_argument(
5244
"-l",
@@ -75,9 +67,7 @@ def main():
7567
elif args.huggingface_model is not None:
7668
scan_result = scan_huggingface_model(args.huggingface_model)
7769
else:
78-
raise ValueError(
79-
"Command line must include either a path, a URL, or a Hugging Face model"
80-
)
70+
raise ValueError("Command line must include either a path, a URL, or a Hugging Face model")
8171

8272
print_summary(args.globals, scan_result)
8373

src/picklescan/scanner.py

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,19 @@ def __str__(self) -> str:
140140
"pip": "*",
141141
"pydoc": "pipepager", # pydoc.pipepager('help','echo pwned')
142142
"timeit": "*",
143+
"torch._dynamo.guards": {"GuardBuilder.get"},
143144
"torch._inductor.codecache": "compile_file", # compile_file('', '', ['sh', '-c','$(echo pwned)'])
145+
"torch.fx.experimental.symbolic_shapes": {"ShapeEnv.evaluate_guards_expression"},
146+
"torch.jit.unsupported_tensor_ops": {"execWrapper"},
144147
"torch.serialization": "load", # pickle could be used to load a different file
148+
"torch.utils._config_module": {
149+
"ConfigModule.load_config"
150+
}, # allows storing a pickle inside a pickle (if this has valid use cases, scan the input bytes instead of flagging the global)
151+
"torch.utils.bottleneck.__main__": {"run_cprofile"},
152+
"torch.utils.collect_env": {"run"},
153+
"torch.utils.data.datapipes.utils.decoder": {
154+
"basichandlers"
155+
}, # allows storing a pickle inside a pickle (if this has valid use cases, scan the input bytes instead of flagging the global)
145156
"venv": "*",
146157
"webbrowser": "*", # Includes webbrowser.open()
147158
}
@@ -200,9 +211,7 @@ def _http_get(url) -> bytes:
200211
_log.debug(f"Request: GET {url}")
201212

202213
parsed_url = urllib.parse.urlparse(url)
203-
path_and_query = parsed_url.path + (
204-
"?" + parsed_url.query if len(parsed_url.query) > 0 else ""
205-
)
214+
path_and_query = parsed_url.path + ("?" + parsed_url.query if len(parsed_url.query) > 0 else "")
206215

207216
conn = http.client.HTTPSConnection(parsed_url.netloc)
208217
try:
@@ -271,18 +280,14 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str]
271280
"BINSTRING",
272281
"SHORT_BINSTRING",
273282
]:
274-
_log.debug(
275-
"Presence of non-string opcode, categorizing as an unknown dangerous import"
276-
)
283+
_log.debug("Presence of non-string opcode, categorizing as an unknown dangerous import")
277284
values.append("unknown")
278285
else:
279286
values.append(ops[n - offset][1])
280287
if len(values) == 2:
281288
break
282289
if len(values) != 2:
283-
raise ValueError(
284-
f"Found {len(values)} values for STACK_GLOBAL at position {n} instead of 2."
285-
)
290+
raise ValueError(f"Found {len(values)} values for STACK_GLOBAL at position {n} instead of 2.")
286291
globals.add((values[1], values[0]))
287292

288293
if not multiple_pickles:
@@ -312,17 +317,11 @@ def _build_scan_result_from_raw_globals(
312317
unsafe_filter = _unsafe_globals.get(g.module)
313318
if "unknown" in g.module or "unknown" in g.name:
314319
g.safety = SafetyLevel.Dangerous
315-
_log.warning(
316-
"%s: %s import '%s %s' FOUND", file_id, g.safety.value, g.module, g.name
317-
)
320+
_log.warning("%s: %s import '%s %s' FOUND", file_id, g.safety.value, g.module, g.name)
318321
issues_count += 1
319-
elif unsafe_filter is not None and (
320-
unsafe_filter == "*" or g.name in unsafe_filter
321-
):
322+
elif unsafe_filter is not None and (unsafe_filter == "*" or g.name in unsafe_filter):
322323
g.safety = SafetyLevel.Dangerous
323-
_log.warning(
324-
"%s: %s import '%s %s' FOUND", file_id, g.safety.value, g.module, g.name
325-
)
324+
_log.warning("%s: %s import '%s %s' FOUND", file_id, g.safety.value, g.module, g.name)
326325
issues_count += 1
327326
elif safe_filter is not None and (safe_filter == "*" or g.name in safe_filter):
328327
g.safety = SafetyLevel.Innocuous
@@ -341,9 +340,7 @@ def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanRe
341340
except GenOpsError as e:
342341
_log.error(f"ERROR: parsing pickle in {file_id}: {e}")
343342
if e.globals is not None:
344-
return _build_scan_result_from_raw_globals(
345-
e.globals, file_id, scan_err=True
346-
)
343+
return _build_scan_result_from_raw_globals(e.globals, file_id, scan_err=True)
347344
else:
348345
return ScanResult([], scan_err=True)
349346

@@ -357,9 +354,7 @@ def scan_7z_bytes(data: IO[bytes], file_id) -> ScanResult:
357354
try:
358355
import py7zr
359356
except ImportError:
360-
raise Exception(
361-
"py7zr is required to scan 7z archives, install picklescan using: 'pip install picklescan[7z]'"
362-
)
357+
raise Exception("py7zr is required to scan 7z archives, install picklescan using: 'pip install picklescan[7z]'")
363358
result = ScanResult([])
364359

365360
with py7zr.SevenZipFile(data, mode="r") as archive:
@@ -389,24 +384,18 @@ def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult:
389384
magic_bytes = file.read(8)
390385
file_ext = os.path.splitext(file_name)[1]
391386

392-
if file_ext in _pickle_file_extensions or any(
393-
magic_bytes.startswith(mn) for mn in _pickle_magic_bytes
394-
):
387+
if file_ext in _pickle_file_extensions or any(magic_bytes.startswith(mn) for mn in _pickle_magic_bytes):
395388
_log.debug("Scanning file %s in zip archive %s", file_name, file_id)
396389
with zip.open(file_name, "r") as file:
397390
result.merge(scan_pickle_bytes(file, f"{file_id}:{file_name}"))
398391

399-
elif file_ext in _numpy_file_extensions or magic_bytes.startswith(
400-
_numpy_magic_bytes
401-
):
392+
elif file_ext in _numpy_file_extensions or magic_bytes.startswith(_numpy_magic_bytes):
402393
_log.debug("Scanning file %s in zip archive %s", file_name, file_id)
403394
with zip.open(file_name, "r") as file:
404395
result.merge(scan_numpy(file, f"{file_id}:{file_name}"))
405396
except (zipfile.BadZipFile, RuntimeError) as e:
406397
# Log decompression issues (password protected, corrupted, etc.)
407-
_log.warning(
408-
"Invalid file %s in zip archive %s: %s", file_name, file_id, str(e)
409-
)
398+
_log.warning("Invalid file %s in zip archive %s: %s", file_name, file_id, str(e))
410399

411400
return result
412401

@@ -491,24 +480,14 @@ def scan_bytes(data: IO[bytes], file_id, file_ext: Optional[str] = None) -> Scan
491480

492481
def scan_huggingface_model(repo_id):
493482
# List model files
494-
model = json.loads(
495-
_http_get(f"https://huggingface.co/api/models/{repo_id}").decode("utf-8")
496-
)
497-
file_names = [
498-
file_name
499-
for file_name in (sibling.get("rfilename") for sibling in model["siblings"])
500-
if file_name is not None
501-
]
483+
model = json.loads(_http_get(f"https://huggingface.co/api/models/{repo_id}").decode("utf-8"))
484+
file_names = [file_name for file_name in (sibling.get("rfilename") for sibling in model["siblings"]) if file_name is not None]
502485

503486
# Scan model files
504487
scan_result = ScanResult([])
505488
for file_name in file_names:
506489
file_ext = os.path.splitext(file_name)[1]
507-
if (
508-
file_ext not in _zip_file_extensions
509-
and file_ext not in _pickle_file_extensions
510-
and file_ext not in _pytorch_file_extensions
511-
):
490+
if file_ext not in _zip_file_extensions and file_ext not in _pickle_file_extensions and file_ext not in _pytorch_file_extensions:
512491
continue
513492
_log.debug("Scanning file %s in model %s", file_name, repo_id)
514493
url = f"https://huggingface.co/{repo_id}/resolve/main/{file_name}"

0 commit comments

Comments
 (0)