Skip to content

Commit 914b956

Browse files
authored
remove pickle from serializer to be safe (#1358)
1 parent de34d6c commit 914b956

File tree

9 files changed

+341
-73
lines changed

9 files changed

+341
-73
lines changed

src/datachain/catalog/loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from importlib import import_module
44
from typing import TYPE_CHECKING, Any, Optional
55

6+
from datachain.plugins import ensure_plugins_loaded
67
from datachain.utils import get_envs_by_prefix
78

89
if TYPE_CHECKING:
@@ -24,6 +25,8 @@
2425

2526

2627
def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
28+
ensure_plugins_loaded()
29+
2730
from datachain.data_storage import AbstractMetastore
2831
from datachain.data_storage.serializer import deserialize
2932

@@ -64,6 +67,8 @@ def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
6467

6568

6669
def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
70+
ensure_plugins_loaded()
71+
6772
from datachain.data_storage import AbstractWarehouse
6873
from datachain.data_storage.serializer import deserialize
6974

Lines changed: 105 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,119 @@
11
import base64
2-
import pickle
2+
import json
33
from abc import abstractmethod
44
from collections.abc import Callable
5-
from typing import Any
5+
from typing import Any, ClassVar
6+
7+
from datachain.plugins import ensure_plugins_loaded
8+
9+
10+
class CallableRegistry:
11+
_registry: ClassVar[dict[str, Callable]] = {}
12+
13+
@classmethod
14+
def register(cls, callable_obj: Callable, name: str) -> str:
15+
cls._registry[name] = callable_obj
16+
return name
17+
18+
@classmethod
19+
def get(cls, name: str) -> Callable:
20+
return cls._registry[name]
621

722

823
class Serializable:
24+
@classmethod
25+
@abstractmethod
26+
def serialize_callable_name(cls) -> str:
27+
"""Return the registered name used for this class' factory callable."""
28+
929
@abstractmethod
1030
def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
11-
"""
12-
Returns the class, args, and kwargs needed to instantiate a cloned copy
13-
of this instance for use in separate processes or machines.
14-
"""
31+
"""Return (callable, args, kwargs) necessary to recreate this object."""
32+
33+
def _prepare(self, params: tuple) -> dict:
34+
callable, args, kwargs = params
35+
callable_name = callable.__self__.serialize_callable_name()
36+
return {
37+
"callable": callable_name,
38+
"args": args,
39+
"kwargs": {
40+
k: self._prepare(v) if isinstance(v, tuple) else v
41+
for k, v in kwargs.items()
42+
},
43+
}
1544

1645
def serialize(self) -> str:
17-
"""
18-
Returns a string representation of clone params.
19-
This is useful for storing the state of an object in environment variable.
20-
"""
21-
return base64.b64encode(pickle.dumps(self.clone_params())).decode()
46+
"""Return a base64-encoded JSON string with registered callable + params."""
47+
_ensure_default_callables_registered()
48+
data = self.clone_params()
49+
return base64.b64encode(json.dumps(self._prepare(data)).encode()).decode()
2250

2351

2452
def deserialize(s: str) -> Serializable:
53+
"""Deserialize from base64-encoded JSON using only registered callables.
54+
55+
Nested serialized objects are instantiated automatically except for those
56+
passed via clone parameter tuples (keys ending with ``_clone_params``),
57+
which must remain as (callable, args, kwargs) for later factory usage.
2558
"""
26-
Returns a new instance of the class represented by the string.
27-
"""
28-
(f, args, kwargs) = pickle.loads(base64.b64decode(s.encode())) # noqa: S301
29-
return f(*args, **kwargs)
59+
ensure_plugins_loaded()
60+
_ensure_default_callables_registered()
61+
decoded = base64.b64decode(s.encode())
62+
data = json.loads(decoded.decode())
63+
64+
def _is_serialized(obj: Any) -> bool:
65+
return isinstance(obj, dict) and {"callable", "args", "kwargs"}.issubset(
66+
obj.keys()
67+
)
68+
69+
def _reconstruct(obj: Any, nested: bool = False) -> Any:
70+
if not _is_serialized(obj):
71+
return obj
72+
callable_name: str = obj["callable"]
73+
args: list[Any] = obj["args"]
74+
kwargs: dict[str, Any] = obj["kwargs"]
75+
# Recurse only inside kwargs because serialize() only nests through kwargs
76+
for k, v in list(kwargs.items()):
77+
if _is_serialized(v):
78+
kwargs[k] = _reconstruct(v, True)
79+
callable_obj = CallableRegistry.get(callable_name)
80+
if nested:
81+
return (callable_obj, args, kwargs)
82+
# Otherwise instantiate
83+
return callable_obj(*args, **kwargs)
84+
85+
if not _is_serialized(data):
86+
raise ValueError("Invalid serialized data format")
87+
return _reconstruct(data, False)
88+
89+
90+
class _DefaultsState:
91+
registered = False
92+
93+
94+
def _ensure_default_callables_registered() -> None:
95+
if _DefaultsState.registered:
96+
return
97+
98+
from datachain.data_storage.sqlite import (
99+
SQLiteDatabaseEngine,
100+
SQLiteMetastore,
101+
SQLiteWarehouse,
102+
)
103+
104+
# Register (idempotent by name overwrite is fine) using class-level
105+
# serialization names to avoid hard-coded literals here.
106+
CallableRegistry.register(
107+
SQLiteDatabaseEngine.from_db_file,
108+
SQLiteDatabaseEngine.serialize_callable_name(),
109+
)
110+
CallableRegistry.register(
111+
SQLiteMetastore.init_after_clone,
112+
SQLiteMetastore.serialize_callable_name(),
113+
)
114+
CallableRegistry.register(
115+
SQLiteWarehouse.init_after_clone,
116+
SQLiteWarehouse.serialize_callable_name(),
117+
)
118+
119+
_DefaultsState.registered = True

src/datachain/data_storage/sqlite.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,14 @@ def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
201201
"""
202202
return (
203203
SQLiteDatabaseEngine.from_db_file,
204-
[self.db_file],
204+
[str(self.db_file)],
205205
{},
206206
)
207207

208+
@classmethod
209+
def serialize_callable_name(cls) -> str:
210+
return "sqlite.from_db_file"
211+
208212
def _reconnect(self) -> None:
209213
if not self.is_closed:
210214
raise RuntimeError("Cannot reconnect on still-open DB!")
@@ -403,6 +407,10 @@ def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
403407
},
404408
)
405409

