@@ -12,15 +12,6 @@ import urllib.error
12
12
import urllib .request
13
13
14
14
15
- def construct_request_data (conversation_history ):
16
- data = {
17
- "stream" : True ,
18
- "messages" : conversation_history ,
19
- }
20
-
21
- return data
22
-
23
-
24
15
def should_colorize ():
25
16
t = os .getenv ("TERM" )
26
17
return t and t != "dumb" and sys .stdout .isatty ()
@@ -53,42 +44,6 @@ def res(response, color):
53
44
return assistant_response
54
45
55
46
56
- def req (conversation_history , url , parsed_args ):
57
- data = construct_request_data (conversation_history )
58
- json_data = json .dumps (data ).encode ("utf-8" )
59
- headers = {
60
- "Content-Type" : "application/json" ,
61
- }
62
-
63
- # Create a request
64
- request = urllib .request .Request (url , data = json_data , headers = headers , method = "POST" )
65
-
66
- # Send request
67
- i = 0.01
68
- response = None
69
- for c in itertools .cycle (['⠋' , '⠙' , '⠹' , '⠸' , '⠼' , '⠴' , '⠦' , '⠧' , '⠇' , '⠏' ]):
70
- try :
71
- response = urllib .request .urlopen (request )
72
- break
73
- except Exception :
74
- if sys .stdout .isatty ():
75
- print (f"\r { c } " , end = "" , flush = True )
76
-
77
- if i > 32 :
78
- break
79
-
80
- time .sleep (i )
81
- i *= 2
82
-
83
- if response :
84
- return res (response , parsed_args .color )
85
-
86
- from ramalama .common import perror
87
- perror (f"\r Error: could not connect to: { url } " )
88
- do_kills (parsed_args )
89
-
90
- return None
91
-
92
47
class RamaLamaShell (cmd .Cmd ):
93
48
def __init__ (self , parsed_args ):
94
49
super ().__init__ ()
@@ -101,6 +56,23 @@ class RamaLamaShell(cmd.Cmd):
101
56
self .prompt = parsed_args .prefix
102
57
103
58
self .url = f"{ parsed_args .host } /v1/chat/completions"
59
+ self .models_url = f"{ parsed_args .host } /v1/models"
60
+ self .models = []
61
+
62
+ def model (self , index = 0 ):
63
+ try :
64
+ if len (self .models ) == 0 :
65
+ self .models = self .get_models ()
66
+ return self .models [index ]
67
+ except urllib .error .URLError :
68
+ return ""
69
+
70
+ def get_models (self ):
71
+ request = urllib .request .Request (self .models_url , method = "GET" )
72
+ response = urllib .request .urlopen (request )
73
+ for line in response :
74
+ line = line .decode ("utf-8" ).strip ()
75
+ return ([d ['id' ] for d in json .loads (line )["data" ]])
104
76
105
77
def do_EOF (self , user_content ):
106
78
print ("" )
@@ -112,7 +84,7 @@ class RamaLamaShell(cmd.Cmd):
112
84
113
85
self .conversation_history .append ({"role" : "user" , "content" : user_content })
114
86
self .request_in_process = True
115
- response = req ( self .conversation_history , self . url , self . parsed_args )
87
+ response = self ._req ( )
116
88
if not response :
117
89
return True
118
90
@@ -121,11 +93,74 @@ class RamaLamaShell(cmd.Cmd):
121
93
)
122
94
self .request_in_process = False
123
95
124
- def do_kills (parsed_args ):
125
- if parsed_args .pid2kill :
126
- os .kill (parsed_args .pid2kill , signal .SIGINT )
127
- os .kill (parsed_args .pid2kill , signal .SIGTERM )
128
- os .kill (parsed_args .pid2kill , signal .SIGKILL )
96
+ def _req (self ):
97
+ data = {
98
+ "stream" : True ,
99
+ "messages" : self .conversation_history ,
100
+ "model" : self .model (),
101
+ }
102
+
103
+ json_data = json .dumps (data ).encode ("utf-8" )
104
+ headers = {
105
+ "Content-Type" : "application/json" ,
106
+ }
107
+
108
+ # Create a request
109
+ request = urllib .request .Request (self .url , data = json_data , headers = headers , method = "POST" )
110
+
111
+ # Send request
112
+ i = 0.01
113
+ response = None
114
+ for c in itertools .cycle (['⠋' , '⠙' , '⠹' , '⠸' , '⠼' , '⠴' , '⠦' , '⠧' , '⠇' , '⠏' ]):
115
+ try :
116
+ response = urllib .request .urlopen (request )
117
+ break
118
+ except Exception :
119
+ if sys .stdout .isatty ():
120
+ print (f"\r { c } " , end = "" , flush = True )
121
+
122
+ if i > 32 :
123
+ break
124
+
125
+ time .sleep (i )
126
+ i *= 2
127
+
128
+ if response :
129
+ return res (response , self .parsed_args .color )
130
+
131
+ print (f"\r Error: could not connect to: { self .url } " , file = sys .stderr )
132
+ self .kills ()
133
+
134
+ return None
135
+
136
+ def kills (self ):
137
+ if self .parsed_args .pid2kill :
138
+ os .kill (self .parsed_args .pid2kill , signal .SIGINT )
139
+ os .kill (self .parsed_args .pid2kill , signal .SIGTERM )
140
+ os .kill (self .parsed_args .pid2kill , signal .SIGKILL )
141
+
142
+ def loop (self ):
143
+ while True :
144
+ self .request_in_process = False
145
+ try :
146
+ self .cmdloop ()
147
+ except KeyboardInterrupt :
148
+ print ("" )
149
+ if not self .request_in_process :
150
+ print ("Use Ctrl + d or /bye or exit to quit." )
151
+
152
+ continue
153
+
154
+ break
155
+
156
+ def handle_args (self ):
157
+ if self .parsed_args .ARGS :
158
+ self .default (" " .join (self .parsed_args .ARGS ))
159
+ self .kills ()
160
+ return True
161
+
162
+ return False
163
+
129
164
130
165
def parse_arguments (args ):
131
166
parser = argparse .ArgumentParser (description = "Run ramalama client core" )
@@ -156,38 +191,17 @@ def parse_arguments(args):
156
191
157
192
return parser .parse_args (args )
158
193
159
- def handle_args (parsed_args , ramalama_shell ):
160
- if parsed_args .ARGS :
161
- ramalama_shell .default (" " .join (parsed_args .ARGS ))
162
- do_kills (parsed_args )
163
- return True
164
-
165
- return False
166
-
167
- def run_shell_loop (ramalama_shell ):
168
- while True :
169
- ramalama_shell .request_in_process = False
170
- try :
171
- ramalama_shell .cmdloop ()
172
- except KeyboardInterrupt :
173
- print ("" )
174
- if not ramalama_shell .request_in_process :
175
- print ("Use Ctrl + d or /bye or exit to quit." )
176
-
177
- continue
178
-
179
- break
180
194
181
195
def main (args ):
182
196
sys .path .append ('./' )
183
197
184
198
parsed_args = parse_arguments (args )
185
199
ramalama_shell = RamaLamaShell (parsed_args )
186
- if handle_args (parsed_args , ramalama_shell ):
200
+ if ramalama_shell . handle_args ():
187
201
return 0
188
202
189
- run_shell_loop ( ramalama_shell )
190
- do_kills ( parsed_args )
203
+ ramalama_shell . loop ( )
204
+ ramalama_shell . kills ( )
191
205
192
206
193
207
if __name__ == '__main__' :
0 commit comments