Skip to content

fix(adserver): adserver returns cloudevents compatible response #5348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions components/alibi-detect-server/adserver/base/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import logging
import tempfile
from distutils.util import strtobool
from typing import Optional

ARTIFACT_DOWNLOAD_LOCATION = os.environ.get("DRIFT_ARTIFACTS_DIR", "/tmp")

Expand All @@ -18,13 +18,13 @@


class Rclone:
def __init__(self, cfg_file: str = None):
def __init__(self, cfg_file: Optional[str] = None):
self.cfg_file = cfg_file

def copy(self, src: str, dest: str = None):
def copy(self, src: str, dest: Optional[str] = None):
if rclone is None:
raise RuntimeError(
"rclone binary not found - rclone-based storage funcionality disabled"
"rclone binary not found - rclone-based storage functionality disabled"
)

if dest is None:
Expand Down
4 changes: 2 additions & 2 deletions components/alibi-detect-server/adserver/cm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
SELDON_PREDICTOR_ID = DEFAULT_LABELS["predictor_name"]


def _load_class_module(module_path: str) -> str:
def _load_class_module(module_path: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? It still seems to return str, or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually returns a Module. It was raising type hint warnings because it's used like so in L77-79

            # Load from locally available models
            MetricsClass = _load_class_module(self.storage_uri)
            self.model = MetricsClass()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so potentially -> "Module" (or other appropriate way to annotate the return type) could be good?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably you'd need to import appropriate type

components = module_path.split(".")
mod = __import__(".".join(components[:-1]))
for comp in components[1:]:
Expand All @@ -32,7 +32,7 @@ def _load_class_module(module_path: str) -> str:

class CustomMetricsModel(CEModel): # pylint:disable=c-extension-no-member
def __init__(
self, name: str, storage_uri: str, elasticsearch_uri: str = None, model=None
self, name: str, storage_uri: str, elasticsearch_uri: Optional[str] = None, model=None
):
"""
Custom Metrics Model
Expand Down
104 changes: 71 additions & 33 deletions components/alibi-detect-server/adserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
event_type: str,
event_source: str,
http_port: int = DEFAULT_HTTP_PORT,
reply_url: str = None,
reply_url: Optional[str] = None,
):
"""
CloudEvents server
Expand Down Expand Up @@ -146,29 +146,21 @@ def get_request_handler(protocol, request: Dict) -> RequestHandler:
raise Exception(f"Unknown protocol {protocol}")


def sendCloudEvent(event: v1.Event, url: str):
def forward_request(headers, data, url):
"""
Send CloudEvent
Forward request

Parameters
----------
event
CloudEvent to send
headers
Headers to forward
data
Data to forward
url
Url to send event
Url to forward to

"""
http_marshaller = marshaller.NewDefaultHTTPMarshaller()
binary_headers, binary_data = http_marshaller.ToRequest(
event, converters.TypeBinary, json.dumps
)

logging.info("binary CloudEvent")
for k, v in binary_headers.items():
logging.info("{0}: {1}\r\n".format(k, v))
logging.info(binary_data)

response = requests.post(url, headers=binary_headers, data=binary_data)
response = requests.post(url, headers=headers, data=data)
response.raise_for_status()


Expand Down Expand Up @@ -252,27 +244,73 @@ def post(self):
else:
logging.error("Metrics returned are invalid: " + str(runtime_metrics))

if response.data is not None:
revent = create_cloud_event(
response.data,
self.event_type,
self.event_source,
event_id=event.EventID(),
extensions=event.Extensions(),
)

if response.data is not None:
# Create event from response if reply_url is active
revent_headers, revent_data = http_marshaller.ToRequest(
revent, converters.TypeBinary, json.dumps
)

if not self.reply_url == "":
if event.EventID() is None or event.EventID() == "":
resp_event_id = uuid.uuid1().hex
else:
resp_event_id = event.EventID()
revent = (
v1.Event()
.SetContentType("application/json")
.SetData(response.data)
.SetEventID(resp_event_id)
.SetSource(self.event_source)
.SetEventType(self.event_type)
.SetExtensions(event.Extensions())
)
logging.debug(json.dumps(revent.Properties()))
sendCloudEvent(revent, self.reply_url)
self.write(json.dumps(response.data))
logging.info("binary CloudEvent")
for k, v in revent_headers.items():
logging.info("{0}: {1}\r\n".format(k, v))
logging.info(revent_data)
forward_request(revent_headers, revent_data, self.reply_url)

self.set_header("Content-Type", "application/json")
for headers in revent_headers:
self.set_header(headers, revent_headers[headers])
self.write(revent_data)


def create_cloud_event(
data: dict,
event_type: str,
event_source: str,
extensions: dict,
event_id: str = None,
) -> v1.Event:
"""
Create a CloudEvent

Parameters
----------
data
The data to send
event_type
The CE event type
event_source
The CE event source
extensions
Any extensions to add
event_id
The event id
Returns
-------
A CloudEvent

"""
if event_id is None or event_id == "":
event_id = uuid.uuid1().hex

event = (
v1.Event()
.SetData(data)
.SetEventID(event_id if event_id else str(uuid.uuid1().hex))
.SetSource(event_source)
.SetEventType(event_type)
.SetExtensions(extensions)
)
return event

class LivenessHandler(tornado.web.RequestHandler):
def get(self):
Expand Down
23 changes: 23 additions & 0 deletions components/alibi-detect-server/adserver/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from typing import List, Dict, Optional, Union
import json
import requests_mock
from cloudevents.sdk import converters
from cloudevents.sdk import marshaller
from cloudevents.sdk.event import v1


class TestProtocol(AsyncHTTPTestCase):
Expand Down Expand Up @@ -74,11 +77,31 @@ def test_basic(self):
)
self.assertEqual(response.code, 200)
expectedResponse = DummyModel.getResponse().data
# assert that the expected response conforms to the CloudEvent spec
event = v1.Event()
http_marshaller = marshaller.NewDefaultHTTPMarshaller()
try:
event = http_marshaller.FromRequest(
event, response.headers, response.body, json.loads
)
except Exception as e:
assert False, f"Failed to unmarshall data with error: {type(e).__name__}('{e}')"

# assert cloud event properties have been set correctly in response
self.assertEqual(event.Data(), expectedResponse)
self.assertEqual(event.Source(), self.eventSource)
self.assertEqual(event.EventType(), self.eventType)
self.assertEqual(event.ContentType(), "application/json")
self.assertEqual(event.EventID(), "1234")
self.assertEqual(event.CloudEventVersion(), "1.0")
self.assertEqual(response.body.decode("utf-8"), json.dumps(expectedResponse))

# assert requests have been made with the correct headers and data
self.assertEqual(m.request_history[0].json(), expectedResponse)
headers: Dict = m.request_history[0]._request.headers
self.assertEqual(headers["ce-source"], self.eventSource)
self.assertEqual(headers["ce-type"], self.eventType)
self.assertNotIn("ce-datacontenttype", headers)


class TestKFservingV2HttpModel(AsyncHTTPTestCase):
Expand Down