Skip to content

Commit c4a46af

Browse files
snnnyuslepukhin
andauthored
Cherry-pick PR #25626 to 1.23.0 release branch (#25640)
### Description <!-- Describe your changes. --> Move moving weights to memory to the end of Graph::Resolve(). Modify Inject so it copies data into TensorProto according to the C API docs. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> TypeAndShape inference runs as a part of `Resolve()` and it unable to inspect and load the initializers that point to OrtValues at that time. We choose to move TensorProto to OrtValue conversion at the end of `Resolve()`. References: #25579 Co-authored-by: Dmitri Smirnov <[email protected]>
1 parent a033ad6 commit c4a46af

File tree

6 files changed

+63
-110
lines changed

6 files changed

+63
-110
lines changed

include/onnxruntime/core/graph/graph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
740740
common::Status InjectExternalInitializedTensors(const InlinedHashMap<std::string, OrtValue>& external_initializers);
741741

742742
/** This function takes externally provided files in memory for initializers with external
743-
* data and replaces graph initializers with its content.
743+
* data and replaces main graph initializers with its content.
744744
*/
745745
common::Status InjectExternalInitializersFromFilesInMemory(
746746
const InlinedHashMap<PathString, std::pair<char*, size_t>>& external_initializer_files);

onnxruntime/core/framework/session_state_utils.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ common::Status SaveInitializedTensors(
375375
// TODO: if the tensor need be copied, does it have enough room?
376376
ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc));
377377

378-
// ??? Should we ignore this session option if the EP is explictly providing the read only allocator?
378+
// ??? Should we ignore this session option if the EP is explicitly providing the read only allocator?
379379
// bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator;
380380
const bool use_device_allocator_for_initializers =
381381
session_options.config_options.GetConfigOrDefault(
@@ -402,9 +402,11 @@ common::Status SaveInitializedTensors(
402402
std::move(tensor), ort_value));
403403
}
404404
} else {
405+
// if in memory we were expecting to find it above.
406+
ORT_ENFORCE(!utils::HasExternalDataInMemory(tensor_proto));
407+
405408
// We need to deserialize the tensor proto into an OrtValue
406409
// using the preallocated buffer or allocator.
407-
408410
Status st = DeserializeTensorProto(env, graph_loc, tensor_proto,
409411
(memory_buffer.has_value()) ? &*memory_buffer : nullptr,
410412
alloc, default_cpu_alloc, ort_value, data_transfer_mgr,

onnxruntime/core/graph/graph.cc

Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,26 +1226,6 @@ Graph::Graph(const Model& owning_model,
12261226
ArgNameToTypeMap name_to_type_map;
12271227
const auto& model_path = ModelPath();
12281228

1229-
// If the tensor proto data is large enough, externalize it and replace with a tensor_proto
1230-
// with external data reference pointing to an OrtValue, otherwise do nothing.
1231-
auto put_data_maybe_in_memory = [this, &model_path](ONNX_NAMESPACE::TensorProto& tensor_proto) {
1232-
size_t size_in_bytes = 0;
1233-
ORT_THROW_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
1234-
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
1235-
OrtValue ort_value;
1236-
ORT_THROW_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
1237-
CPUAllocator::DefaultInstance(), ort_value));
1238-
constexpr const bool use_tensor_buffer_true = true;
1239-
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
1240-
use_tensor_buffer_true);
1241-
assert(ort_value.IsAllocated());
1242-
auto ins_result = ortvalue_initializers_.insert_or_assign(tensor_proto_to_add.name(), std::move(ort_value));
1243-
ORT_ENFORCE(ins_result.second, "Unexpected duplicate insert or assign OrtValue for tensor: ", tensor_proto_to_add.name(),
1244-
" in the initializer list.");
1245-
tensor_proto = std::move(tensor_proto_to_add);
1246-
}
1247-
};
1248-
12491229
// Process 'Constant' nodes
12501230
// Put the 'TensorProto' stored in the 'Constant' nodes attribute into the graphs initializer list
12511231
for (auto& node : graph_proto_->node()) {
@@ -1265,8 +1245,6 @@ Graph::Graph(const Model& owning_model,
12651245
}
12661246
}
12671247

