1414import re
1515from datetime import datetime
1616from threading import Thread , Lock
17+ from typing_extensions import override
1718import json
1819import math
1920from MAVProxy .modules .lib import param_help
2021
2122try :
22- from openai import OpenAI
23+ from openai import OpenAI , AssistantEventHandler
2324except Exception :
2425 print ("chat: failed to import openai. See https://ardupilot.org/mavproxy/docs/modules/chat.html" )
2526 exit ()
2627
2728
29+ class EventHandler (AssistantEventHandler ):
30+ def __init__ (self , chat_openai ):
31+ # record reference to chat_openai object
32+ self .chat_openai = chat_openai
33+
34+ # initialise parent class (assistant event handler)
35+ super ().__init__ ()
36+
37+ @override
38+ def on_event (self , event ):
39+ # record run id so that it can be cancelled if required
40+ self .chat_openai .latest_run_id = event .data .id
41+
42+ # display the event in the status field
43+ event_string_array = event .event .split ("." )
44+ event_string = event_string_array [- 1 ]
45+ self .chat_openai .send_status (event_string )
46+
47+ # requires_action events handled by function calls
48+ if event .event == 'thread.run.requires_action' :
49+ self .chat_openai .handle_function_call (event )
50+
51+ # display reply text in the reply window
52+ if (event .event == "thread.message.delta" ):
53+ stream_text = event .data .delta .content [0 ].text .value
54+ self .chat_openai .send_reply (stream_text )
55+
56+
2857class chat_openai ():
29- def __init__ (self , mpstate , status_cb = None , wait_for_command_ack_fn = None ):
58+ def __init__ (self , mpstate , status_cb = None , reply_cb = None , wait_for_command_ack_fn = None ):
3059 # keep reference to mpstate
3160 self .mpstate = mpstate
3261
3362 # keep reference to status callback
3463 self .status_cb = status_cb
64+ self .reply_cb = reply_cb
3565
3666 # keep reference to wait_for_command_ack_fn
3767 self .wait_for_command_ack_fn = wait_for_command_ack_fn
@@ -48,7 +78,7 @@ def __init__(self, mpstate, status_cb=None, wait_for_command_ack_fn=None):
4878 self .client = None
4979 self .assistant = None
5080 self .assistant_thread = None
51- self .latest_run = None
81+ self .latest_run_id = None
5282
5383 # check connection to OpenAI assistant and connect if necessary
5484 # returns True if connection is good, False if not
@@ -102,22 +132,15 @@ def set_api_key(self, api_key_str):
102132
103133 # cancel the active run
104134 def cancel_run (self ):
105- # check the active thread and run
106- if (self .assistant_thread and self .latest_run is not None ):
107- run_status = self .latest_run .status
108- if (run_status != "completed" and run_status != "cancelled" and
109- run_status != "cancelling" ):
110-
111- # cancel the active run
112- self .client .beta .threads .runs .cancel (
113- thread_id = self .assistant_thread .id ,
114- run_id = self .run .id
115- )
116- else :
117- if (self .latest_run .status == "completed" ):
118- print ("Chat is completed, cannot be cancelled" )
119- elif (self .latest_run .status == "cancelled" ):
120- print ("Chat is cancelled" )
135+ # check the active thread and run id
136+ if self .assistant_thread and self .latest_run_id is not None :
137+ # cancel the run
138+ self .client .beta .threads .runs .cancel (
139+ thread_id = self .assistant_thread .id ,
140+ run_id = self .latest_run_id
141+ )
142+ else :
143+ self .send_status ("No active run to cancel" )
121144
122145 # send text to assistant
123146 def send_to_assistant (self , text ):
@@ -126,7 +149,8 @@ def send_to_assistant(self, text):
126149
127150 # check connection
128151 if not self .check_connection ():
129- return "chat: failed to connect to OpenAI"
152+ self .send_reply ("chat: failed to connect to OpenAI" )
153+ return
130154
131155 # create a new message
132156 input_message = self .client .beta .threads .messages .create (
@@ -135,91 +159,46 @@ def send_to_assistant(self, text):
135159 content = text
136160 )
137161 if input_message is None :
138- return "chat: failed to create input message"
162+ self .send_reply ("chat: failed to create input message" )
163+ return
164+
165+ # create event handler
166+ event_handler = EventHandler (self )
139167
140168 # create a run
141- self . run = self .client .beta .threads .runs .create (
169+ with self .client .beta .threads .runs .stream (
142170 thread_id = self .assistant_thread .id ,
143- assistant_id = self .assistant .id
144- )
145- if self .run is None :
146- return "chat: failed to create run"
147-
148- # wait for run to complete
149- run_done = False
150- while not run_done :
151- # wait for one second
152- time .sleep (0.1 )
153-
154- # retrieve the run
155- self .latest_run = self .client .beta .threads .runs .retrieve (
171+ assistant_id = self .assistant .id ,
172+ event_handler = event_handler
173+ ) as stream :
174+ stream .until_done ()
175+
176+ # retrieve the run and print its final status
177+ if self .assistant_thread and self .latest_run_id :
178+ run = self .client .beta .threads .runs .retrieve (
156179 thread_id = self .assistant_thread .id ,
157- run_id = self .run . id
180+ run_id = self .latest_run_id
158181 )
159-
160- # init failure message
161- failure_message = None
162-
163- # check run status
164- if self .latest_run .status in ["queued" , "in_progress" , "cancelling" ]:
165- run_done = False
166- elif self .latest_run .status in ["cancelled" , "completed" , "expired" ]:
167- run_done = True
168- elif self .latest_run .status in ["failed" ]:
169- failure_message = self .latest_run .last_error .message
170- run_done = True
171- elif self .latest_run .status in ["requires_action" ]:
172- self .handle_function_call (self .latest_run )
173- run_done = False
174- else :
175- print ("chat: unrecognised run status" + self .latest_run .status )
176- run_done = True
177-
178- # send status to status callback
179- status_message = self .latest_run .status
180- if failure_message is not None :
181- status_message = status_message + ": " + failure_message
182- self .send_status (status_message )
183-
184- # retrieve messages on the thread
185- reply_messages = self .client .beta .threads .messages .list (self .assistant_thread .id ,
186- order = "asc" ,
187- after = input_message .id )
188-
189- if (self .latest_run .status == "cancelled" ):
190- return "cancelled successfully"
191- elif reply_messages is None :
192- return "chat: failed to retrieve messages"
193-
194- # concatenate all messages into a single reply skipping the first which is our question
195- reply = ""
196- need_newline = False
197- for message in reply_messages .data :
198- reply = reply + message .content [0 ].text .value
199- if need_newline :
200- reply = reply + "\n "
201- need_newline = True
202-
203- if reply is None or reply == "" :
204- return "chat: failed to retrieve latest reply"
205- return reply
182+ if run is not None :
183+ self .send_status (run .status )
184+ else :
185+ self .send_status ("done" )
206186
207187 # handle function call request from assistant
208- # on success this returns the text response that should be sent to the assistant, returns None on failure
209- def handle_function_call (self , run ):
188+ def handle_function_call (self , event ):
210189
211190 # sanity check required action (this should never happen)
212- if run . required_action is None :
191+ if ( event . event != "thread.run.requires_action" ) :
213192 print ("chat::handle_function_call: assistant function call empty" )
214- return None
193+ return
215194
216195 # check format
217- if run .required_action .submit_tool_outputs is None :
196+ if event . data .required_action .submit_tool_outputs is None :
218197 print ("chat::handle_function_call: submit tools outputs empty" )
219- return None
198+ return
220199
221200 tool_outputs = []
222- for tool_call in run .required_action .submit_tool_outputs .tool_calls :
201+ for tool_call in event . data .required_action .submit_tool_outputs .tool_calls :
223202 # init output to None
224203 output = "invalid function call"
225204
@@ -274,13 +253,26 @@ def handle_function_call(self, run):
274253
275254 # send function replies to assistant
276255 try :
277- self .client .beta .threads .runs .submit_tool_outputs (
278- thread_id = run .thread_id ,
279- run_id = run .id ,
280- tool_outputs = tool_outputs )
256+ stream = self .client .beta .threads .runs .submit_tool_outputs (
257+ thread_id = event .data .thread_id ,
258+ run_id = event .data .id ,
259+ tool_outputs = tool_outputs ,
260+ stream = True )
261+
262+ for event in stream :
263+ # requires_action events handled by function calls
264+ if event .event == 'thread.run.requires_action' :
265+ self .handle_function_call (event )
266+
267+ # display reply text in the reply window
268+ if (event .event == "thread.message.delta" ):
269+ stream_text = event .data .delta .content [0 ].text .value
270+ self .send_reply (stream_text )
271+
281272 except Exception :
282273 print ("chat: error replying to function call" )
283274 print (tool_outputs )
275+ return
284276
285277 # get the current date and time in the format, Saturday, June 24, 2023 6:14:14 PM
286278 def get_current_datetime (self , arguments ):
@@ -800,6 +792,10 @@ def send_status(self, status):
800792 if self .status_cb is not None :
801793 self .status_cb (status )
802794
795+ def send_reply (self , reply ):
796+ if self .reply_cb is not None :
797+ self .reply_cb (reply )
798+
803799 # returns true if string contains regex characters
804800 def contains_regex (self , string ):
805801 regex_characters = ".^$*+?{}[]\\ |()"
0 commit comments