Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions libs/langchain/langchain/output_parsers/openai_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json
from typing import Any, Dict, List, Type, Union

Expand Down Expand Up @@ -25,8 +26,8 @@ def parse_result(self, result: List[Generation]) -> Any:
)
message = generation.message
try:
func_call = message.additional_kwargs["function_call"]
except ValueError as exc:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}")

if self.args_only:
Expand All @@ -38,11 +39,16 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as the Json object."""

def parse_result(self, result: List[Generation]) -> Any:
func = super().parse_result(result)
function_call_info = super().parse_result(result)
if self.args_only:
return json.loads(func)
func["arguments"] = json.loads(func["arguments"])
return func
try:
return json.loads(function_call_info)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
function_call_info["arguments"] = json.loads(function_call_info["arguments"])
return function_call_info


class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json

import pytest

from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
)
from langchain.schema import BaseMessage, ChatGeneration, OutputParserException
from langchain.schema.messages import AIMessage, HumanMessage


@pytest.fixture
def ai_message() -> AIMessage:
"""Return a simple AIMessage."""
content = "This is a test message"

args = json.dumps(
{
"arg1": "value1",
}
)

function_call = {"name": "function_name", "arguments": args}
additional_kwargs = {"function_call": function_call}
return AIMessage(content=content, additional_kwargs=additional_kwargs)


def test_json_output_function_parser(ai_message: AIMessage) -> None:
"""Test that the JsonOutputFunctionsParser with full output."""
chat_generation = ChatGeneration(message=ai_message)

# Full output
parser = JsonOutputFunctionsParser(args_only=False)
result = parser.parse_result([chat_generation])
assert result == {"arguments": {"arg1": "value1"}, "name": "function_name"}

# Args only
parser = JsonOutputFunctionsParser(args_only=True)
result = parser.parse_result([chat_generation])
assert result == {"arg1": "value1"}

# Verify that the original message is not modified
assert ai_message.additional_kwargs == {
"function_call": {"name": "function_name", "arguments": '{"arg1": "value1"}'}
}


@pytest.mark.parametrize(
"bad_message",
[
# Human message has no function call
HumanMessage(content="This is a test message"),
# AIMessage has no function call information.
AIMessage(content="This is a test message", additional_kwargs={}),
# Bad function call information (arguments should be a string)
AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": {}}
},
),
# Bad function call information (arguments should be proper json)
AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": "noqweqwe"}
},
),
],
)
def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None:
"""Test exceptions raised correctly while using JSON parser."""
chat_generation = ChatGeneration(message=bad_message)

with pytest.raises(OutputParserException):
JsonOutputFunctionsParser().parse_result([chat_generation])