1268-
put_data_maybe_in_memory(*tensor);
1269-
12701248
// Ensure initializers are also graph inputs.
12711249
if (ir_version_ < 4) {
12721250
TypeProto t{utils::TypeProtoFromTensorProto(*tensor)};
@@ -1343,22 +1321,7 @@ Graph::Graph(const Model& owning_model,
13431321
}
13441322

13451323
// Copy initial tensors to a map.
1346-
for (int i = 0, lim = graph_proto_->initializer_size(); i < lim; ++i) {
1347-
auto& tensor = *graph_proto_->mutable_initializer(i);
1348-
// If data is on disk, it will be loaded either by optimizers
1349-
// or during session state finalization.
1350-
// If data is already in memory, do nothing.
1351-
if (!utils::HasExternalData(tensor)) {
1352-
const bool is_sparse = sparse_tensor_names_.count(tensor.name());
1353-
if (is_sparse) {
1354-
sparse_tensor_names_.erase(tensor.name());
1355-
}
1356-
put_data_maybe_in_memory(tensor);
1357-
if (is_sparse) {
1358-
sparse_tensor_names_.emplace(tensor.name());
1359-
}
1360-
}
1361-
1324+
for (auto& tensor : graph_proto_->initializer()) {
13621325
auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor);
13631326
if (!p.second) {
13641327
LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name()
@@ -3420,7 +3383,32 @@ Status Graph::Resolve(const ResolveOptions& options) {
34203383

34213384
ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func));
34223385

3423-
return Status::OK();
3386+
auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
3387+
// if we have any initializers that are not in memory, put them there.
3388+
const auto& model_path = graph.ModelPath();
3389+
auto& graph_proto = *graph.graph_proto_;
3390+
for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) {
3391+
auto& tensor_proto = *graph_proto.mutable_initializer(i);
3392+
if (utils::HasExternalData(tensor_proto)) {
3393+
continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize
3394+
}
3395+
3396+
size_t size_in_bytes = 0;
3397+
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
3398+
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
3399+
OrtValue ort_value;
3400+
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
3401+
CPUAllocator::DefaultInstance(), ort_value));
3402+
constexpr const bool use_tensor_buffer_true = true;
3403+
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
3404+
use_tensor_buffer_true);
3405+
ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value));
3406+
}
3407+
}
3408+
return Status::OK();
3409+
};
3410+
3411+
return ForThisAndAllSubgraphs(all_subgraphs, put_weights_maybe_in_memory_func);
34243412
}
34253413

