Skip to content

Commit c44e3ea

Browse files
committed
less ugly shadowing
1 parent 9d38112 commit c44e3ea

File tree

2 files changed

+46
-62
lines changed

2 files changed

+46
-62
lines changed

client/lomas_client/client.py

Lines changed: 16 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import base64
22
import json
33
import pickle
4-
from functools import wraps
54

65
import pandas as pd
76
import polars as pl
@@ -12,7 +11,6 @@
1211
from returns.io import IOResultE
1312
from returns.pipeline import flow
1413
from returns.pointfree import bind, map_
15-
from returns.unsafe import unsafe_perform_io
1614

1715
from lomas_client.constants import (
1816
DUMMY_NB_ROWS,
@@ -24,7 +22,7 @@
2422
from lomas_client.libraries.smartnoise_sql import SmartnoiseSQLClient
2523
from lomas_client.libraries.smartnoise_synth import SmartnoiseSynthClient
2624
from lomas_client.models.config import ClientConfig
27-
from lomas_client.utils import parse_if_ok
25+
from lomas_client.utils import parse_if_ok, unwrap_all_clsmethods
2826
from lomas_core.constants import DPLibraries
2927
from lomas_core.instrumentation import init_telemetry
3028
from lomas_core.models.requests import GetDummyDataset, LomasRequestModel, OpenDPQueryModel
@@ -272,63 +270,20 @@ def post_processes_queries(queries: list[dict]) -> list[dict]:
272270
)
273271

274272

275-
# FIXME: how to cleanly shadow Client without to much python __darkmagic__ ...
276-
277-
278-
def call_and_unwrap_wrapper(method):
279-
@wraps(method)
280-
def call_and_unwrap(*args, **kwargs):
281-
result = method(*args, **kwargs)
282-
if hasattr(result, "unwrap"):
283-
return unsafe_perform_io(result.unwrap())
284-
return result
285-
286-
return call_and_unwrap
287-
288-
289-
class SmartnoiseSQLClientU(SmartnoiseSQLClient):
290-
def __getattribute__(self, name, *args):
291-
attr = super().__getattribute__(name)
292-
if callable(attr):
293-
return call_and_unwrap_wrapper(attr)
294-
return attr
295-
296-
297-
class OpenDPClientU(OpenDPClient):
298-
def __getattribute__(self, name, *args):
299-
attr = super().__getattribute__(name)
300-
if callable(attr):
301-
return call_and_unwrap_wrapper(attr)
302-
return attr
303-
304-
305-
class SmartnoiseSynthClientU(SmartnoiseSynthClient):
306-
def __getattribute__(self, name, *args):
307-
attr = super().__getattribute__(name)
308-
if callable(attr):
309-
return call_and_unwrap_wrapper(attr)
310-
return attr
311-
312-
313-
class DiffPrivLibClientU(DiffPrivLibClient):
314-
def __getattribute__(self, name, *args):
315-
attr = super().__getattribute__(name)
316-
if callable(attr):
317-
return call_and_unwrap_wrapper(attr)
318-
return attr
273+
@unwrap_all_clsmethods
274+
class Client(ClientIO):
275+
"""Original Client interface to shadow ClientIO whilst unwrapping results (for now)."""
319276

277+
def __init__(self, **kwargs: ClientConfig.model_config):
278+
super().__init__(**kwargs)
320279

321-
class Client(ClientIO):
322-
def __getattribute__(self, name, *args):
323-
attr = super().__getattribute__(name)
324-
if callable(attr):
325-
return call_and_unwrap_wrapper(attr)
326-
return attr
327-
328-
def __init__(self, *args, **kwargs):
329-
super().__init__(*args, **kwargs)
330-
331-
self.smartnoise_sql = SmartnoiseSQLClientU(self.http_client)
332-
self.smartnoise_synth = SmartnoiseSynthClientU(self.http_client)
333-
self.opendp = OpenDPClientU(self.http_client)
334-
self.diffprivlib = DiffPrivLibClientU(self.http_client)
280+
self.smartnoise_sql = unwrap_all_clsmethods(type("SmartnoiseSQLClientU", (SmartnoiseSQLClient,), {}))(
281+
self.http_client
282+
)
283+
self.smartnoise_synth = unwrap_all_clsmethods(
284+
type("SmartnoiseSynthClientU", (SmartnoiseSynthClient,), {})
285+
)(self.http_client)
286+
self.opendp = unwrap_all_clsmethods(type("OpenDPClientU", (OpenDPClient,), {}))(self.http_client)
287+
self.diffprivlib = unwrap_all_clsmethods(type("DiffPrivLibClientU", (DiffPrivLibClient,), {}))(
288+
self.http_client
289+
)

client/lomas_client/utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import argparse
2+
import inspect
13
import warnings
4+
from collections.abc import Callable
5+
from functools import wraps
26
from json import JSONDecodeError
37
from typing import TypeVar
48

@@ -7,6 +11,7 @@
711
from pydantic import ValidationError
812
from returns.functions import raise_exception
913
from returns.io import IOFailure, IOResultE, IOSuccess, impure_safe
14+
from returns.unsafe import unsafe_perform_io
1015

1116
from lomas_client.http_client import LomasHttpClient
1217
from lomas_core.constants import SSynthGanSynthesizer, SSynthMarginalSynthesizer
@@ -22,7 +27,7 @@ def parse_if_ok(res: requests.Response) -> IOResultE[str]:
2227
return parse_server_error(res).bind_result(specify_error_from_model)
2328

2429

25-
def parse_server_error(response: requests.Response) -> LomasServerExceptionType:
30+
def parse_server_error(response: requests.Response) -> IOResultE[LomasServerExceptionType]:
2631
"""Parse a server error message based on the HTTP response.
2732
2833
Args:
@@ -91,3 +96,27 @@ def validate_model_response(
9196
specify_error_from_model(job.error)
9297

9398
return response_model.model_validate(job.result)
99+
100+
101+
def call_and_unwrap_wrapper(method: Callable) -> Callable:
102+
"""Unwrap IOResultE[T] back to T in the unsafest way possible."""
103+
104+
@wraps(method)
105+
def call_and_unwrap(*args: argparse.Namespace, **kwargs: dict) -> Callable:
106+
result = method(*args, **kwargs)
107+
if hasattr(result, "unwrap"):
108+
# First raise the internal Exception if the container is a failure
109+
inner_success = result.alt(raise_exception).unwrap()
110+
# Otherwise force-escape IO
111+
return unsafe_perform_io(inner_success)
112+
return result
113+
114+
return call_and_unwrap
115+
116+
117+
def unwrap_all_clsmethods(cls: type) -> type:
118+
"""Add a wrapper to all (public) methods of the given Class."""
119+
for name, method in inspect.getmembers(cls, inspect.isfunction):
120+
if not name.startswith("_"):
121+
setattr(cls, name, call_and_unwrap_wrapper(method))
122+
return cls

0 commit comments

Comments
 (0)