Skip to content

Move moving weights to memory to the end of Graph::Resolve() #25626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 2, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
common::Status InjectExternalInitializedTensors(const InlinedHashMap<std::string, OrtValue>& external_initializers);

/** This function takes externally provided files in memory for initializers with external
* data and replaces graph initializers with its content.
* data and replaces main graph initializers with its content.
*/
common::Status InjectExternalInitializersFromFilesInMemory(
const InlinedHashMap<PathString, std::pair<char*, size_t>>& external_initializer_files);
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/framework/session_state_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ common::Status SaveInitializedTensors(
// TODO: if the tensor need be copied, does it have enough room?
ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc));

// ??? Should we ignore this session option if the EP is explictly providing the read only allocator?
// ??? Should we ignore this session option if the EP is explicitly providing the read only allocator?
// bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator;
const bool use_device_allocator_for_initializers =
session_options.config_options.GetConfigOrDefault(
Expand All @@ -402,9 +402,11 @@ common::Status SaveInitializedTensors(
std::move(tensor), ort_value));
}
} else {
// if in memory we were expecting to find it above.
ORT_ENFORCE(!utils::HasExternalDataInMemory(tensor_proto));

// We need to deserialize the tensor proto into an OrtValue
// using the preallocated buffer or allocator.

Status st = DeserializeTensorProto(env, graph_loc, tensor_proto,
(memory_buffer.has_value()) ? &*memory_buffer : nullptr,
alloc, default_cpu_alloc, ort_value, data_transfer_mgr,
Expand Down
110 changes: 52 additions & 58 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1226,26 +1226,6 @@ Graph::Graph(const Model& owning_model,
ArgNameToTypeMap name_to_type_map;
const auto& model_path = ModelPath();

// If the tensor proto data is large enough, externalize it and replace with a tensor_proto
// with external data reference pointing to an OrtValue, otherwise do nothing.
auto put_data_maybe_in_memory = [this, &model_path](ONNX_NAMESPACE::TensorProto& tensor_proto) {
size_t size_in_bytes = 0;
ORT_THROW_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
OrtValue ort_value;
ORT_THROW_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
CPUAllocator::DefaultInstance(), ort_value));
constexpr const bool use_tensor_buffer_true = true;
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
use_tensor_buffer_true);
assert(ort_value.IsAllocated());
auto ins_result = ortvalue_initializers_.insert_or_assign(tensor_proto_to_add.name(), std::move(ort_value));
ORT_ENFORCE(ins_result.second, "Unexpected duplicate insert or assign OrtValue for tensor: ", tensor_proto_to_add.name(),
" in the initializer list.");
tensor_proto = std::move(tensor_proto_to_add);
}
};

// Process 'Constant' nodes
// Put the 'TensorProto' stored in the 'Constant' nodes attribute into the graphs initializer list
for (auto& node : graph_proto_->node()) {
Expand All @@ -1265,8 +1245,6 @@ Graph::Graph(const Model& owning_model,
}
}

put_data_maybe_in_memory(*tensor);

// Ensure initializers are also graph inputs.
if (ir_version_ < 4) {
TypeProto t{utils::TypeProtoFromTensorProto(*tensor)};
Expand Down Expand Up @@ -1343,22 +1321,7 @@ Graph::Graph(const Model& owning_model,
}

// Copy initial tensors to a map.
for (int i = 0, lim = graph_proto_->initializer_size(); i < lim; ++i) {
auto& tensor = *graph_proto_->mutable_initializer(i);
// If data is on disk, it will be loaded either by optimizers
// or during session state finalization.
// If data is already in memory, do nothing.
if (!utils::HasExternalData(tensor)) {
const bool is_sparse = sparse_tensor_names_.count(tensor.name());
if (is_sparse) {
sparse_tensor_names_.erase(tensor.name());
}
put_data_maybe_in_memory(tensor);
if (is_sparse) {
sparse_tensor_names_.emplace(tensor.name());
}
}

for (auto& tensor : graph_proto_->initializer()) {
auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor);
if (!p.second) {
LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name()
Expand Down Expand Up @@ -3420,7 +3383,32 @@ Status Graph::Resolve(const ResolveOptions& options) {

ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func));

return Status::OK();
auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
// if we have any initializers that are not in memory, put them there.
const auto& model_path = graph.ModelPath();
auto& graph_proto = *graph.graph_proto_;
for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) {
auto& tensor_proto = *graph_proto.mutable_initializer(i);
if (utils::HasExternalData(tensor_proto)) {
continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize
}

size_t size_in_bytes = 0;
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
OrtValue ort_value;
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
CPUAllocator::DefaultInstance(), ort_value));
constexpr const bool use_tensor_buffer_true = true;
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
use_tensor_buffer_true);
ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value));
}
}
return Status::OK();
};

return ForThisAndAllSubgraphs(all_subgraphs, put_weights_maybe_in_memory_func);
}