34263414
void Graph::SetName(const std::string& name) {
@@ -3659,6 +3647,15 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi
36593647
ORT_RETURN_IF_NOT(old_initializer.data_type() == new_initializer.data_type(),
36603648
"Replacement tensor's data type does not match.");
36613649

3650+
bool is_sparse = false;
3651+
{
3652+
auto sparse_tensor_it = sparse_tensor_names_.find(initializer_name);
3653+
if (sparse_tensor_it != sparse_tensor_names_.end()) {
3654+
sparse_tensor_names_.erase(sparse_tensor_it);
3655+
is_sparse = true;
3656+
}
3657+
}
3658+
36623659
auto& mutable_initializers = *(graph_proto_->mutable_initializer());
36633660
// use cheaper pointer comparison to find old entry
36643661
auto existing_entry = std::find(mutable_initializers.pointer_begin(), mutable_initializers.pointer_end(),
@@ -3675,6 +3672,9 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi
36753672
}
36763673

36773674
**existing_entry = std::move(new_initializer);
3675+
if (is_sparse) {
3676+
sparse_tensor_names_.insert((**existing_entry).name());
3677+
}
36783678

36793679
return Status::OK();
36803680
}
@@ -3720,7 +3720,7 @@ Status Graph::InjectExternalInitializedTensors(const InlinedHashMap<std::string,
37203720
Status Graph::InjectExternalInitializersFromFilesInMemory(
37213721
const InlinedHashMap<PathString, std::pair<char*, size_t>>& external_initializer_files) {
37223722
for (const auto& [tensor_name, tensor_proto] : name_to_initial_tensor_) {
3723-
if (tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) {
3723+
if (utils::HasExternalDataInFile(*tensor_proto)) {
37243724
std::unique_ptr<ExternalDataInfo> external_data_info;
37253725
ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto->external_data(), external_data_info));
37263726

@@ -3729,25 +3729,27 @@ Status Graph::InjectExternalInitializersFromFilesInMemory(
37293729
const size_t external_data_length = external_data_info->GetLength();
37303730
SafeInt<size_t> tensor_byte_size;
37313731
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &tensor_byte_size));
3732+
37323733
ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
37333734
"TensorProto: ", tensor_name, " external data size mismatch. Computed size: ",
37343735
*&tensor_byte_size, ", external_data.length: ", external_data_length);
37353736

37363737
SafeInt<FileOffsetType> end_of_read(file_offset);
37373738
end_of_read += tensor_byte_size;
37383739

3739-
auto external_file_pos = external_initializer_files.find(external_file);
3740-
ORT_RETURN_IF(external_file_pos == external_initializer_files.end(),
3740+
auto user_provided_entry = external_initializer_files.find(external_file);
3741+
ORT_RETURN_IF(user_provided_entry == external_initializer_files.end(),
37413742
"External file: ", ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(external_file),
37423743
" not found from the table user provided.");
3743-
auto external_file_length = external_file_pos->second.second;
37443744

3745-
ORT_RETURN_IF(file_offset < 0 || end_of_read > narrow<FileOffsetType>(external_file_length),
3745+
auto user_provided_length = user_provided_entry->second.second;
3746+
3747+
ORT_RETURN_IF(file_offset < 0 || end_of_read > narrow<FileOffsetType>(user_provided_length),
37463748
"External initializer: ", tensor_name,
37473749
" offset: ", file_offset, " size to read: ", external_data_length,
3748-
" given file_length: ", external_file_length, " are out of bounds or can not be read in full.");
3749-
char* external_file_buffer = static_cast<char*>(external_file_pos->second.first);
3750-
char* tensor_buffer = external_file_buffer + file_offset;
3750+
" given file_length: ", user_provided_length, " are out of bounds or can not be read in full.");
3751+
char* user_provided_file_buffer = static_cast<char*>(user_provided_entry->second.first);
3752+
char* user_provided_tensor_buffer = user_provided_file_buffer + file_offset;
37513753

37523754
const auto& old_initializer = *(tensor_proto);
37533755
auto& mutable_initializers = *(graph_proto_->mutable_initializer());
@@ -3762,19 +3764,11 @@ Status Graph::InjectExternalInitializersFromFilesInMemory(
37623764
const DataTypeImpl* const type =
37633765
DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType();
37643766
TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer);
3765-
auto tensor = Tensor(type, tensor_shape, tensor_buffer,
3767+
auto tensor = Tensor(type, tensor_shape, user_provided_tensor_buffer,
37663768
OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator));
37673769

3768-
constexpr const bool use_tensor_buffer_true = true;
3769-
auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_true);
3770-
// Implied that external data is in memory
3771-
const bool has_external_data_in_memory = utils::HasExternalData(new_tensor_proto);
3772-
3773-
OrtValue ort_value;
3774-
if (has_external_data_in_memory) {
3775-
Tensor::InitOrtValue(std::move(tensor), ort_value);
3776-
}
3777-
ortvalue_initializers_.insert_or_assign(tensor_name, std::move(ort_value));
3770+
constexpr const bool use_tensor_buffer_false = false;
3771+
auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false);
37783772
**existing_entry = std::move(new_tensor_proto);
37793773
}
37803774
}

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,45 +7,6 @@
77
namespace onnxruntime {
88
namespace nnapi {
99

10-
namespace {
11-
bool HasExternalInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit) {
12-
const auto is_ext_initializer =
13-
[&](const NodeArg& node_arg) {
14-
const auto& input_name(node_arg.Name());
15-
const auto initializer = initializers.find(input_name);
16-
if (initializer == initializers.end())
17-
return false;
18-
19-
const auto& tensor = *initializer->second;
20-
if (tensor.has_data_location() &&
21-
tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
22-
LOGS_DEFAULT(VERBOSE) << "Initializer [" << input_name
23-
<< "] with external data location are not currently supported";
24-
return true;
25-
}
26-
27-
return false;
28-
};
29-
30-
const auto& inputs = node_unit.Inputs();
31-
for (const auto& input : inputs) {
32-
if (is_ext_initializer(input.node_arg))
33-
return true;
34-
35-
if (!input.quant_param)
36-
continue;
37-
38-
if (is_ext_initializer(input.quant_param->scale))
39-
return true;
40-
41-
if (input.quant_param->zero_point && is_ext_initializer(*input.quant_param->zero_point))
42-
return true;
43-
}
44-
45-
return false;
46-
}
47-
} // namespace
48-
4910
// Add operator related
5011

5112
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
@@ -86,10 +47,6 @@ bool BaseOpBuilder::IsOpSupported(const GraphViewer& graph_viewer, const NodeUni
8647
if (!HasSupportedInputOutputs(graph_viewer, node_unit, params))
8748
return false;
8849

89-
// We do not support external initializers for now
90-
if (HasExternalInitializer(graph_viewer.GetAllInitializedTensors(), node_unit))
91-
return false;
92-
9350
if (!HasSupportedOpSet(node_unit))
9451
return false;
9552

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ Status ModelBuilder::RegisterInitializers() {
283283
auto [index, size, padded_size] = initializers[i++];
284284
const uint8_t* src = nullptr;
285285
// TensorProto_DataType_UINT8 or TensorProto_DataType_FLOAT:
286-
Initializer unpacked_tensor(tensor, graph_viewer_.ModelPath());
286+
Initializer unpacked_tensor(graph_viewer_.GetGraph(), tensor, graph_viewer_.ModelPath());
287287
size_t size_in_bytes = unpacked_tensor.DataAsByteSpan().size();
288288
ORT_RETURN_IF_NOT(size == size_in_bytes,
289289
"initializer tensor: ", tensor.name(), "'s size: ",

orttraining/orttraining/test/optimizer/graph_transform_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) {
12351235
auto out_channel = 64;
12361236
auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});
12371237

1238-
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel / group, 1}, {-1.0f, 1.0f});
1238+
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel / group, 1}, -1.0f, 1.0f);
12391239
auto* conv_output = builder.MakeOutput();
12401240

12411241
auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
@@ -1280,8 +1280,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) {
12801280
auto out_channel = 64;
12811281
auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});
12821282

1283-
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
1284-
auto* bias_arg = builder.MakeInitializer<float>({out_channel}, {-1.0f, 1.0f});
1283+
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, -1.0f, 1.0f);
1284+
auto* bias_arg = builder.MakeInitializer<float>({out_channel}, -1.0f, 1.0f);
12851285
auto* conv_output = builder.MakeOutput();
12861286

12871287
auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output});
@@ -1314,7 +1314,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) {
13141314
auto out_channel = 64;
13151315
auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});
13161316

1317-
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
1317+
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, -1.0f, 1.0f);
13181318
auto* conv_output = builder.MakeOutput();
13191319

13201320
auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
@@ -1347,7 +1347,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) {
13471347
auto out_channel = 64;
13481348
auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});
13491349

1350-
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
1350+
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, -1.0f, 1.0f);
13511351
auto* conv_output = builder.MakeOutput();
13521352

13531353
auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});

0 commit comments

Comments
 (0)