Skip to content

Commit 4497e3f

Browse files
authored
enhance(client): question answering infer type support multiple args (#2997)
1 parent 8b19f10 commit 4497e3f

File tree

10 files changed

+212
-74
lines changed

10 files changed

+212
-74
lines changed

client/starwhale/api/_impl/service/service.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import typing as t
66
import functools
77

8+
from .types.types import ComponentSpec
9+
810
if sys.version_info >= (3, 9):
911
from importlib.resources import files
1012
else:
1113
from importlib_resources import files
1214

1315
from fastapi import FastAPI
14-
from pydantic import Field, BaseModel
16+
from pydantic import Field
1517
from starlette.responses import FileResponse
1618
from starlette.staticfiles import StaticFiles
1719

@@ -25,8 +27,17 @@
2527
)
2628

2729

28-
class Query(BaseModel):
29-
content: str
30+
class ApiSpec(SwBaseModel):
31+
uri: str
32+
inference_type: str
33+
components_hint: t.List[ComponentSpec] = Field(default_factory=list)
34+
35+
36+
class ServiceSpec(SwBaseModel):
37+
title: t.Optional[str]
38+
description: t.Optional[str]
39+
version: str
40+
apis: t.List[ApiSpec]
3041

3142

3243
class Api(SwBaseModel):
@@ -37,32 +48,28 @@ class Api(SwBaseModel):
3748
inputs: Inputs = Field(exclude=True)
3849
outputs: Outputs = Field(exclude=True)
3950

40-
@staticmethod
41-
def question_answering(func: t.Callable) -> t.Callable:
42-
def inter(query: Query) -> str:
43-
return func(query.content) # type: ignore
44-
45-
return inter
46-
4751
def view_func(self, ins: t.Any = None) -> t.Callable:
4852
func = self.func
4953
if ins is not None:
5054
func = functools.partial(func, ins)
5155
if self.inference_type is None:
5256
return func
53-
return getattr(self, self.inference_type.value)(func) # type: ignore
57+
return self.inference_type.router_fn(func)
5458

5559
def all_gradio_components(self) -> bool:
5660
if self.inference_type is not None:
5761
return False
5862
return all_components_are_gradio(inputs=self.inputs, outputs=self.outputs)
5963

64+
def to_spec(self) -> ApiSpec | None:
65+
if self.inference_type is None:
66+
return None
6067

61-
class ServiceSpec(SwBaseModel):
62-
title: t.Optional[str]
63-
description: t.Optional[str]
64-
version: str
65-
apis: t.List[Api]
68+
return ApiSpec(
69+
uri=self.uri,
70+
inference_type=self.inference_type.name,
71+
components_hint=self.inference_type.components_spec(),
72+
)
6673

6774

6875
class Service:
@@ -89,7 +96,10 @@ def decorator(func: t.Any) -> t.Any:
8996
return decorator
9097

9198
def get_spec(self) -> ServiceSpec:
92-
return ServiceSpec(version="0.0.1", apis=list(self.apis.values()))
99+
return ServiceSpec(
100+
version="0.0.2",
101+
apis=list(filter(None, [_api.to_spec() for _api in self.apis.values()])),
102+
)
93103

