Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#ifndef SHARED_PROVIDER
#include <memory>
#include <optional>
#include <string_view>
#include <unordered_map>
#include <unordered_set>

Expand Down Expand Up @@ -62,6 +64,9 @@ using RunOptions = ::OrtRunOptions;
enum class DataLayout {
NCHW,
NHWC,

// NCHW is the default ONNX standard data layout. So default to it.
Default = NCHW,
};

class IExecutionProvider {
Expand Down Expand Up @@ -322,9 +327,21 @@ class IExecutionProvider {
}

virtual DataLayout GetPreferredLayout() const {
// NCHW is the default ONNX standard data layout. So default to it.
// EPs which prefer a different layout should override to return their preferred layout.
return DataLayout::NCHW;
return DataLayout::Default;
}

/**
Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout should be
converted to `target_data_layout`.
If the EP prefers a non-default data layout (see `GetPreferredLayout()`), this function will be called during
layout transformation with `target_data_layout` set to the EP's preferred data layout.
A return value of `std::nullopt` indicates that this decision is left to ORT.
*/
virtual std::optional<bool> ShouldConvertDataLayoutForOp(std::string_view /*domain*/,
std::string_view /*op_type*/,
DataLayout /*target_data_layout*/) const {
return std::nullopt;
}

virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {}
Expand Down
36 changes: 35 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,18 @@ struct OrtEpApi {
};

/**
* \brief The data layout type that is preferred by an EP.
* \brief The data layout type.
*
* EPs may specify a preferred data layout type. ORT's default layout type is OrtEpDataLayout_NCHW, or
* OrtEpDataLayout_Default.
*
* \since Version 1.23.
*/
typedef enum OrtEpDataLayout {
OrtEpDataLayout_NCHW = 0,
OrtEpDataLayout_NHWC,

OrtEpDataLayout_Default = OrtEpDataLayout_NCHW,
} OrtEpDataLayout;

/**
Expand Down Expand Up @@ -257,6 +263,34 @@ struct OrtEp {
OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr,
_Out_ OrtEpDataLayout* preferred_data_layout);

/** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout
* should be converted to `target_data_layout`.
* If the EP prefers a non-default data layout (see `GetPreferredDataLayout()`), this function will be called
* during layout transformation with `target_data_layout` set to the EP's preferred data layout.
*
* \note Implementation of this function is optional.
* If an EP prefers a non-default data layout, it may implement this to customize the specific op data layout
* preferences at a finer granularity.
*
* \param[in] this_ptr The OrtEp instance.
* \param[in] domain The op domain. An empty string means the ONNX domain.
* \param[in] op_type The op type.
* \param[in] target_data_layout The target data layout.
* \param[out] should_convert Whether the associated node's data layout should be converted to `target_data_layout`.
* If greater than 0, convert.
* If 0, don't convert.
* Otherwise, if less than 0, leave the decision to ORT.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr,
_In_z_ const char* domain,
_In_z_ const char* op_type,
_In_ OrtEpDataLayout target_data_layout,
_Outptr_ int* should_convert);

/** \brief Set dynamic options on this EP.
*
* Dynamic options can be set by the user at any time after session creation with `OrtApi::SetEpDynamicOptions()`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,93 +30,29 @@
return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose);
}

#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
// TODO(mtavenrath) generate list from registered kernels using nhwc domain
const std::unordered_set<std::string_view>& GetCUDALayoutSensitiveOps() {
static std::unordered_set<std::string_view> cuda_nhwc_ops = []() {
return std::unordered_set<std::string_view>{
"BatchNormalization",
"Conv",
"ConvTranspose",
"GlobalMaxPool",
"MaxPool",
"GlobalAveragePool",
"AveragePool",
"GridSample",
"DepthToSpace",
"SpaceToDepth",
"LRN"};
}();
return cuda_nhwc_ops;
}
#endif

/// <summary>
/// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the
/// default set of layout sensitive operators if required.
///
/// Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we
/// don't hardcode it here.
/// </summary>
/// <param name="execution_provider">The EP instance.</param>
/// <param name="node">Node to check</param>
/// <returns>true if the node should have its layout converted to NHWC.</returns>
bool ConvertNodeLayout(const api::NodeRef& node) {
bool ShouldConvertNodeLayoutToNhwc(const IExecutionProvider& execution_provider, const api::NodeRef& node) {
// skip if op is not an ONNX or contrib op
auto domain = node.Domain();
const auto domain = node.Domain();
if (domain != kOnnxDomain && domain != kMSDomain) {
return false;
}

const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();

// handle special cases
#if defined(USE_JSEP)
// TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed
if (node.GetExecutionProviderType() == kJsExecutionProvider) {
if (node.OpType() == "Resize") {
// leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain
// with the original input layout.
return false;
}
const auto op_type = node.OpType();
if (auto should_convert_from_ep = execution_provider.ShouldConvertDataLayoutForOp(domain, op_type, DataLayout::NHWC);
should_convert_from_ep.has_value()) {
return *should_convert_from_ep;
}
#endif

// NHWC for Resize operator is not implemented on kWebGpuExecutionProvider
#if defined(USE_WEBGPU)
if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) {
if (node.OpType() == "Resize") {
return false;
}
}
#endif

// TODO: We don't need to check USE_CUDA || USE_CUDA_PROVIDER_INTERFACE in this function because we're already
// checking if the node is assigned to the desired EP (e.g., CUDA EP). We should only need to check
// ENABLE_CUDA_NHWC_OPS.
#if (defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)) && ENABLE_CUDA_NHWC_OPS
if (node.GetExecutionProviderType() == kCudaExecutionProvider) {
if (layout_sensitive_ops.count(node.OpType())) {
const auto& cuda_nhwc_ops = GetCUDALayoutSensitiveOps();
if (!cuda_nhwc_ops.count(node.OpType())) {
return false;
}
}
}
#endif

// TODO: We don't really need EP pre-processor macros in this function because we're already checking if the
// node is assigned to the desired EP (e.g., QNN EP). There's nothing about this code that absolutely requires
// conditional compilation.
#if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE)
if (node.GetExecutionProviderType() == kQnnExecutionProvider) {
if (node.OpType() == "Upsample") {
// Upsample is translated to QNN's Resize, which requires the NHWC layout for processing.
return true;
}
}
#endif

return layout_sensitive_ops.count(node.OpType()) != 0;
const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();
const auto op_identifier = MakeORTLayoutSensitiveOpId(domain, op_type);
return layout_sensitive_ops.find(op_identifier) != layout_sensitive_ops.end();
}
} // namespace

Expand All @@ -126,25 +62,37 @@
// Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove
// as many of the layout transposes as possible.
const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps() {
static std::unordered_set<std::string_view> ort_layout_sensitive_ops = []() {
const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps();
static const std::unordered_set<std::string_view> ort_layout_sensitive_ops = []() {
const auto& layout_sensitive_onnx_ops = onnx_transpose_optimization::GetLayoutSensitiveOps();

// Define a static local string array so we can refer to the elements with string_views.
static const std::string layout_sensitive_contrib_ops[]{
MakeORTLayoutSensitiveOpId(kMSDomain, "FusedConv"),
MakeORTLayoutSensitiveOpId(kMSDomain, "GridSample"),
MakeORTLayoutSensitiveOpId(kMSDomain, "QLinearAveragePool"),
MakeORTLayoutSensitiveOpId(kMSDomain, "QLinearGlobalAveragePool"),
};

std::unordered_set<std::string_view> ort_specific_ops =
{
"FusedConv",
"QLinearAveragePool",
"QLinearGlobalAveragePool",
// Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default
// as EPs tend to only support one layout.
"Resize",
};

ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend());
ort_specific_ops.insert(std::begin(layout_sensitive_onnx_ops), std::end(layout_sensitive_onnx_ops));
ort_specific_ops.insert(std::begin(layout_sensitive_contrib_ops), std::end(layout_sensitive_contrib_ops));
return ort_specific_ops;
}();

return ort_layout_sensitive_ops;
}

// "op_type" if from ONNX domain, "domain:op_type" otherwise.
std::string MakeORTLayoutSensitiveOpId(std::string_view domain, std::string_view op_type) {
return (domain == kOnnxDomain) ? std::string(op_type) : MakeString(domain, ":", op_type);

Check warning on line 93 in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc:93: Add #include <string> for string [build/include_what_you_use] [4]
}

Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider,
AllocatorPtr cpu_allocator,
const DebugGraphFn& debug_graph_fn) {
Expand All @@ -159,7 +107,7 @@
continue;
}

if (ConvertNodeLayout(*node)) {
if (ShouldConvertNodeLayoutToNhwc(execution_provider, *node)) {
// domain kMSInternalNHWCDomain uses OpType "Conv" for both Conv and FusedConv.
// So, change the OpType to "Conv" for FusedConv.
std::string_view op_type = node->OpType() == "FusedConv" ? "Conv" : node->OpType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,19 @@
/// Gets a list of layout sensitive ops for ORT. This list contains ONNX standard defined
/// layout sensitive ops + contrib ops + ops which are not layout sensitive but are treated as
/// layout sensitive by ORT EPs (example Resize).
///
/// Note: The format of the returned op identifiers is "<op type>" for ops in the ONNX domain and
/// "<domain>:<op type>" for ops in other domains. `MakeORTLayoutSensitiveOpId()` can be used to
/// create an op identifier with this format.
/// </summary>
/// <returns>unordered set of op_types which are layout sensitive</returns>
/// <returns>set of op identifiers which are layout sensitive</returns>
const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps();

/// <summary>
/// Creates an op identifier compatible with `GetORTLayoutSensitiveOps()`.
/// </summary>
std::string MakeORTLayoutSensitiveOpId(std::string_view domain, std::string_view op_type);

Check warning on line 82 in onnxruntime/core/optimizer/layout_transformation/layout_transformation.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/layout_transformation/layout_transformation.h:82: Add #include <string> for string [build/include_what_you_use] [4]

/// <summary>
/// Inserts transposes around op inputs/outputs. Alternatively transposes initializers or uses existing Transpose
/// nodes if possible. Populates shape information on affected node inputs/outputs to reflect the change.
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,37 @@
return this->IsNHWCPreferred() ? DataLayout::NHWC : DataLayout::NCHW;
}

std::optional<bool> CUDAExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain,
std::string_view node_op_type,
DataLayout target_data_layout) const {
#if defined(ENABLE_CUDA_NHWC_OPS)
if (target_data_layout != DataLayout::NHWC) {
return std::nullopt;
}

// TODO(mtavenrath) generate list from registered kernels using nhwc domain
static const std::unordered_set<std::string_view> cuda_nhwc_onnx_ops{

Check warning on line 335 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:335: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]
"BatchNormalization",
"Conv",
"ConvTranspose",
"GlobalMaxPool",
"MaxPool",
"GlobalAveragePool",
"AveragePool",
"GridSample",
"DepthToSpace",
"SpaceToDepth",
"LRN",
};

return (node_domain == kOnnxDomain && cuda_nhwc_onnx_ops.find(node_op_type) != cuda_nhwc_onnx_ops.end()) ||
(node_domain == kMSDomain && node_op_type == "GridSample");

#else // defined(ENABLE_CUDA_NHWC_OPS)
return std::nullopt;
#endif
}

CUDAExecutionProvider::~CUDAExecutionProvider() {
// clean up thread local context caches
{
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class CUDAExecutionProvider : public IExecutionProvider {

DataLayout GetPreferredLayout() const override;

std::optional<bool> ShouldConvertDataLayoutForOp(std::string_view node_domain,
std::string_view node_op_type,
DataLayout target_data_layout) const override;

const void* GetExecutionHandle() const noexcept override {
// The CUDA interface does not return anything interesting.
return nullptr;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

namespace onnxruntime::cuda {

// When adding new supported NHWC operations make sure to also integrate them into: ConvertNodeLayout
// in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
// When adding new supported NHWC operations make sure to also integrate them into
// CUDAExecutionProvider::ShouldConvertDataLayoutForOp()

class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, float, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, double, BatchNormalization);
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,23 @@ std::unique_ptr<onnxruntime::IExternalDataLoader> JsExecutionProvider::GetExtern
return std::make_unique<js::ExternalDataLoader>();
}

std::optional<bool> JsExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain,
std::string_view node_op_type,
DataLayout target_data_layout) const {
if (target_data_layout != DataLayout::NHWC) {
return std::nullopt;
}

// TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed
if (node_domain == kOnnxDomain && node_op_type == "Resize") {
// leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain
// with the original input layout.
return false;
}

return std::nullopt;
}

JsExecutionProvider::~JsExecutionProvider() {
}

Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class JsExecutionProvider : public IExecutionProvider {

DataLayout GetPreferredLayout() const override { return preferred_data_layout_; }

std::optional<bool> ShouldConvertDataLayoutForOp(std::string_view node_domain,
std::string_view node_op_type,
DataLayout target_data_layout) const override;

FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }

// JSEP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to work,
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,21 @@ DataLayout QNNExecutionProvider::GetPreferredLayout() const {
return DataLayout::NHWC;
}

std::optional<bool> QNNExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain,
std::string_view node_op_type,
DataLayout target_data_layout) const {
if (target_data_layout != DataLayout::NHWC) {
return std::nullopt;
}

if (node_domain == kOnnxDomain && node_op_type == "Upsample") {
// Upsample is translated to QNN's Resize, which requires the NHWC layout for processing.
return true;
}

return std::nullopt;
}

Status QNNExecutionProvider::CreateComputeFunc(std::vector<NodeComputeInfo>& node_compute_funcs,
const logging::Logger& logger) {
NodeComputeInfo compute_info;
Expand Down
Loading
Loading