410+
@classmethod
411+
def serialize_callable_name(cls) -> str:
412+
return "sqlite.metastore.init_after_clone"
413+
406414
@classmethod
407415
def init_after_clone(
408416
cls,
@@ -610,6 +618,10 @@ def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
610618
{"db_clone_params": self.db.clone_params()},
611619
)
612620

621+
@classmethod
622+
def serialize_callable_name(cls) -> str:
623+
return "sqlite.warehouse.init_after_clone"
624+
613625
@classmethod
614626
def init_after_clone(
615627
cls,

src/datachain/plugins.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Plugin loader for DataChain callables.
2+
3+
Discovers and invokes entry points in the group "datachain.callables" once
4+
per process. This enables external packages (e.g., Studio) to register
5+
their callables with the serializer registry without explicit imports.
6+
"""
7+
8+
from importlib import metadata as importlib_metadata
9+
10+
_plugins_loaded = False
11+
12+
13+
def ensure_plugins_loaded() -> None:
14+
global _plugins_loaded # noqa: PLW0603
15+
if _plugins_loaded:
16+
return
17+
18+
# Compatible across importlib.metadata versions
19+
eps_obj = importlib_metadata.entry_points()
20+
if hasattr(eps_obj, "select"):
21+
eps_list = eps_obj.select(group="datachain.callables")
22+
else:
23+
# Compatibility for older versions of importlib_metadata, Python 3.9
24+
eps_list = eps_obj.get("datachain.callables", []) # type: ignore[attr-defined]
25+
26+
for ep in eps_list:
27+
func = ep.load()
28+
func()
29+
30+
_plugins_loaded = True

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def clean_environment(
126126
working_dir = str(tmp_path_factory.mktemp("default_working_dir"))
127127
monkeypatch_session.chdir(working_dir)
128128
monkeypatch_session.delenv(DataChainDir.ENV_VAR, raising=False)
129+
monkeypatch_session.delenv(DataChainDir.ENV_VAR_DATACHAIN_ROOT, raising=False)
129130

130131

131132
@pytest.fixture

tests/unit/test_database_engine.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import base64
2+
import json
23
import os
3-
import pickle
44

55
import pytest
66
from sqlalchemy import Column, Integer, Table
77

88
from datachain.data_storage.serializer import deserialize
9-
from datachain.data_storage.sqlite import SQLiteDatabaseEngine, get_db_file_in_memory
9+
from datachain.data_storage.sqlite import (
10+
SQLiteDatabaseEngine,
11+
get_db_file_in_memory,
12+
)
1013
from tests.utils import skip_if_not_sqlite
1114

1215

@@ -24,6 +27,7 @@ def test_init_clone(tmp_dir, db_file, expected_db_file):
2427
expected_db_file = os.fspath(tmp_dir / expected_db_file)
2528

2629
with SQLiteDatabaseEngine.from_db_file(db_file) as db:
30+
assert isinstance(db, SQLiteDatabaseEngine)
2731
assert db.db_file == expected_db_file
2832

2933
# Test clone
@@ -53,17 +57,15 @@ def test_get_db_file_in_memory(db_file, in_memory, expected):
5357

5458

5559
def test_serialize(sqlite_db):
56-
# Test serialization
60+
# JSON serialization format
5761
serialized = sqlite_db.serialize()
5862
assert serialized
59-
serialized_pickled = base64.b64decode(serialized.encode())
60-
assert serialized_pickled
61-
(f, args, kwargs) = pickle.loads(serialized_pickled) # noqa: S301
62-
assert str(f) == str(SQLiteDatabaseEngine.from_db_file)
63-
assert args == [":memory:"]
64-
assert kwargs == {}
65-
66-
# Test deserialization
63+
raw = base64.b64decode(serialized.encode())
64+
data = json.loads(raw.decode())
65+
assert data["callable"] == "sqlite.from_db_file"
66+
assert data["args"] == [":memory:"]
67+
assert data["kwargs"] == {}
68+
6769
obj3 = deserialize(serialized)
6870
assert isinstance(obj3, SQLiteDatabaseEngine)
6971
assert obj3.db_file == ":memory:"

tests/unit/test_metastore.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import base64
2-
import pickle
2+
import json
33

44
import pytest
55

@@ -24,18 +24,19 @@ def test_sqlite_metastore(sqlite_db):
2424
assert obj2.db.db_file == sqlite_db.db_file
2525
assert obj2.clone_params() == obj.clone_params()
2626

27-
# Test serialization
27+
# Test serialization JSON format
2828
serialized = obj.serialize()
2929
assert serialized
30-
serialized_pickled = base64.b64decode(serialized.encode())
31-
assert serialized_pickled
32-
(f, args, kwargs) = pickle.loads(serialized_pickled) # noqa: S301
33-
assert str(f) == str(SQLiteMetastore.init_after_clone)
34-
assert args == []
35-
assert kwargs["uri"] == uri
36-
assert str(kwargs["db_clone_params"]) == str(sqlite_db.clone_params())
37-
38-
# Test deserialization
30+
raw = base64.b64decode(serialized.encode())
31+
data = json.loads(raw.decode())
32+
assert data["callable"] == "sqlite.metastore.init_after_clone"
33+
assert data["args"] == []
34+
assert data["kwargs"]["uri"] == uri
35+
nested = data["kwargs"]["db_clone_params"]
36+
assert nested["callable"] == "sqlite.from_db_file"
37+
assert nested["args"] == [":memory:"]
38+
assert nested["kwargs"] == {}
39+
3940
obj3 = deserialize(serialized)
4041
assert isinstance(obj3, SQLiteMetastore)
4142
assert obj3.uri == uri

0 commit comments

Comments
 (0)