Skip to content

Commit aa07eee

Browse files
committed
Use nb::new_ for factory functions with non-movable types
Changed from nb::init to nb::new_ pattern for all factory function constructors. This allows nanobind to properly handle the raw pointers returned from .release() without trying to move-construct the objects. The nb::new_ pattern is specifically designed for custom allocation and works with types that have deleted copy/move constructors.
1 parent 6988ef8 commit aa07eee

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/python/python.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ NB_MODULE(onnxruntime_genai, m) {
361361
.def("decode", [](OgaTokenizerStream& t, int32_t token) { return t.Decode(token); });
362362

363363
nb::class_<OgaNamedTensors>(m, "NamedTensors")
364-
.def(nb::init([]() { return OgaNamedTensors::Create(); }))
364+
.def(nb::new_([]() { return OgaNamedTensors::Create().release(); }), nb::rv_policy::take_ownership)
365365
.def("__getitem__", [](OgaNamedTensors& named_tensors, const std::string& name) {
366366
auto tensor = named_tensors.Get(name.c_str());
367367
if (!tensor)
@@ -391,14 +391,14 @@ NB_MODULE(onnxruntime_genai, m) {
391391
return keys; });
392392

393393
nb::class_<OgaTensor>(m, "Tensor")
394-
.def(nb::init([](nb::ndarray<>& v) { return ToOgaTensor(v); }))
394+
.def(nb::new_([](nb::ndarray<>& v) { return ToOgaTensor(v).release(); }), nb::rv_policy::take_ownership)
395395
.def("shape", &OgaTensor::Shape)
396396
.def("type", &OgaTensor::Type)
397397
// .def("data", &OgaTensor::Data) // Exposing raw void* might be unsafe. Consider alternatives.
398398
.def("as_numpy", [](OgaTensor& t) { return ToNumpy(t); });
399399

400400
nb::class_<OgaTokenizer>(m, "Tokenizer")
401-
.def(nb::init([](const OgaModel& model) { return OgaTokenizer::Create(model); }))
401+
.def(nb::new_([](const OgaModel& model) { return OgaTokenizer::Create(model).release(); }), nb::rv_policy::take_ownership)
402402
.def_prop_ro("bos_token_id", &OgaTokenizer::GetBosTokenId)
403403
.def_prop_ro("eos_token_ids", [](const OgaTokenizer& t) { return ToPython(t.GetEosTokenIds()); })
404404
.def_prop_ro("pad_token_id", &OgaTokenizer::GetPadTokenId)
@@ -452,7 +452,7 @@ NB_MODULE(onnxruntime_genai, m) {
452452
.def("create_stream", [](const OgaTokenizer& t) { return OgaTokenizerStream::Create(t); });
453453

454454
nb::class_<OgaConfig>(m, "Config")
455-
.def(nb::init([](const std::string& config_path) { return OgaConfig::Create(config_path.c_str()); }))
455+
.def(nb::new_([](const std::string& config_path) { return OgaConfig::Create(config_path.c_str()).release(); }), nb::rv_policy::take_ownership)
456456
.def("append_provider", &OgaConfig::AppendProvider)
457457
.def("set_provider_option", &OgaConfig::SetProviderOption)
458458
.def("clear_providers", &OgaConfig::ClearProviders)
@@ -488,8 +488,8 @@ NB_MODULE(onnxruntime_genai, m) {
488488
.def("clear_decoder_provider_options_hardware_vendor_id", &OgaConfig::ClearDecoderProviderOptionsHardwareVendorId);
489489

490490
nb::class_<OgaModel>(m, "Model")
491-
.def(nb::init([](const OgaConfig& config) { return OgaModel::Create(config); }))
492-
.def(nb::init([](const std::string& config_path) { return OgaModel::Create(config_path.c_str()); }))
491+
.def(nb::new_([](const OgaConfig& config) { return OgaModel::Create(config).release(); }), nb::rv_policy::take_ownership)
492+
.def(nb::new_([](const std::string& config_path) { return OgaModel::Create(config_path.c_str()).release(); }), nb::rv_policy::take_ownership)
493493
.def_prop_ro("type", [](const OgaModel& model) -> std::string { return model.GetType().p_; })
494494
.def_prop_ro("device_type", [](const OgaModel& model) -> std::string { return model.GetDeviceType().p_; }, "The device type the model is running on")
495495
// Assuming OgaMultiModalProcessor::Create returns a unique_ptr or needs management
@@ -636,12 +636,12 @@ NB_MODULE(onnxruntime_genai, m) {
636636
});
637637

638638
nb::class_<OgaAdapters>(m, "Adapters")
639-
.def(nb::init([](OgaModel& model) { return OgaAdapters::Create(model); }))
639+
.def(nb::new_([](OgaModel& model) { return OgaAdapters::Create(model).release(); }), nb::rv_policy::take_ownership)
640640
.def("unload", &OgaAdapters::UnloadAdapter)
641641
.def("load", &OgaAdapters::LoadAdapter);
642642

643643
nb::class_<OgaRequest>(m, "Request")
644-
.def(nb::init([](PyGeneratorParams& params) { return OgaRequest::Create(*params.params_); }))
644+
.def(nb::new_([](PyGeneratorParams& params) { return OgaRequest::Create(*params.params_).release(); }), nb::rv_policy::take_ownership)
645645
.def("add_tokens", [](OgaRequest& request, nb::ndarray<const int32_t, nb::shape<-1>, nb::c_contig, nb::device::cpu> tokens) {
646646
auto sequences = OgaSequences::Create(); // Needs management?
647647
auto tokens_span = ToSpan(tokens);
@@ -666,7 +666,7 @@ NB_MODULE(onnxruntime_genai, m) {
666666
return nb::borrow<nb::object>(static_cast<PyObject*>(opaque_data_ptr)); });
667667

668668
nb::class_<OgaEngine>(m, "Engine")
669-
.def(nb::init([](OgaModel& model) { return OgaEngine::Create(model); }))
669+
.def(nb::new_([](OgaModel& model) { return OgaEngine::Create(model).release(); }), nb::rv_policy::take_ownership)
670670
.def("add_request", &OgaEngine::Add)
671671
.def("step", &OgaEngine::Step)
672672
.def("remove_request", &OgaEngine::Remove)

0 commit comments

Comments
 (0)