Skip to content

Commit b26a82c

Browse files
committed
make ramalama-client-core send default model to server
Also move most of the helper functions into ramalamashell class Signed-off-by: Daniel J Walsh <[email protected]>
1 parent 8a604e3 commit b26a82c

File tree

2 files changed

+90
-75
lines changed

2 files changed

+90
-75
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ jobs:
8888
sudo apt-get install podman bats bash codespell python3-argcomplete pipx git cmake libcurl4-openssl-dev
8989
make install-requirements
9090
sudo ./container-images/scripts/build_llama_and_whisper.sh
91+
sudo python -m pip install . --prefix=/usr
9192
9293
- name: install ollama
9394
shell: bash

libexec/ramalama/ramalama-client-core

Lines changed: 89 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@ import urllib.error
1212
import urllib.request
1313

1414

15-
def construct_request_data(conversation_history):
16-
data = {
17-
"stream": True,
18-
"messages": conversation_history,
19-
}
20-
21-
return data
22-
23-
2415
def should_colorize():
2516
t = os.getenv("TERM")
2617
return t and t != "dumb" and sys.stdout.isatty()
@@ -53,42 +44,6 @@ def res(response, color):
5344
return assistant_response
5445

5546

56-
def req(conversation_history, url, parsed_args):
57-
data = construct_request_data(conversation_history)
58-
json_data = json.dumps(data).encode("utf-8")
59-
headers = {
60-
"Content-Type": "application/json",
61-
}
62-
63-
# Create a request
64-
request = urllib.request.Request(url, data=json_data, headers=headers, method="POST")
65-
66-
# Send request
67-
i = 0.01
68-
response = None
69-
for c in itertools.cycle(['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']):
70-
try:
71-
response = urllib.request.urlopen(request)
72-
break
73-
except Exception:
74-
if sys.stdout.isatty():
75-
print(f"\r{c}", end="", flush=True)
76-
77-
if i > 32:
78-
break
79-
80-
time.sleep(i)
81-
i *= 2
82-
83-
if response:
84-
return res(response, parsed_args.color)
85-
86-
from ramalama.common import perror
87-
perror(f"\rError: could not connect to: {url}")
88-
do_kills(parsed_args)
89-
90-
return None
91-
9247
class RamaLamaShell(cmd.Cmd):
9348
def __init__(self, parsed_args):
9449
super().__init__()
@@ -101,6 +56,23 @@ class RamaLamaShell(cmd.Cmd):
10156
self.prompt = parsed_args.prefix
10257

10358
self.url = f"{parsed_args.host}/v1/chat/completions"
59+
self.models_url = f"{parsed_args.host}/v1/models"
60+
self.models = []
61+
62+
def model(self, index=0):
63+
try:
64+
if len(self.models) == 0:
65+
self.models = self.get_models()
66+
return self.models[index]
67+
except urllib.error.URLError:
68+
return ""
69+
70+
def get_models(self):
71+
request = urllib.request.Request(self.models_url, method="GET")
72+
response = urllib.request.urlopen(request)
73+
for line in response:
74+
line = line.decode("utf-8").strip()
75+
return([d['id'] for d in json.loads(line)["data"]])
10476

10577
def do_EOF(self, user_content):
10678
print("")
@@ -112,7 +84,7 @@ class RamaLamaShell(cmd.Cmd):
11284

11385
self.conversation_history.append({"role": "user", "content": user_content})
11486
self.request_in_process = True
115-
response = req(self.conversation_history, self.url, self.parsed_args)
87+
response = self._req()
11688
if not response:
11789
return True
11890

@@ -121,11 +93,74 @@ class RamaLamaShell(cmd.Cmd):
12193
)
12294
self.request_in_process = False
12395

124-
def do_kills(parsed_args):
125-
if parsed_args.pid2kill:
126-
os.kill(parsed_args.pid2kill, signal.SIGINT)
127-
os.kill(parsed_args.pid2kill, signal.SIGTERM)
128-
os.kill(parsed_args.pid2kill, signal.SIGKILL)
96+
def _req(self):
97+
data = {
98+
"stream": True,
99+
"messages": self.conversation_history,
100+
"model": self.model(),
101+
}
102+
103+
json_data = json.dumps(data).encode("utf-8")
104+
headers = {
105+
"Content-Type": "application/json",
106+
}
107+
108+
# Create a request
109+
request = urllib.request.Request(self.url, data=json_data, headers=headers, method="POST")
110+
111+
# Send request
112+
i = 0.01
113+
response = None
114+
for c in itertools.cycle(['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']):
115+
try:
116+
response = urllib.request.urlopen(request)
117+
break
118+
except Exception:
119+
if sys.stdout.isatty():
120+
print(f"\r{c}", end="", flush=True)
121+
122+
if i > 32:
123+
break
124+
125+
time.sleep(i)
126+
i *= 2
127+
128+
if response:
129+
return res(response, self.parsed_args.color)
130+
131+
print(f"\rError: could not connect to: {self.url}", file=sys.stderr)
132+
self.kills()
133+
134+
return None
135+
136+
def kills(self):
137+
if self.parsed_args.pid2kill:
138+
os.kill(self.parsed_args.pid2kill, signal.SIGINT)
139+
os.kill(self.parsed_args.pid2kill, signal.SIGTERM)
140+
os.kill(self.parsed_args.pid2kill, signal.SIGKILL)
141+
142+
def loop(self):
143+
while True:
144+
self.request_in_process = False
145+
try:
146+
self.cmdloop()
147+
except KeyboardInterrupt:
148+
print("")
149+
if not self.request_in_process:
150+
print("Use Ctrl + d or /bye or exit to quit.")
151+
152+
continue
153+
154+
break
155+
156+
def handle_args(self):
157+
if self.parsed_args.ARGS:
158+
self.default(" ".join(self.parsed_args.ARGS))
159+
self.kills()
160+
return True
161+
162+
return False
163+
129164

130165
def parse_arguments(args):
131166
parser = argparse.ArgumentParser(description="Run ramalama client core")
@@ -156,38 +191,17 @@ def parse_arguments(args):
156191

157192
return parser.parse_args(args)
158193

159-
def handle_args(parsed_args, ramalama_shell):
160-
if parsed_args.ARGS:
161-
ramalama_shell.default(" ".join(parsed_args.ARGS))
162-
do_kills(parsed_args)
163-
return True
164-
165-
return False
166-
167-
def run_shell_loop(ramalama_shell):
168-
while True:
169-
ramalama_shell.request_in_process = False
170-
try:
171-
ramalama_shell.cmdloop()
172-
except KeyboardInterrupt:
173-
print("")
174-
if not ramalama_shell.request_in_process:
175-
print("Use Ctrl + d or /bye or exit to quit.")
176-
177-
continue
178-
179-
break
180194

181195
def main(args):
182196
sys.path.append('./')
183197

184198
parsed_args = parse_arguments(args)
185199
ramalama_shell = RamaLamaShell(parsed_args)
186-
if handle_args(parsed_args, ramalama_shell):
200+
if ramalama_shell.handle_args():
187201
return 0
188202

189-
run_shell_loop(ramalama_shell)
190-
do_kills(parsed_args)
203+
ramalama_shell.loop()
204+
ramalama_shell.kills()
191205

192206

193207
if __name__ == '__main__':

0 commit comments

Comments
 (0)