Skip to content

Commit 8108c3a

Browse files
authored
Merge pull request #627 from h2oai/sajith/main/fic-client-text-completion-return-value
[Client] Parse the return value from `/submit_nochat_api` to extract the response
2 parents 4324ae5 + c234868 commit 8108c3a

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

client/h2ogpt_client/_core.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import asyncio
23
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, ValuesView
34

@@ -69,7 +70,7 @@ def create(
6970
number_returns: int = 1,
7071
system_pre_context: str = "",
7172
langchain_mode: LangChainMode = LangChainMode.DISABLED,
72-
system_prompt: str = '',
73+
system_prompt: str = "",
7374
) -> "TextCompletion":
7475
"""
7576
Creates a new text completion.
@@ -95,7 +96,8 @@ def create(
9596
:param system_pre_context: directly pre-appended without prompt processing
9697
:param langchain_mode: LangChain mode
9798
:param add_chat_history_to_context: Whether to add chat history to context
98-
:param system_prompt: Universal system prompt to override prompt_type's system prompt
99+
:param system_prompt: Universal system prompt to override prompt_type's system
100+
prompt
99101
"""
100102
params = _utils.to_h2ogpt_params(locals().copy())
101103
params["instruction"] = "" # empty when chat_mode is False
@@ -116,7 +118,7 @@ def create(
116118
params["document_choice"] = []
117119
params["pre_prompt_summary"] = ""
118120
params["prompt_summary"] = ""
119-
params['system_prompt'] = ''
121+
params["system_prompt"] = ""
120122
return TextCompletion(self._client, params)
121123

122124

@@ -133,6 +135,10 @@ def _get_parameters(self, prompt: str) -> OrderedDict[str, Any]:
133135
self._parameters["instruction_nochat"] = prompt
134136
return self._parameters
135137

138+
@staticmethod
139+
def _get_reply(response: str) -> str:
140+
return ast.literal_eval(response)["response"]
141+
136142
async def complete(self, prompt: str) -> str:
137143
"""
138144
Complete this text completion.
@@ -141,9 +147,10 @@ async def complete(self, prompt: str) -> str:
141147
:return: response from the model
142148
"""
143149

144-
return await self._client._predict_async(
150+
response = await self._client._predict_async(
145151
str(dict(self._get_parameters(prompt))), api_name=self._API_NAME
146152
)
153+
return self._get_reply(response)
147154

148155
def complete_sync(self, prompt: str) -> str:
149156
"""
@@ -152,9 +159,10 @@ def complete_sync(self, prompt: str) -> str:
152159
:param prompt: text prompt to generate completion for
153160
:return: response from the model
154161
"""
155-
return self._client._predict(
162+
response = self._client._predict(
156163
str(dict(self._get_parameters(prompt))), api_name=self._API_NAME
157164
)
165+
return self._get_reply(response)
158166

159167

160168
class ChatCompletionCreator:
@@ -180,7 +188,7 @@ def create(
180188
number_returns: int = 1,
181189
system_pre_context: str = "",
182190
langchain_mode: LangChainMode = LangChainMode.DISABLED,
183-
system_prompt: str = '',
191+
system_prompt: str = "",
184192
) -> "ChatCompletion":
185193
"""
186194
Creates a new chat completion.
@@ -205,7 +213,8 @@ def create(
205213
:param number_returns:
206214
:param system_pre_context: directly pre-appended without prompt processing
207215
:param langchain_mode: LangChain mode
208-
:param system_prompt: Universal system prompt to override prompt_type's system prompt
216+
:param system_prompt: Universal system prompt to override prompt_type's system
217+
prompt
209218
"""
210219
params = _utils.to_h2ogpt_params(locals().copy())
211220
params["instruction"] = None # future prompts
@@ -217,7 +226,7 @@ def create(
217226
params["instruction_nochat"] = "" # empty when chat_mode is True
218227
params["langchain_mode"] = langchain_mode.value # convert to serializable type
219228
params["add_chat_history_to_context"] = False # relevant only for the UI
220-
params["system_prompt"] = ''
229+
params["system_prompt"] = ""
221230
params["langchain_action"] = LangChainAction.QUERY.value
222231
params["langchain_agents"] = []
223232
params["top_k_docs"] = 4 # langchain: number of document chunks

0 commit comments

Comments
 (0)