Skip to content

Commit 6bd0450

Browse files
adityaomar3rmackay9
authored andcommitted
chat: enabled chat streaming feature
Co-authored-by: Randy Mackay <[email protected]>
1 parent 8f5bc6d commit 6bd0450

File tree

2 files changed

+100
-102
lines changed

2 files changed

+100
-102
lines changed

MAVProxy/modules/mavproxy_chat/chat_openai.py

Lines changed: 90 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,54 @@
1414
import re
1515
from datetime import datetime
1616
from threading import Thread, Lock
17+
from typing_extensions import override
1718
import json
1819
import math
1920
from MAVProxy.modules.lib import param_help
2021

2122
try:
22-
from openai import OpenAI
23+
from openai import OpenAI, AssistantEventHandler
2324
except Exception:
2425
print("chat: failed to import openai. See https://ardupilot.org/mavproxy/docs/modules/chat.html")
2526
exit()
2627

2728

29+
class EventHandler(AssistantEventHandler):
30+
def __init__(self, chat_openai):
31+
# record reference to chat_openai object
32+
self.chat_openai = chat_openai
33+
34+
# initialise parent class (assistant event handler)
35+
super().__init__()
36+
37+
@override
38+
def on_event(self, event):
39+
# record run id so that it can be cancelled if required
40+
self.chat_openai.latest_run_id = event.data.id
41+
42+
# display the event in the status field
43+
event_string_array = event.event.split(".")
44+
event_string = event_string_array[-1]
45+
self.chat_openai.send_status(event_string)
46+
47+
# requires_action events handled by function calls
48+
if event.event == 'thread.run.requires_action':
49+
self.chat_openai.handle_function_call(event)
50+
51+
# display reply text in the reply window
52+
if (event.event == "thread.message.delta"):
53+
stream_text = event.data.delta.content[0].text.value
54+
self.chat_openai.send_reply(stream_text)
55+
56+
2857
class chat_openai():
29-
def __init__(self, mpstate, status_cb=None, wait_for_command_ack_fn=None):
58+
def __init__(self, mpstate, status_cb=None, reply_cb=None, wait_for_command_ack_fn=None):
3059
# keep reference to mpstate
3160
self.mpstate = mpstate
3261

3362
# keep reference to status callback
3463
self.status_cb = status_cb
64+
self.reply_cb = reply_cb
3565

3666
# keep reference to wait_for_command_ack_fn
3767
self.wait_for_command_ack_fn = wait_for_command_ack_fn
@@ -48,7 +78,7 @@ def __init__(self, mpstate, status_cb=None, wait_for_command_ack_fn=None):
4878
self.client = None
4979
self.assistant = None
5080
self.assistant_thread = None
51-
self.latest_run = None
81+
self.latest_run_id = None
5282

5383
# check connection to OpenAI assistant and connect if necessary
5484
# returns True if connection is good, False if not
@@ -102,22 +132,15 @@ def set_api_key(self, api_key_str):
102132

103133
# cancel the active run
104134
def cancel_run(self):
105-
# check the active thread and run
106-
if (self.assistant_thread and self.latest_run is not None):
107-
run_status = self.latest_run.status
108-
if (run_status != "completed" and run_status != "cancelled" and
109-
run_status != "cancelling"):
110-
111-
# cancel the active run
112-
self.client.beta.threads.runs.cancel(
113-
thread_id=self.assistant_thread.id,
114-
run_id=self.run.id
115-
)
116-
else:
117-
if (self.latest_run.status == "completed"):
118-
print("Chat is completed, cannot be cancelled")
119-
elif (self.latest_run.status == "cancelled"):
120-
print("Chat is cancelled")
135+
# check the active thread and run id
136+
if self.assistant_thread and self.latest_run_id is not None:
137+
# cancel the run
138+
self.client.beta.threads.runs.cancel(
139+
thread_id=self.assistant_thread.id,
140+
run_id=self.latest_run_id
141+
)
142+
else:
143+
self.send_status("No active run to cancel")
121144

