|
1 | 1 | import base64 |
2 | 2 | import json |
3 | 3 | import pickle |
4 | | -from functools import wraps |
5 | 4 |
|
6 | 5 | import pandas as pd |
7 | 6 | import polars as pl |
|
12 | 11 | from returns.io import IOResultE |
13 | 12 | from returns.pipeline import flow |
14 | 13 | from returns.pointfree import bind, map_ |
15 | | -from returns.unsafe import unsafe_perform_io |
16 | 14 |
|
17 | 15 | from lomas_client.constants import ( |
18 | 16 | DUMMY_NB_ROWS, |
|
24 | 22 | from lomas_client.libraries.smartnoise_sql import SmartnoiseSQLClient |
25 | 23 | from lomas_client.libraries.smartnoise_synth import SmartnoiseSynthClient |
26 | 24 | from lomas_client.models.config import ClientConfig |
27 | | -from lomas_client.utils import parse_if_ok |
| 25 | +from lomas_client.utils import parse_if_ok, unwrap_all_clsmethods |
28 | 26 | from lomas_core.constants import DPLibraries |
29 | 27 | from lomas_core.instrumentation import init_telemetry |
30 | 28 | from lomas_core.models.requests import GetDummyDataset, LomasRequestModel, OpenDPQueryModel |
@@ -272,63 +270,20 @@ def post_processes_queries(queries: list[dict]) -> list[dict]: |
272 | 270 | ) |
273 | 271 |
|
274 | 272 |
|
275 | | -# FIXME: how to cleanly shadow Client without to much python __darkmagic__ ... |
276 | | - |
277 | | - |
278 | | -def call_and_unwrap_wrapper(method): |
279 | | - @wraps(method) |
280 | | - def call_and_unwrap(*args, **kwargs): |
281 | | - result = method(*args, **kwargs) |
282 | | - if hasattr(result, "unwrap"): |
283 | | - return unsafe_perform_io(result.unwrap()) |
284 | | - return result |
285 | | - |
286 | | - return call_and_unwrap |
287 | | - |
288 | | - |
289 | | -class SmartnoiseSQLClientU(SmartnoiseSQLClient): |
290 | | - def __getattribute__(self, name, *args): |
291 | | - attr = super().__getattribute__(name) |
292 | | - if callable(attr): |
293 | | - return call_and_unwrap_wrapper(attr) |
294 | | - return attr |
295 | | - |
296 | | - |
297 | | -class OpenDPClientU(OpenDPClient): |
298 | | - def __getattribute__(self, name, *args): |
299 | | - attr = super().__getattribute__(name) |
300 | | - if callable(attr): |
301 | | - return call_and_unwrap_wrapper(attr) |
302 | | - return attr |
303 | | - |
304 | | - |
305 | | -class SmartnoiseSynthClientU(SmartnoiseSynthClient): |
306 | | - def __getattribute__(self, name, *args): |
307 | | - attr = super().__getattribute__(name) |
308 | | - if callable(attr): |
309 | | - return call_and_unwrap_wrapper(attr) |
310 | | - return attr |
311 | | - |
312 | | - |
313 | | -class DiffPrivLibClientU(DiffPrivLibClient): |
314 | | - def __getattribute__(self, name, *args): |
315 | | - attr = super().__getattribute__(name) |
316 | | - if callable(attr): |
317 | | - return call_and_unwrap_wrapper(attr) |
318 | | - return attr |
| 273 | +@unwrap_all_clsmethods |
| 274 | +class Client(ClientIO): |
| 275 | + """Original Client interface to shadow ClientIO whilst unwrapping results (for now).""" |
319 | 276 |
|
| 277 | + def __init__(self, **kwargs: ClientConfig.model_config): |
| 278 | + super().__init__(**kwargs) |
320 | 279 |
|
321 | | -class Client(ClientIO): |
322 | | - def __getattribute__(self, name, *args): |
323 | | - attr = super().__getattribute__(name) |
324 | | - if callable(attr): |
325 | | - return call_and_unwrap_wrapper(attr) |
326 | | - return attr |
327 | | - |
328 | | - def __init__(self, *args, **kwargs): |
329 | | - super().__init__(*args, **kwargs) |
330 | | - |
331 | | - self.smartnoise_sql = SmartnoiseSQLClientU(self.http_client) |
332 | | - self.smartnoise_synth = SmartnoiseSynthClientU(self.http_client) |
333 | | - self.opendp = OpenDPClientU(self.http_client) |
334 | | - self.diffprivlib = DiffPrivLibClientU(self.http_client) |
| 280 | + self.smartnoise_sql = unwrap_all_clsmethods(type("SmartnoiseSQLClientU", (SmartnoiseSQLClient,), {}))( |
| 281 | + self.http_client |
| 282 | + ) |
| 283 | + self.smartnoise_synth = unwrap_all_clsmethods( |
| 284 | + type("SmartnoiseSynthClientU", (SmartnoiseSynthClient,), {}) |
| 285 | + )(self.http_client) |
| 286 | + self.opendp = unwrap_all_clsmethods(type("OpenDPClientU", (OpenDPClient,), {}))(self.http_client) |
| 287 | + self.diffprivlib = unwrap_all_clsmethods(type("DiffPrivLibClientU", (DiffPrivLibClient,), {}))( |
| 288 | + self.http_client |
| 289 | + ) |
0 commit comments