Skip to content

Commit 3c35dd5

Browse files
authored
Feature/shepherd (#20)
* wip: data preparation scripts * wip: update req builder * update workflow * add simulator * minor
1 parent 6b372ef commit 3c35dd5

File tree

24 files changed

+738
-2
lines changed

24 files changed

+738
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ pyrightconfig.json
55
.local
66
.vscode
77
.zed
8+
.data

meta/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ prometheus_client
1818
tqdm
1919
einops
2020
pillow
21+
tenacity
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .req import LLM

scratchpad/utils/client/req.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
from typing import Optional
3+
import requests
4+
from tenacity import retry, stop_after_attempt, wait_fixed
5+
6+
7+
class LLM:
8+
def __init__(
9+
self,
10+
model: str,
11+
endpoint: Optional[str] = None,
12+
api_key: Optional[str] = None,
13+
system_prompt: Optional[str] = None,
14+
):
15+
if not endpoint:
16+
endpoint = os.environ.get("RC_API_BASE", None)
17+
if not api_key:
18+
api_key = os.environ.get("RC_API_KEY", None)
19+
if not endpoint or not api_key:
20+
raise ValueError("API key or endpoint not found")
21+
if not system_prompt:
22+
system_prompt = "You are a helpful assistant."
23+
self.model = model
24+
self.endpoint = endpoint + "/chat/completions"
25+
self.api_key = api_key
26+
self.system_prompt = system_prompt
27+
self.headers = {"Authorization": f"Bearer {self.api_key}"}
28+
29+
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
30+
def __call__(self, prompt: str):
31+
data = {
32+
"model": self.model,
33+
"messages": [
34+
{"role": "system", "content": self.system_prompt},
35+
{
36+
"role": "user",
37+
"content": prompt,
38+
},
39+
],
40+
}
41+
try:
42+
res = requests.post(
43+
self.endpoint,
44+
headers=self.headers,
45+
json=data,
46+
)
47+
result = res.json()
48+
except Exception as e:
49+
print(f"Error calling LLM: {res.text}")
50+
return None
51+
return result["choices"][0]["message"]["content"]

scratchpad/utils/ui/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ui import make_table

scratchpad/utils/ui/ui.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from rich.console import Console
2+
from rich.table import Table
3+
import numpy as np
4+
5+
console = Console()
6+
7+
8+
def make_table(title, data):
9+
keys = data[0].keys()
10+
table = Table(title=title)
11+
colors = ["cyan", "magenta", "green", "yellow", "blue", "red", "black"]
12+
13+
for idx, column in enumerate(keys):
14+
table.add_column(column, justify="Right", style=colors[idx % len(colors)])
15+
16+
for row in data:
17+
for key in keys:
18+
if type(row[key]) == np.float64 or type(row[key]) == float:
19+
row[key] = str(round(row[key], 2))
20+
table.add_row(*[row[key] for key in keys])
21+
return table

scripts/serve_llama_1b.sh

100644100755
File mode changed.

tools/client/__init__.py

Whitespace-only changes.

tools/client/register.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import time
2+
import requests
3+
4+
5+
def update_peer(endpoint, model, ipv4, port):
6+
peer = {
7+
"service": [
8+
{
9+
"name": "llm",
10+
"status": "online",
11+
"hardware": [],
12+
"host": ipv4,
13+
"port": port,
14+
"identity_group": [f"model={model}"],
15+
}
16+
]
17+
}
18+
res = requests.post(endpoint + "/v1/dnt/_node", json=peer)
19+
print(res.text)
20+
21+
22+
def health_check(args):
23+
local_address = f"http://{args.local_ip}:{args.service_port}/health"
24+
is_healthy = False
25+
while not is_healthy:
26+
try:
27+
print(f"Checking health of service at {local_address}", flush=True)
28+
res = requests.get(local_address, timeout=5)
29+
print(f"Service health check response: {res.status_code}", flush=True)
30+
if res.status_code == 200:
31+
is_healthy = True
32+
else:
33+
print(f"Service not ready yet, waiting for 5 seconds", flush=True)
34+
time.sleep(5)
35+
except Exception as e:
36+
print(f"Service not ready yet, waiting for 5 seconds", flush=True)
37+
time.sleep(5)
38+
return is_healthy
39+
40+
41+
def register(args):
42+
print(f"Registering service with config: {args}")
43+
if health_check(args):
44+
update_peer(args.ocf_addr, args.model_name, args.local_ip, args.service_port)
45+
else:
46+
print("Service is not healthy")
47+
48+
49+
if __name__ == "__main__":
50+
import argparse
51+
52+
parser = argparse.ArgumentParser(
53+
description="Register local service to OCF network"
54+
)
55+
parser.add_argument("--model-name", type=str, help="Model Name")
56+
parser.add_argument("--ocf-addr", type=str, default="http://localhost:8092")
57+
parser.add_argument("--local-ip", type=str, default="localhost")
58+
parser.add_argument("--service-port", type=str, default="8000")
59+
register(parser.parse_args())

tools/client/req.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import os
2+
import anyio
3+
import requests
14
import aiohttp
25
import asyncio
3-
from typing import Dict
6+
from typing import Dict, Optional
47

58

69
async def async_request(endpoint, req: Dict):

0 commit comments

Comments
 (0)