122145
# send text to assistant
123146
def send_to_assistant(self, text):
@@ -126,7 +149,8 @@ def send_to_assistant(self, text):
126149

127150
# check connection
128151
if not self.check_connection():
129-
return "chat: failed to connect to OpenAI"
152+
self.send_reply("chat: failed to connect to OpenAI")
153+
return
130154

131155
# create a new message
132156
input_message = self.client.beta.threads.messages.create(
@@ -135,91 +159,46 @@ def send_to_assistant(self, text):
135159
content=text
136160
)
137161
if input_message is None:
138-
return "chat: failed to create input message"
162+
self.send_reply("chat: failed to create input message")
163+
return
164+
165+
# create event handler
166+
event_handler = EventHandler(self)
139167

140168
# create a run
141-
self.run = self.client.beta.threads.runs.create(
169+
with self.client.beta.threads.runs.stream(
142170
thread_id=self.assistant_thread.id,
143-
assistant_id=self.assistant.id
144-
)
145-
if self.run is None:
146-
return "chat: failed to create run"
147-
148-
# wait for run to complete
149-
run_done = False
150-
while not run_done:
151-
# wait for one second
152-
time.sleep(0.1)
153-
154-
# retrieve the run
155-
self.latest_run = self.client.beta.threads.runs.retrieve(
171+
assistant_id=self.assistant.id,
172+
event_handler=event_handler
173+
) as stream:
174+
stream.until_done()
175+
176+
# retrieve the run and print its final status
177+
if self.assistant_thread and self.latest_run_id:
178+
run = self.client.beta.threads.runs.retrieve(
156179
thread_id=self.assistant_thread.id,
157-
run_id=self.run.id
180+
run_id=self.latest_run_id
158181
)
159-
160-
# init failure message
161-
failure_message = None
162-
163-
# check run status
164-
if self.latest_run.status in ["queued", "in_progress", "cancelling"]:
165-
run_done = False
166-
elif self.latest_run.status in ["cancelled", "completed", "expired"]:
167-
run_done = True
168-
elif self.latest_run.status in ["failed"]:
169-
failure_message = self.latest_run.last_error.message
170-
run_done = True
171-
elif self.latest_run.status in ["requires_action"]:
172-
self.handle_function_call(self.latest_run)
173-
run_done = False
174-
else:
175-
print("chat: unrecognised run status" + self.latest_run.status)
176-
run_done = True
177-
178-
# send status to status callback
179-
status_message = self.latest_run.status
180-
if failure_message is not None:
181-
status_message = status_message + ": " + failure_message
182-
self.send_status(status_message)
183-
184-
# retrieve messages on the thread
185-
reply_messages = self.client.beta.threads.messages.list(self.assistant_thread.id,
186-
order="asc",
187-
after=input_message.id)
188-
189-
if (self.latest_run.status == "cancelled"):
190-
return "cancelled successfully"
191-
elif reply_messages is None:
192-
return "chat: failed to retrieve messages"
193-
194-
# concatenate all messages into a single reply skipping the first which is our question
195-
reply = ""
196-
need_newline = False
197-
for message in reply_messages.data:
198-
reply = reply + message.content[0].text.value
199-
if need_newline:
200-
reply = reply + "\n"
201-
need_newline = True
202-
203-
if reply is None or reply == "":
204-
return "chat: failed to retrieve latest reply"
205-
return reply
182+
if run is not None:
183+
self.send_status(run.status)
184+
else:
185+
self.send_status("done")
206186

207187
# handle function call request from assistant
208-
# on success this returns the text response that should be sent to the assistant, returns None on failure
209-
def handle_function_call(self, run):
188+
def handle_function_call(self, event):
210189

211190
# sanity check required action (this should never happen)
212-
if run.required_action is None:
191+
if (event.event != "thread.run.requires_action"):
213192
print("chat::handle_function_call: assistant function call empty")
214-
return None
193+
return
215194

216195
# check format
217-
if run.required_action.submit_tool_outputs is None:
196+
if event.data.required_action.submit_tool_outputs is None:
218197
print("chat::handle_function_call: submit tools outputs empty")
219-
return None
198+
return
220199

221200
tool_outputs = []
222-
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
201+
for tool_call in event.data.required_action.submit_tool_outputs.tool_calls:
223202
# init output to None
224203
output = "invalid function call"
225204

@@ -274,13 +253,26 @@ def handle_function_call(self, run):
274253

275254
# send function replies to assistant
276255
try:
277-
self.client.beta.threads.runs.submit_tool_outputs(
278-
thread_id=run.thread_id,
279-
run_id=run.id,
280-
tool_outputs=tool_outputs)
256+
stream = self.client.beta.threads.runs.submit_tool_outputs(
257+
thread_id=event.data.thread_id,
258+
run_id=event.data.id,
259+
tool_outputs=tool_outputs,
260+
stream=True)
261+
262+
for event in stream:
263+
# requires_action events handled by function calls
264+
if event.event == 'thread.run.requires_action':
265+
self.handle_function_call(event)
266+
267+
# display reply text in the reply window
268+
if (event.event == "thread.message.delta"):
269+
stream_text = event.data.delta.content[0].text.value
270+
self.send_reply(stream_text)
271+
281272
except Exception:
282273
print("chat: error replying to function call")
283274
print(tool_outputs)
275+
return
284276

285277
# get the current date and time in the format, Saturday, June 24, 2023 6:14:14 PM
286278
def get_current_datetime(self, arguments):
@@ -800,6 +792,10 @@ def send_status(self, status):
800792
if self.status_cb is not None:
801793
self.status_cb(status)
802794

795+
def send_reply(self, reply):
796+
if self.reply_cb is not None:
797+
self.reply_cb(reply)
798+
803799
# returns true if string contains regex characters
804800
def contains_regex(self, string):
805801
regex_characters = ".^$*+?{}[]\\|()"

MAVProxy/modules/mavproxy_chat/chat_window.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def __init__(self, mpstate, wait_for_command_ack_fn):
1818
self.mpstate = mpstate
1919

2020
# create chat_openai object
21-
self.chat_openai = chat_openai.chat_openai(self.mpstate, self.set_status_text, wait_for_command_ack_fn)
21+
self.chat_openai = chat_openai.chat_openai(self.mpstate, self.set_status_text, self.append_chat_replies,
22+
wait_for_command_ack_fn)
2223

2324
# create chat_voice_to_text object
2425
self.chat_voice_to_text = chat_voice_to_text.chat_voice_to_text()
@@ -182,15 +183,11 @@ def send_text_to_assistant(self):
182183
wx.CallAfter(self.text_input.Clear)
183184

184185
# copy user input text to reply box
185-
orig_text_attr = self.text_reply.GetDefaultStyle()
186186
wx.CallAfter(self.text_reply.SetDefaultStyle, wx.TextAttr(wx.RED))
187-
wx.CallAfter(self.text_reply.AppendText, send_text + "\n")
187+
wx.CallAfter(self.text_reply.AppendText, "\n" + send_text + "\n")
188188

189-
# send text to assistant and place reply in reply box
190-
reply = self.chat_openai.send_to_assistant(send_text)
191-
if reply:
192-
wx.CallAfter(self.text_reply.SetDefaultStyle, orig_text_attr)
193-
wx.CallAfter(self.text_reply.AppendText, reply + "\n\n")
189+
# send text to assistant. replies will be handled by append_chat_replies
190+
self.chat_openai.send_to_assistant(send_text)
194191

195192
# reenable buttons and text input (can't be done from a thread or must use CallAfter)
196193
# disable the cancel button
@@ -206,3 +203,8 @@ def send_text_to_assistant(self):
206203
# set status text
207204
def set_status_text(self, text):
208205
wx.CallAfter(self.text_status.SetValue, text)
206+
207+
# append chat to reply box
208+
def append_chat_replies(self, text):
209+
wx.CallAfter(self.text_reply.SetDefaultStyle, wx.TextAttr(wx.BLACK))
210+
wx.CallAfter(self.text_reply.AppendText, text)

0 commit comments

Comments
 (0)