void Graph::SetName(const std::string& name) {
Expand Down Expand Up @@ -3659,6 +3647,15 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi
ORT_RETURN_IF_NOT(old_initializer.data_type() == new_initializer.data_type(),
"Replacement tensor's data type does not match.");

bool is_sparse = false;
{
auto sparse_tensor_it = sparse_tensor_names_.find(initializer_name);
if (sparse_tensor_it != sparse_tensor_names_.end()) {
sparse_tensor_names_.erase(sparse_tensor_it);
is_sparse = true;
}
}

auto& mutable_initializers = *(graph_proto_->mutable_initializer());
// use cheaper pointer comparison to find old entry
auto existing_entry = std::find(mutable_initializers.pointer_begin(), mutable_initializers.pointer_end(),
Expand All @@ -3675,6 +3672,9 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi
}

**existing_entry = std::move(new_initializer);
if (is_sparse) {
sparse_tensor_names_.insert((**existing_entry).name());
}

return Status::OK();
}
Expand Down Expand Up @@ -3720,7 +3720,7 @@ Status Graph::InjectExternalInitializedTensors(const InlinedHashMap<std::string,
Status Graph::InjectExternalInitializersFromFilesInMemory(
const InlinedHashMap<PathString, std::pair<char*, size_t>>& external_initializer_files) {
for (const auto& [tensor_name, tensor_proto] : name_to_initial_tensor_) {
if (tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) {
if (utils::HasExternalDataInFile(*tensor_proto)) {
std::unique_ptr<ExternalDataInfo> external_data_info;
ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto->external_data(), external_data_info));

Expand All @@ -3729,25 +3729,27 @@ Status Graph::InjectExternalInitializersFromFilesInMemory(
const size_t external_data_length = external_data_info->GetLength();
SafeInt<size_t> tensor_byte_size;
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &tensor_byte_size));

ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
"TensorProto: ", tensor_name, " external data size mismatch. Computed size: ",
*&tensor_byte_size, ", external_data.length: ", external_data_length);

SafeInt<FileOffsetType> end_of_read(file_offset);
end_of_read += tensor_byte_size;

auto external_file_pos = external_initializer_files.find(external_file);
ORT_RETURN_IF(external_file_pos == external_initializer_files.end(),
auto user_provided_entry = external_initializer_files.find(external_file);
ORT_RETURN_IF(user_provided_entry == external_initializer_files.end(),
"External file: ", ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(external_file),
" not found from the table user provided.");
auto external_file_length = external_file_pos->second.second;

ORT_RETURN_IF(file_offset < 0 || end_of_read > narrow<FileOffsetType>(external_file_length),
auto user_provided_length = user_provided_entry->second.second;

ORT_RETURN_IF(file_offset < 0 || end_of_read > narrow<FileOffsetType>(user_provided_length),
"External initializer: ", tensor_name,
" offset: ", file_offset, " size to read: ", external_data_length,
" given file_length: ", external_file_length, " are out of bounds or can not be read in full.");
char* external_file_buffer = static_cast<char*>(external_file_pos->second.first);
char* tensor_buffer = external_file_buffer + file_offset;
" given file_length: ", user_provided_length, " are out of bounds or can not be read in full.");
char* user_provided_file_buffer = static_cast<char*>(user_provided_entry->second.first);
char* user_provided_tensor_buffer = user_provided_file_buffer + file_offset;

const auto& old_initializer = *(tensor_proto);
auto& mutable_initializers = *(graph_proto_->mutable_initializer());
Expand All @@ -3762,19 +3764,11 @@ Status Graph::InjectExternalInitializersFromFilesInMemory(
const DataTypeImpl* const type =
DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType();
TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer);
auto tensor = Tensor(type, tensor_shape, tensor_buffer,
auto tensor = Tensor(type, tensor_shape, user_provided_tensor_buffer,
OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator));

constexpr const bool use_tensor_buffer_true = true;
auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_true);
// Implied that external data is in memory
const bool has_external_data_in_memory = utils::HasExternalData(new_tensor_proto);

OrtValue ort_value;
if (has_external_data_in_memory) {
Tensor::InitOrtValue(std::move(tensor), ort_value);
}
ortvalue_initializers_.insert_or_assign(tensor_name, std::move(ort_value));
constexpr const bool use_tensor_buffer_false = false;
auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false);
**existing_entry = std::move(new_tensor_proto);
}
}
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4918,3 +4918,16 @@ TEST(CApiTest, custom_cast) {
inputs, "output", expected_dims_y, expected_values_y, 0,
custom_op_domain, nullptr);
}

// TEST(CApiTest, TestPhi4) {
//
// constexpr const ORTCHAR_T* model_path = ORT_TSTR("D:\\dev\\data\\foundry_cache\\models\\Microsoft\\Phi-4-mini-instruct-cuda-gpu\\v3\\model.onnx");
// Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
// Ort::SessionOptions session_options;
//
// try {
// Ort::Session session(env, model_path, session_options);
// } catch (const Ort::Exception& ex) {
// std::cout << "Exception: " << ex.what() << std::endl;
// }
// }
Loading