Skip to content

Commit 25c694b

Browse files
committed
fix validate
1 parent bfe4a11 commit 25c694b

File tree

5 files changed

+20
-20
lines changed

5 files changed

+20
-20
lines changed

client/lomas_client/libraries/diffprivlib.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from returns.curry import partial
33
from returns.io import IOResultE
44
from returns.pipeline import flow
5-
from returns.pointfree import map_
5+
from returns.pointfree import bind
66
from sklearn.pipeline import Pipeline
77

88
from lomas_client.constants import (
@@ -75,7 +75,7 @@ def cost(
7575
},
7676
DiffPrivLibRequestModel.model_validate,
7777
partial(self.http_client.post, "estimate_diffprivlib_cost"),
78-
map_(validate_model_response(self.http_client, CostResponse)),
78+
bind(validate_model_response(self.http_client, CostResponse)),
7979
)
8080

8181
def query(
@@ -137,11 +137,11 @@ def query(
137137
{**body_dict, "dummy_nb_rows": nb_rows, "dummy_seed": seed},
138138
DiffPrivLibDummyQueryModel.model_validate,
139139
lambda body: self.http_client.post("dummy_diffprivlib_query", body),
140-
map_(validate_model_response(self.http_client, QueryResponse)),
140+
bind(validate_model_response(self.http_client, QueryResponse)),
141141
)
142142
return flow(
143143
body_dict,
144144
DiffPrivLibQueryModel.model_validate,
145145
lambda body: self.http_client.post("diffprivlib_query", body),
146-
map_(validate_model_response(self.http_client, QueryResponse)),
146+
bind(validate_model_response(self.http_client, QueryResponse)),
147147
)

client/lomas_client/libraries/opendp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from returns.curry import partial
44
from returns.io import IOResultE
55
from returns.pipeline import flow
6-
from returns.pointfree import map_
6+
from returns.pointfree import bind
77

88
from lomas_client.constants import DUMMY_NB_ROWS, DUMMY_SEED
99
from lomas_client.http_client import LomasHttpClient
@@ -104,7 +104,7 @@ def cost(
104104
body_json,
105105
OpenDPRequestModel.model_validate,
106106
partial(self.http_client.post, "estimate_opendp_cost"),
107-
map_(validate_model_response(self.http_client, CostResponse)),
107+
bind(validate_model_response(self.http_client, CostResponse)),
108108
)
109109

110110
def query(
@@ -154,11 +154,11 @@ def query(
154154
{**body_dict, "dummy_nb_rows": nb_rows, "dummy_seed": seed},
155155
OpenDPDummyQueryModel.model_validate,
156156
partial(self.http_client.post, "dummy_opendp_query"),
157-
map_(validate_model_response(self.http_client, QueryResponse)),
157+
bind(validate_model_response(self.http_client, QueryResponse)),
158158
)
159159
return flow(
160160
body_dict,
161161
OpenDPQueryModel.model_validate,
162162
partial(self.http_client.post, "opendp_query"),
163-
map_(validate_model_response(self.http_client, QueryResponse)),
163+
bind(validate_model_response(self.http_client, QueryResponse)),
164164
)

client/lomas_client/libraries/smartnoise_sql.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from returns.io import IOResultE
22
from returns.pipeline import flow
3-
from returns.pointfree import map_
3+
from returns.pointfree import bind
44

55
from lomas_client.constants import DUMMY_NB_ROWS, DUMMY_SEED
66
from lomas_client.http_client import LomasHttpClient
@@ -51,7 +51,7 @@ def cost(
5151
},
5252
SmartnoiseSQLRequestModel.model_validate,
5353
lambda body: self.http_client.post("estimate_smartnoise_sql_cost", body),
54-
map_(validate_model_response(self.http_client, CostResponse)),
54+
bind(validate_model_response(self.http_client, CostResponse)),
5555
)
5656

5757
def query(
@@ -108,11 +108,11 @@ def query(
108108
{**body_dict, "dummy_nb_rows": nb_rows, "dummy_seed": seed},
109109
SmartnoiseSQLDummyQueryModel.model_validate,
110110
lambda body: self.http_client.post("dummy_smartnoise_sql_query", body),
111-
map_(validate_model_response(self.http_client, QueryResponse)),
111+
bind(validate_model_response(self.http_client, QueryResponse)),
112112
)
113113
return flow(
114114
body_dict,
115115
SmartnoiseSQLQueryModel.model_validate,
116116
lambda body: self.http_client.post("smartnoise_sql_query", body),
117-
map_(validate_model_response(self.http_client, QueryResponse)),
117+
bind(validate_model_response(self.http_client, QueryResponse)),
118118
)

client/lomas_client/libraries/smartnoise_synth.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from returns.io import IOResultE
22
from returns.pipeline import flow
3-
from returns.pointfree import map_
3+
from returns.pointfree import bind
44
from smartnoise_synth_logger import serialise_constraints
55

66
from lomas_client.constants import (
@@ -97,7 +97,7 @@ def cost(
9797
lambda body: self.http_client.post(
9898
"estimate_smartnoise_synth_cost", body, SMARTNOISE_SYNTH_READ_TIMEOUT
9999
),
100-
map_(validate_model_response(self.http_client, CostResponse)),
100+
bind(validate_model_response(self.http_client, CostResponse)),
101101
)
102102

103103
def query(
@@ -195,12 +195,12 @@ def query(
195195
lambda body: self.http_client.post(
196196
"dummy_smartnoise_synth_query", body, SMARTNOISE_SYNTH_READ_TIMEOUT
197197
),
198-
map_(validate_model_response(self.http_client, QueryResponse)),
198+
bind(validate_model_response(self.http_client, QueryResponse)),
199199
)
200200
return flow(
201201
body_dict,
202202
# tap(lambda _: validate_synthesizer(synth_name, return_model)),
203203
SmartnoiseSynthQueryModel.model_validate,
204204
lambda body: self.http_client.post("smartnoise_synth_query", body, SMARTNOISE_SYNTH_READ_TIMEOUT),
205-
map_(validate_model_response(self.http_client, QueryResponse)),
205+
bind(validate_model_response(self.http_client, QueryResponse)),
206206
)

client/lomas_client/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def validate_synthesizer(synth_name: str, return_model: bool = False) -> None:
7777

7878
def validate_model_response(
7979
client: LomasHttpClient, response_model: type[ResponseT]
80-
) -> Callable[[requests.Response], ResponseT]:
80+
) -> Callable[[requests.Response], IOResultE[ResponseT]]:
8181
"""Validate and process a HTTP response.
8282
8383
Args:
@@ -87,17 +87,17 @@ def validate_model_response(
8787
response_model: Model for responses requests.
8888
"""
8989

90-
def validate(response: requests.Response) -> ResponseT:
90+
def validate(response: requests.Response) -> IOResultE[ResponseT]:
9191
if response.status_code != status.HTTP_202_ACCEPTED:
92-
parse_server_error(response).bind_result(specify_error_from_model).alt(raise_exception)
92+
return parse_server_error(response).bind_result(specify_error_from_model)
9393

9494
job_uid = response.json()["uid"]
9595
job = client.wait_for_job(job_uid)
9696
if job.status == "failed":
9797
assert job.error is not None, f"job {job_uid} failed without error !"
9898
specify_error_from_model(job.error)
9999

100-
return response_model.model_validate(job.result)
100+
return impure_safe(response_model.model_validate)(job.result)
101101

102102
return validate
103103

0 commit comments

Comments
 (0)