Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con

program
.CacheHint(interleaved_)
.AddInputs({{input, ProgramTensorMetadataDependency::Rank},
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
{position_ids, ProgramTensorMetadataDependency::Rank},
{cos_cache, ProgramTensorMetadataDependency::Rank},
{sin_cache, ProgramTensorMetadataDependency::Rank}})
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,13 @@
program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank});
}

program.CacheHint(is_input_empty)
// TODO: the ReduceKernel class is designed to use `keepdims_`, `noop_with_empty_axes_` and input axes as uniform variables,

Check warning on line 250 in onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc:250: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// but the current implementation does not work without them in cache key.
// This is a temporary workaround to make it work. We should fix this in the future.
program.CacheHint(keepdims_,
noop_with_empty_axes_,
select_last_index_,
absl::StrJoin(input_axes, ","))
.AddOutput({context.Output(0, output_shape), ProgramTensorMetadataDependency::TypeAndRank})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
GSL_SUPPRESS(r.11)
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode));
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode, config.preserve_device));
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
} else if (context_id != 0) {
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
Expand All @@ -794,7 +794,7 @@ void WebGpuContextFactory::ReleaseContext(int context_id) {
auto it = contexts_.find(context_id);
ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found.");

if (--it->second.ref_count == 0) {
if (--it->second.ref_count == 0 && !it->second.context->preserve_device_) {
contexts_.erase(it);
}
}
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct WebGpuContextConfig {
WGPUDevice device;
const void* dawn_proc_table;
ValidationMode validation_mode;
bool preserve_device;
};

struct WebGpuBufferCacheConfig {
Expand Down Expand Up @@ -152,8 +153,8 @@ class WebGpuContext final {
AtPasses
};

WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode, bool preserve_device)
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None}, preserve_device_{preserve_device} {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);

std::vector<const char*> GetEnabledAdapterToggles() const;
Expand Down Expand Up @@ -229,6 +230,7 @@ class WebGpuContext final {

uint64_t gpu_timestamp_offset_ = 0;
bool is_profiling_ = false;
bool preserve_device_;

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,33 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
}
}

std::string preserve_device_str;
bool preserve_device = false;
if (config_options.TryGetConfigEntry(kPreserveDevice, preserve_device_str)) {
if (preserve_device_str == kPreserveDevice_ON) {
preserve_device = true;
} else if (preserve_device_str == kPreserveDevice_OFF) {
preserve_device = false;
} else {
ORT_THROW("Invalid preserve device: ", preserve_device_str);
}
}

webgpu::WebGpuContextConfig context_config{
context_id,
reinterpret_cast<WGPUInstance>(webgpu_instance),
reinterpret_cast<WGPUDevice>(webgpu_device),
reinterpret_cast<const void*>(dawn_proc_table),
validation_mode,
preserve_device,
};

LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << context_id;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << webgpu_instance;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << webgpu_device;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << dawn_proc_table;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << validation_mode;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP PreserveDevice: " << preserve_device;

//
// STEP.3 - prepare parameters for WebGPU context initialization.
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ constexpr const char* kValidationMode = "WebGPU:validationMode";
constexpr const char* kForceCpuNodeNames = "WebGPU:forceCpuNodeNames";
constexpr const char* kEnablePIXCapture = "WebGPU:enablePIXCapture";

constexpr const char* kPreserveDevice = "WebGPU:preserveDevice";

// The following are the possible values for the provider options.

constexpr const char* kDawnBackendType_D3D12 = "D3D12";
Expand All @@ -44,6 +46,9 @@ constexpr const char* kEnableGraphCapture_OFF = "0";
constexpr const char* kEnablePIXCapture_ON = "1";
constexpr const char* kEnablePIXCapture_OFF = "0";

constexpr const char* kPreserveDevice_ON = "1";
constexpr const char* kPreserveDevice_OFF = "0";

constexpr const char* kBufferCacheMode_Disabled = "disabled";
constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease";
constexpr const char* kBufferCacheMode_Simple = "simple";
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ std::unique_ptr<IExecutionProvider> DefaultWebGpuExecutionProvider() {
ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode,
webgpu::options::kBufferCacheMode_Disabled)
.IsOK());
// Disable device auto collect
ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kPreserveDevice,
webgpu::options::kPreserveDevice_ON)
.IsOK());
return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider();
#else
return nullptr;
Expand Down
Loading