Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/lomas_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from logging import NullHandler

from .client import Client
from .client import Client, ClientIO

logger = logging.getLogger(__name__)
logger.addHandler(NullHandler())

__all__ = ("Client",)
__all__ = ("Client", "ClientIO")
231 changes: 153 additions & 78 deletions client/lomas_client/client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import base64
import json
import logging
import pickle
from functools import partial

import pandas as pd
import polars as pl
from fastapi import status
from opendp.mod import enable_features
from opendp_logger import enable_logging
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from pydantic import ValidationError
from returns.io import IOResultE
from returns.pipeline import flow
from returns.pointfree import bind, map_

from lomas_client.constants import (
DUMMY_NB_ROWS,
Expand All @@ -20,7 +24,7 @@
from lomas_client.libraries.smartnoise_sql import SmartnoiseSQLClient
from lomas_client.libraries.smartnoise_synth import SmartnoiseSynthClient
from lomas_client.models.config import ClientConfig
from lomas_client.utils import raise_error, validate_model_response_direct
from lomas_client.utils import parse_if_ok, unwrap, unwrap_all_clsmethods
from lomas_core.constants import DPLibraries
from lomas_core.instrumentation import init_telemetry
from lomas_core.models.requests import GetDummyDataset, LomasRequestModel, OpenDPQueryModel
Expand All @@ -32,12 +36,14 @@
)
from lomas_core.opendp_utils import reconstruct_measurement_pipeline

logger = logging.getLogger(__name__)

# Opendp_logger
enable_logging()
enable_features("contrib")


class Client:
class ClientIO:
"""Client class to send requests to the server.

Handle all serialisation and deserialisation steps
Expand All @@ -53,9 +59,10 @@ def __init__(self, **kwargs: ClientConfig.model_config):
try:
self.config = ClientConfig(**kwargs)
except ValidationError as exc:
for err in exc.errors():
logger.error(f"{err['loc'][0]} --> {err['msg']}")
raise ValueError(
"Missing one of or invalid: client_id, client_secret, keycloak_url"
"or realm when using jwt authentication method."
"Missing/Invalid fields"
"If you are using this library from a managed environment and don't know "
"about your credentials, please contact your system administrator."
) from exc
Expand All @@ -70,29 +77,33 @@ def __init__(self, **kwargs: ClientConfig.model_config):
self.opendp = OpenDPClient(self.http_client)
self.diffprivlib = DiffPrivLibClient(self.http_client)

def get_dataset_metadata(self) -> LomasRequestModel:
def get_dataset_metadata(self) -> IOResultE[LomasRequestModel]:
"""This function retrieves metadata for the dataset.

Returns:
LomasRequestModel:
A dictionary containing dataset metadata.
"""
body_dict = {"dataset_name": self.config.dataset_name}
body = LomasRequestModel.model_validate(body_dict)
res = self.http_client.post("get_dataset_metadata", body)
if res.status_code == status.HTTP_200_OK:
data = res.content.decode("utf8")
metadata = json.loads(data)
return metadata

raise_error(res)
return flow(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks much better indeed

# construct request body
{"dataset_name": self.config.dataset_name},
# validate request model
LomasRequestModel.model_validate,
# post to the validated body to the corresponding endpoint
lambda body: self.http_client.post("get_dataset_metadata", body),
# parse reply if HTTP 200
bind(parse_if_ok),
# load successful response as json
map_(json.loads),
)

def get_dummy_dataset(
self,
nb_rows: int = DUMMY_NB_ROWS,
seed: int = DUMMY_SEED,
lazy: bool = False,
) -> pd.DataFrame | pl.LazyFrame:
) -> IOResultE[pd.DataFrame | pl.LazyFrame]:
"""This function retrieves a dummy dataset with optional parameters.

