Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions cpp/json_schema_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ class JSONSchemaConverter {
std::string VisitConst(const picojson::object& schema, const std::string& rule_name);

/*! \brief Visit an enum schema. */
std::string VisitEnum(const picojson::object& schema, const std::string& rule_name);
std::string VisitEnum(
const picojson::object& schema, const std::string& rule_name, const JSONFormat json_format
);

/*! \brief Convert the JSON string to a printable string that can be shown in BNF. */
std::string JSONStrToPrintableStr(const std::string& json_str);
Expand Down Expand Up @@ -873,7 +875,7 @@ std::string JSONSchemaConverter::VisitSchema(
} else if (schema_obj.count("const")) {
return VisitConst(schema_obj, rule_name);
} else if (schema_obj.count("enum")) {
return VisitEnum(schema_obj, rule_name);
return VisitEnum(schema_obj, rule_name, json_format);
} else if (schema_obj.count("anyOf") || schema_obj.count("oneOf")) {
return VisitAnyOf(schema_obj, rule_name);
} else if (schema_obj.count("allOf")) {
Expand Down Expand Up @@ -979,7 +981,7 @@ std::string JSONSchemaConverter::VisitConst(
}

std::string JSONSchemaConverter::VisitEnum(
const picojson::object& schema, const std::string& rule_name
const picojson::object& schema, const std::string& rule_name, const JSONFormat json_format
) {
XGRAMMAR_CHECK(schema.count("enum"));
std::string result = "";
Expand All @@ -989,7 +991,17 @@ std::string JSONSchemaConverter::VisitEnum(
result += " | ";
}
++idx;
result += "(\"" + JSONStrToPrintableStr(value.serialize()) + "\")";
if (json_format == JSONFormat::kJSON) {
result += "(\"" + JSONStrToPrintableStr(value.serialize()) + "\")";
} else if (json_format == JSONFormat::kXML) {
auto inner = JSONStrToPrintableStr(value.serialize());
// If the inner is a json style string, remove the quotation marks.
if (inner.size() >= 4 && inner.substr(0, 2) == "\\\"" &&
inner.substr(inner.size() - 2, 2) == "\\\"") {
inner = inner.substr(2, inner.size() - 4);
}
result += "(\"" + inner + "\")";
}
}
return result;
}
Expand Down
50 changes: 50 additions & 0 deletions cpp/nanobind/nanobind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,61 @@
#include "../testing.h"
#include "python_methods.h"
#include "xgrammar/exception.h"
#include "xgrammar/grammar.h"
#include "xgrammar/matcher.h"

namespace nb = nanobind;

namespace xgrammar {

Grammar Grammar_ApplyStructuralTagTemplate(
const std::string& structural_tag_template, const nb::kwargs& kwargs
) {
std::unordered_map<std::string, std::vector<std::unordered_map<std::string, std::string>>> values;
for (const auto& [key, value] : kwargs) {
nb::str key_str;
if (!nb::try_cast(key, key_str)) {
throw nb::type_error("Expected a string key for structural tag template values");
}
nb::list value_list;
if (!nb::try_cast(value, value_list)) {
throw nb::type_error("Expected a list of dictionaries for structural tag template values");
}
std::vector<std::unordered_map<std::string, std::string>> value_vec;
value_vec.reserve(value_list.size());
for (const auto& item : value_list) {
nb::dict item_dict;
if (!nb::try_cast(item, item_dict)) {
throw nb::type_error(
"Expected a dictionary for each item in the list of structural tag template values"
);
}
std::unordered_map<std::string, std::string> item_map;
for (const auto& [item_key, item_value] : item_dict) {
nb::str item_key_str, item_value_str;
if (!nb::try_cast(item_key, item_key_str)) {
throw nb::type_error(
"Expected a string key for each item in the structural tag template dictionary"
);
}
if (!nb::try_cast(item_value, item_value_str)) {
throw nb::type_error(
"Expected a string for each value in the structural tag template dictionary"
);
}
item_map[item_key_str.c_str()] = item_value_str.c_str();
}
value_vec.push_back(std::move(item_map));
}
values[key_str.c_str()] = std::move(value_vec);
}
auto result = ApplyStructuralTagTemplate(structural_tag_template, values).ToVariant();
if (std::holds_alternative<StructuralTagError>(result)) {
ThrowVariantError(std::get<StructuralTagError>(result));
}
return std::get<Grammar>(result);
}

std::vector<std::string> CommonEncodedVocabType(
const nb::typed<nb::list, std::variant<std::string, nb::bytes>> encoded_vocab
) {
Expand Down Expand Up @@ -203,6 +252,7 @@ NB_MODULE(xgrammar_bindings, m) {
&Grammar_FromStructuralTag,
nb::call_guard<nb::gil_scoped_release>()
)
.def_static("apply_structural_tag_template", &Grammar_ApplyStructuralTagTemplate)
.def_static("builtin_json_grammar", &Grammar::BuiltinJSONGrammar)
.def_static("union", &Grammar::Union, nb::call_guard<nb::gil_scoped_release>())
.def_static("concat", &Grammar::Concat, nb::call_guard<nb::gil_scoped_release>())
Expand Down
1 change: 1 addition & 0 deletions cpp/nanobind/python_methods.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <vector>

#include "../grammar_impl.h"
#include "../structural_tag.h"
#include "../support/logging.h"
#include "../support/utils.h"
#include "xgrammar/exception.h"
Expand Down
2 changes: 2 additions & 0 deletions cpp/nanobind/python_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
#include <optional>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>

#include "../structural_tag.h"
#include "xgrammar/tokenizer_info.h"

namespace xgrammar {
Expand Down
Loading
Loading