Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,18 @@ export class WebGpuBackend {
requiredFeatures,
};

if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) {
requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName);
} else if (adapter.features.has('timestamp-query')) {
requiredFeatures.push('timestamp-query');
}
if (adapter.features.has('shader-f16')) {
requiredFeatures.push('shader-f16');
// Try requiring WebGPU features
const requireFeatureIfAvailable = (feature: GPUFeatureName) =>
adapter.features.has(feature) && requiredFeatures.push(feature) && true;
// Try chromium-experimental-timestamp-query-inside-passes and fallback to timestamp-query
if (!requireFeatureIfAvailable('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName)) {
requireFeatureIfAvailable('timestamp-query');
}
requireFeatureIfAvailable('shader-f16');
// Try subgroups
if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) {
// If subgroups feature is available, also try subgroups-f16
requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName);
}

this.device = await adapter.requestDevice(deviceDescriptor);
Expand Down
23 changes: 23 additions & 0 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ class ComputeContextImpl implements ComputeContext {
return this.backend.device.limits.maxComputeWorkgroupStorageSize;
}

isSubgroupsSupported(): boolean {
return this.backend.device.features.has('subgroups' as GPUFeatureName);
}

isSubgroupsF16Supported(): boolean {
return this.backend.device.features.has('subgroups-f16' as GPUFeatureName);
}

getSubgroupSizeRange(): [number, number] | undefined {
// Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to
// workaround the IDL type checks.
// TODO: clean this after subgroups feature is sattled in IDL.
const deviceSubgroupsLimits = this.backend.device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number };
if (
!this.isSubgroupsSupported() ||
!deviceSubgroupsLimits.minSubgroupSize ||
!deviceSubgroupsLimits.maxSubgroupSize
) {
return undefined;
}
return [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize];
}

compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
// prepare inputs. inputs should always be valid data.
const mappedInputs =
Expand Down
20 changes: 15 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,23 @@ export class ProgramManager {
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
TRACE_FUNC_BEGIN(programInfo.name);
const device = this.backend.device;
const extensions: string[] = [];
if (device.features.has('shader-f16')) {
extensions.push('enable f16;');
}
const enableDirectives: string[] = [];

// Enable WGSL extensions based on available WebGPU features
const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [
{ feature: 'shader-f16', extension: 'f16' },
{ feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' },
{ feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' },
];
extensionsInfo.forEach((info) => {
if (device.features.has(info.feature)) {
enableDirectives.push(`enable ${info.extension};`);
}
});

const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits);
const userCode = programInfo.getShaderSource(shaderHelper);
const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const code = `${enableDirectives.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({ code, label: programInfo.name });
LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`);

Expand Down
3 changes: 3 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ export interface ComputeContext {
output(index: number, dims: readonly number[]): number;
getMaxComputeWorkgroupSizes(): [number, number, number];
getMaxComputeWorkgroupStoragesize(): number;
isSubgroupsSupported(): boolean;
isSubgroupsF16Supported(): boolean;
getSubgroupSizeRange(): [number, number] | undefined;
}

export type TimestampQuery = 'none' | 'inside-passes' | 'at-passes';
Loading