94104
def add_api(
95105
self,

client/starwhale/api/_impl/service/types.py

Lines changed: 0 additions & 42 deletions
This file was deleted.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .types import Inputs, Outputs, ServiceType, all_components_are_gradio
2+
3+
__all__ = ["Inputs", "Outputs", "ServiceType", "all_components_are_gradio"]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from typing import Any, Set, List, Callable, Optional
5+
6+
from pydantic import BaseModel
7+
from pydantic.dataclasses import dataclass
8+
9+
from .types import ServiceType, ComponentSpec
10+
11+
12+
@dataclass
13+
class Message:
14+
content: str
15+
role: str
16+
17+
18+
class Query(BaseModel):
19+
user_input: str
20+
history: List[Message]
21+
confidence: Optional[float]
22+
top_k: Optional[float]
23+
top_p: Optional[float]
24+
temperature: Optional[float]
25+
max_new_tokens: Optional[int]
26+
27+
28+
class LLMChat(ServiceType):
29+
name = "llm_chat"
30+
31+
# TODO use pydantic model annotations generated arg_types
32+
arg_types = {
33+
"user_input": str,
34+
"history": list, # list of Message
35+
"top_k": float,
36+
"top_p": float,
37+
"temperature": float,
38+
"max_new_tokens": int,
39+
}
40+
41+
def __init__(self, args: Set | None = None) -> None:
42+
if args is None:
43+
args = set(self.arg_types.keys())
44+
else:
45+
# check if all args are in arg_types
46+
for arg in args:
47+
if arg not in self.arg_types:
48+
raise ValueError(f"Argument {arg} is not in arg_types.")
49+
50+
self.args = args
51+
52+
def components_spec(self) -> List[ComponentSpec]:
53+
return [
54+
ComponentSpec(name=arg, type=self.arg_types[arg].__name__)
55+
for arg in self.args
56+
]
57+
58+
def router_fn(self, func: Callable) -> Callable:
59+
params = inspect.signature(func).parameters
60+
61+
def wrapper(query: Query) -> Any:
62+
return func(**{k: getattr(query, k) for k in params})
63+
64+
return wrapper
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import abc
2+
import inspect
3+
from typing import Any, Dict, List, Callable
4+
5+
from starwhale.utils import console
6+
from starwhale.base.models.base import SwBaseModel
7+
8+
Inputs = Any
9+
Outputs = Any
10+
11+
12+
class ComponentSpec(SwBaseModel):
13+
name: str
14+
type: str
15+
16+
def __hash__(self) -> int:
17+
return hash((self.name, self.type))
18+
19+
20+
class ServiceType(abc.ABC):
21+
"""Protocol for service types."""
22+
23+
@property
24+
@abc.abstractmethod
25+
def arg_types(self) -> Dict[str, Any]:
26+
...
27+
28+
@property
29+
@abc.abstractmethod
30+
def name(self) -> str:
31+
...
32+
33+
def _validate_fn_with_arg_types(self, func: Callable) -> None:
34+
"""Validate the function with the argument types."""
35+
sig = inspect.signature(func)
36+
params = sig.parameters
37+
38+
# check the type of each argument
39+
for name, param in params.items():
40+
expected_type = self.arg_types[name]
41+
arg_type = param.annotation
42+
if arg_type is inspect.Parameter.empty:
43+
console.warn(f"Argument type {name} is not specified.")
44+
continue
45+
if arg_type is not expected_type:
46+
raise ValueError(
47+
f"Argument type {name} should be {expected_type}, not {arg_type}."
48+
)
49+
50+
def validate(self, value: Callable) -> None:
51+
"""
52+
Validate the service type
53+
The function should raise a ValueError if the function is not valid.
54+
:param value: the function to validate
55+
"""
56+
self._validate_fn_with_arg_types(value)
57+
58+
@abc.abstractmethod
59+
def router_fn(self, func: Callable) -> Callable:
60+
...
61+
62+
@abc.abstractmethod
63+
def components_spec(self) -> List[ComponentSpec]:
64+
...
65+
66+
67+
def all_components_are_gradio(
68+
inputs: Inputs, outputs: Outputs
69+
) -> bool: # pragma: no cover
70+
"""Check if all components are Gradio components."""
71+
if inputs is None and outputs is None:
72+
return False
73+
74+
if not isinstance(inputs, list):
75+
inputs = inputs is not None and [inputs] or []
76+
if not isinstance(outputs, list):
77+
outputs = outputs is not None and [outputs] or []
78+
79+
try:
80+
import gradio
81+
except ImportError:
82+
gradio = None
83+
84+
return all(
85+
[
86+
gradio is not None,
87+
all([isinstance(inp, gradio.components.Component) for inp in inputs]),
88+
all([isinstance(out, gradio.components.Component) for out in outputs]),
89+
]
90+
)

client/starwhale/api/service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from starwhale.api._impl.service.service import api, Service, ServiceType
2+
from starwhale.api._impl.service.types.llm import LLMChat
23

34
__all__ = [
45
"Service",
56
"api",
67
"ServiceType",
8+
"LLMChat",
79
]

client/tests/core/test_model.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@
5757
from starwhale.core.instance.view import InstanceTermView
5858
from starwhale.base.scheduler.step import Step
5959
from starwhale.core.runtime.process import Process
60-
from starwhale.api._impl.service.types import ServiceType
61-
from starwhale.api._impl.service.service import Api, ServiceSpec
60+
from starwhale.api._impl.service.service import ApiSpec, ServiceSpec
6261
from starwhale.base.client.models.models import (
6362
UserVo,
6463
ModelVo,
@@ -137,10 +136,9 @@ def test_build_workflow(
137136
svc.get_spec.return_value = ServiceSpec(
138137
version="0.0.1",
139138
apis=[
140-
Api(
141-
func=lambda x: x,
139+
ApiSpec(
142140
uri="",
143-
inference_type=ServiceType.QUESTION_ANSWERING,
141+
inference_type="question_answering",
144142
)
145143
],
146144
)
@@ -1378,9 +1376,7 @@ def test_build_with_custom_config_file(
13781376
svc = MagicMock(spec=Service)
13791377
svc.get_spec.return_value = ServiceSpec(
13801378
version="1",
1381-
apis=[
1382-
Api(func=lambda x: x, uri="", inference_type=ServiceType.QUESTION_ANSWERING)
1383-
],
1379+
apis=[ApiSpec(uri="", inference_type="question_answering")],
13841380
)
13851381
svc.example_resources = [example]
13861382
m_get_service.return_value = svc

client/tests/data/sdk/service/default_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from starwhale import PipelineHandler
44
from starwhale.api import service
5-
from starwhale.api._impl.service.types import ServiceType
5+
from starwhale.api._impl.service.types.llm import LLMChat
66

77

88
class MyDefaultClass(PipelineHandler):
@@ -12,6 +12,6 @@ def ppl(self, data: bytes, **kw: t.Any) -> t.Any:
1212
def handler_foo(self, data: t.Any) -> t.Any:
1313
return
1414

15-
@service.api(inference_type=ServiceType.QUESTION_ANSWERING)
15+
@service.api(inference_type=LLMChat(args={"user_input", "history", "temperature"}))
1616
def cmp(self, ppl_result: t.Iterator) -> t.Any:
1717
pass

client/tests/sdk/test_service.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from tests import ROOT_DIR, BaseTestCase
88
from starwhale.core.model.model import StandaloneModel
9+
from starwhale.api._impl.service.types.types import ComponentSpec
910

1011

1112
class ServiceTestCase(BaseTestCase):
@@ -36,7 +37,13 @@ def test_default_class(self):
3637
spec = svc.get_spec()
3738
assert len(spec.apis) == 1
3839
assert spec.apis[0].uri == "cmp"
39-
assert spec.apis[0].inference_type.value == "question_answering"
40+
assert spec.apis[0].inference_type == "llm_chat"
41+
components = spec.apis[0].components_hint
42+
assert set(components) == {
43+
ComponentSpec(name="user_input", type="str"),
44+
ComponentSpec(name="history", type="list"),
45+
ComponentSpec(name="temperature", type="float"),
46+
}
4047

4148
def test_class_without_api(self):
4249
svc = StandaloneModel._get_service(["no_api:NoApi"], self.root)

example/web-handler/main.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
from starwhale.api.service import api, ServiceType
1+
from typing import List
22

3+
from starwhale.api.service import api, LLMChat
34

4-
@api(ServiceType.QUESTION_ANSWERING)
5-
def fake_chat_bot(content: str) -> str:
6-
return f"hello from chat bot with {content}"
5+
6+
@api(inference_type=LLMChat(args={"user_input", "history", "temperature"}))
7+
def fake_chat_bot(
8+
user_input: str, history: List[dict], temperature: float
9+
) -> List[dict]:
10+
result = f"hello from chat bot with {user_input}, and temperature {temperature}"
11+
history.extend(
12+
[{"content": user_input, "role": "user"}, {"content": result, "role": "bot"}]
13+
)
14+
return history

0 commit comments

Comments
 (0)