Args:
Expand All @@ -107,32 +118,27 @@ def get_dummy_dataset(
pd.DataFrame | pl.LazyFrame: A Pandas DataFrame representing
the dummy dataset (optionally in LazyFrame format).
"""
body_dict = {
"dataset_name": self.config.dataset_name,
"dummy_nb_rows": nb_rows,
"dummy_seed": seed,
}
body = GetDummyDataset.model_validate(body_dict)
res = self.http_client.post("get_dummy_dataset", body)

if res.status_code == status.HTTP_200_OK:
data = res.content.decode("utf8")
dummy_df = DummyDsResponse.model_validate_json(data).dummy_df
if lazy:
# Temporary: we use type string for datetime in polars
# Will be fixed in 0.13
for col in dummy_df.select_dtypes(include=["datetime"]):
dummy_df[col] = dummy_df[col].astype("string[python]")
print(
"Datetime type mismatch: The Polars LazyFrame currently uses 'str' for datetime fields, "
"which may not match the expected metadata types. This is a temporary workaround "
"and will be resolved in a future release (>=0.13)."
return flow(
# construct request body
{
"dataset_name": self.config.dataset_name,
"dummy_nb_rows": nb_rows,
"dummy_seed": seed,
},
# validate request model
GetDummyDataset.model_validate,
# post to the validated body to the corresponding endpoint
lambda body: self.http_client.post("get_dummy_dataset", body),
# parse reply if HTTP 200
bind(parse_if_ok),
# load successful response as json
map_(DummyDsResponse.model_validate_json),
map_(
lambda dummy_ds_res: (
pl.from_pandas(dummy_ds_res.dummy_df).lazy() if lazy else dummy_ds_res.dummy_df
)
return pl.from_pandas(dummy_df).lazy()

return dummy_df

raise_error(res)
),
)

def get_dummy_lf(self, nb_rows: int = DUMMY_NB_ROWS, seed: int = DUMMY_SEED) -> pl.LazyFrame:
"""
Expand All @@ -154,50 +160,68 @@ def get_dummy_lf(self, nb_rows: int = DUMMY_NB_ROWS, seed: int = DUMMY_SEED) ->
dummy_pandas[col] = dummy_pandas[col].astype(str)
return pl.from_pandas(dummy_pandas).lazy()

def get_initial_budget(self) -> InitialBudgetResponse:
def get_initial_budget(self) -> IOResultE[InitialBudgetResponse]:
"""This function retrieves the initial budget.

Returns:
InitialBudgetResponse: A dictionary
containing the initial budget.
"""

body_dict = {"dataset_name": self.config.dataset_name}

body = LomasRequestModel.model_validate(body_dict)
res = self.http_client.post("get_initial_budget", body)

return validate_model_response_direct(res, InitialBudgetResponse)

def get_total_spent_budget(self) -> SpentBudgetResponse:
return flow(
# construct request body
{"dataset_name": self.config.dataset_name},
# validate request model
LomasRequestModel.model_validate,
# post to the validated body to the corresponding endpoint
lambda body: self.http_client.post("get_initial_budget", body),
# parse reply if HTTP 200
bind(parse_if_ok),
# build Budget Response from successful json payload
map_(InitialBudgetResponse.model_validate_json),
)

def get_total_spent_budget(self) -> IOResultE[SpentBudgetResponse]:
"""This function retrieves the total spent budget.

Returns:
SpentBudgetResponse: A dictionary containing
the total spent budget.
"""
body_dict = {"dataset_name": self.config.dataset_name}

body = LomasRequestModel.model_validate(body_dict)
res = self.http_client.post("get_total_spent_budget", body)

return validate_model_response_direct(res, SpentBudgetResponse)

def get_remaining_budget(self) -> RemainingBudgetResponse:
return flow(
# construct request body
{"dataset_name": self.config.dataset_name},
# validate request model
LomasRequestModel.model_validate,
# post to the validated body to the corresponding endpoint
lambda body: self.http_client.post("get_total_spent_budget", body),
# parse reply if HTTP 200
bind(parse_if_ok),
# build Budget Response from successful json payload
map_(SpentBudgetResponse.model_validate_json),
)

def get_remaining_budget(self) -> IOResultE[RemainingBudgetResponse]:
"""This function retrieves the remaining budget.

Returns:
RemainingBudgetResponse: A dictionary
containing the remaining budget.
"""
body_dict = {"dataset_name": self.config.dataset_name}

body = LomasRequestModel.model_validate(body_dict)
res = self.http_client.post("get_remaining_budget", body)

return validate_model_response_direct(res, RemainingBudgetResponse)

def get_previous_queries(self) -> list[dict]:
return flow(
# construct request body
{"dataset_name": self.config.dataset_name},
# validate request model
LomasRequestModel.model_validate,
# post to the validated body to the corresponding endpoint
lambda body: self.http_client.post("get_remaining_budget", body),
# parse reply if HTTP 200
bind(parse_if_ok),
# build Budget Response from successful json payload
map_(RemainingBudgetResponse.model_validate_json),
)

def get_previous_queries(self) -> IOResultE[list[dict]]:
"""This function retrieves the previous queries of the user.

Raises:
Expand All @@ -208,17 +232,8 @@ def get_previous_queries(self) -> list[dict]:
List[dict]: A list of dictionary containing
the different queries on the private dataset.
"""
body_dict = {"dataset_name": self.config.dataset_name}

body = LomasRequestModel.model_validate(body_dict)
res = self.http_client.post("get_previous_queries", body)

if res.status_code == status.HTTP_200_OK:
queries = json.loads(res.content.decode("utf8"))["previous_queries"]

if not queries:
return queries

def post_processes_queries(queries: list[dict]) -> list[dict]:
deserialised_queries = []
for query in queries:
match query["dp_library"]:
Expand All @@ -233,8 +248,8 @@ def get_previous_queries(self) -> list[dict]:
query["response"]["result"] = pd.DataFrame(res)
case DPLibraries.OPENDP:
query_json = OpenDPQueryModel.model_validate(query["client_input"])
query["client_input"]["opendp_json"] = reconstruct_measurement_pipeline(
query_json, self.get_dataset_metadata()
query["client_input"]["opendp_json"] = self.get_dataset_metadata().map(
partial(reconstruct_measurement_pipeline, query_json)
)
case DPLibraries.DIFFPRIVLIB:
model = base64.b64decode(query["response"]["result"]["model"])
Expand All @@ -246,4 +261,64 @@ def get_previous_queries(self) -> list[dict]:

return deserialised_queries

raise_error(res)
return flow(
# construct request body
{"dataset_name": self.config.dataset_name},
# validate request model
LomasRequestModel.model_validate,
# post to the validated body to the corresponding endpoint
lambda body: self.http_client.post("get_previous_queries", body),
# parse reply if HTTP 200
bind(parse_if_ok),
map_(lambda content: json.loads(content)["previous_queries"]),
map_(post_processes_queries),
)


class Client:
"""Original Client interface to shadow ClientIO whilst unwrapping results (for now)."""

def __init__(self, **kwargs: ClientConfig.model_config):
self.client_io = ClientIO(**kwargs)

self.smartnoise_sql = unwrap_all_clsmethods(type("SmartnoiseSQLClientU", (SmartnoiseSQLClient,), {}))(
self.client_io.http_client
)
self.smartnoise_synth = unwrap_all_clsmethods(
type("SmartnoiseSynthClientU", (SmartnoiseSynthClient,), {})
)(self.client_io.http_client)
self.opendp = unwrap_all_clsmethods(type("OpenDPClientU", (OpenDPClient,), {}))(
self.client_io.http_client
)
self.diffprivlib = unwrap_all_clsmethods(type("DiffPrivLibClientU", (DiffPrivLibClient,), {}))(
self.client_io.http_client
)

def get_dataset_metadata(self) -> LomasRequestModel:
"""Unwrap proxy."""
return unwrap(self.client_io.get_dataset_metadata())

def get_dummy_dataset(
self,
nb_rows: int = DUMMY_NB_ROWS,
seed: int = DUMMY_SEED,
lazy: bool = False,
) -> pd.DataFrame | pl.LazyFrame:
"""Unwrap proxy."""
return unwrap(self.client_io.get_dummy_dataset(nb_rows, seed, lazy))

def get_initial_budget(self) -> InitialBudgetResponse:
"""Unwrap proxy."""
return unwrap(self.client_io.get_initial_budget())

def get_total_spent_budget(self) -> SpentBudgetResponse:
"""Unwrap proxy."""
return unwrap(self.client_io.get_total_spent_budget())

def get_remaining_budget(self) -> RemainingBudgetResponse:
"""Unwrap proxy."""
return unwrap(self.client_io.get_remaining_budget())

def get_previous_queries(self) -> list[dict]:
"""Unwrap proxy."""
return unwrap(self.client_io.get_previous_queries())
3 changes: 3 additions & 0 deletions client/lomas_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from oauthlib.oauth2 import BackendApplicationClient, TokenExpiredError
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from requests_oauthlib import OAuth2Session
from returns.io import impure_safe

from lomas_client.constants import CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
from lomas_client.models.config import ClientConfig
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self, config: ClientConfig) -> None:
# Fetch first token:
self._fetch_token()

@impure_safe
def _fetch_token(self) -> None:
"""Fetches an authorization token and stores it."""
self._oauth2_session.fetch_token(
Expand All @@ -50,6 +52,7 @@ def _fetch_token(self) -> None:
client_secret=self.config.client_secret,
)

@impure_safe
def post(
self,
endpoint: str,
Expand Down
Loading
Loading