File tree Expand file tree Collapse file tree 4 files changed +59
-2
lines changed
model_executor/guided_decoding Expand file tree Collapse file tree 4 files changed +59
-2
lines changed Original file line number Diff line number Diff line change @@ -837,6 +837,39 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
837
837
assert loaded == {"result" : 2 }, loaded
838
838
839
839
840
+ @pytest .mark .asyncio
841
+ async def test_response_format_json_schema (client : openai .AsyncOpenAI ):
842
+ for _ in range (2 ):
843
+ resp = await client .chat .completions .create (
844
+ model = MODEL_NAME ,
845
+ messages = [{
846
+ "role" :
847
+ "user" ,
848
+ "content" : ('what is 1+1? please respond with a JSON object, '
849
+ 'the format is {"result": 2}' )
850
+ }],
851
+ response_format = {
852
+ "type" : "json_schema" ,
853
+ "json_schema" : {
854
+ "name" : "foo_test" ,
855
+ "schema" : {
856
+ "type" : "object" ,
857
+ "properties" : {
858
+ "result" : {
859
+ "type" : "integer"
860
+ },
861
+ },
862
+ },
863
+ }
864
+ })
865
+
866
+ content = resp .choices [0 ].message .content
867
+ assert content is not None
868
+
869
+ loaded = json .loads (content )
870
+ assert loaded == {"result" : 2 }, loaded
871
+
872
+
840
873
@pytest .mark .asyncio
841
874
async def test_extra_fields (client : openai .AsyncOpenAI ):
842
875
with pytest .raises (BadRequestError ) as exc_info :
Original file line number Diff line number Diff line change @@ -85,9 +85,19 @@ class UsageInfo(OpenAIBaseModel):
85
85
completion_tokens : Optional [int ] = 0
86
86
87
87
88
+ class JsonSchemaResponseFormat (OpenAIBaseModel ):
89
+ name : str
90
+ description : Optional [str ] = None
91
+ # schema is the field in openai but that causes conflicts with pydantic so
92
+ # instead use json_schema with an alias
93
+ json_schema : Optional [Dict [str , Any ]] = Field (default = None , alias = 'schema' )
94
+ strict : Optional [bool ] = None
95
+
96
+
88
97
class ResponseFormat (OpenAIBaseModel ):
89
- # type must be "json_object" or "text"
90
- type : Literal ["text" , "json_object" ]
98
+ # type must be "json_schema", "json_object" or "text"
99
+ type : Literal ["text" , "json_object" , "json_schema" ]
100
+ json_schema : Optional [JsonSchemaResponseFormat ] = None
91
101
92
102
93
103
class StreamOptions (OpenAIBaseModel ):
Original file line number Diff line number Diff line change @@ -49,6 +49,13 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
49
49
and request .response_format .type == "json_object" ):
50
50
character_level_parser = JsonSchemaParser (
51
51
None ) # None means any json object
52
+ elif (request .response_format is not None
53
+ and request .response_format .type == "json_schema"
54
+ and request .response_format .json_schema is not None
55
+ and request .response_format .json_schema .json_schema is not None ):
56
+ schema = _normalize_json_schema_object (
57
+ request .response_format .json_schema .json_schema )
58
+ character_level_parser = JsonSchemaParser (schema )
52
59
else :
53
60
return None
54
61
Original file line number Diff line number Diff line change @@ -127,6 +127,13 @@ def _get_guide_and_mode(
127
127
and request .response_format is not None
128
128
and request .response_format .type == "json_object" ):
129
129
return JSON_GRAMMAR , GuidedDecodingMode .GRAMMAR
130
+ elif (not isinstance (request , GuidedDecodingRequest )
131
+ and request .response_format is not None
132
+ and request .response_format .type == "json_schema"
133
+ and request .response_format .json_schema is not None
134
+ and request .response_format .json_schema .json_schema is not None ):
135
+ json = json_dumps (request .response_format .json_schema .json_schema )
136
+ return json , GuidedDecodingMode .JSON
130
137
else :
131
138
return None , None
132
139
You can’t perform that action at this time.
0 commit comments