Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { ProgramManager } from './webgpu/program-manager';
import {
AdapterInfo,
ComputeContext,
DeviceInfo,
GpuArchitecture,
GpuData,
GpuVendor,
Expand Down Expand Up @@ -134,13 +135,34 @@ class AdapterInfoImpl implements AdapterInfo {
}
}

class DeviceInfoImpl implements DeviceInfo {
readonly subgroupsSupported: boolean;
readonly subgroupsF16Supported: boolean;
readonly subgroupSizeRange?: readonly [number, number];

constructor(device: GPUDevice) {
this.subgroupsSupported = device.features.has('subgroups' as GPUFeatureName);
this.subgroupsF16Supported = device.features.has('subgroups' as GPUFeatureName);
// Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to
// workaround the IDL type checks.
// TODO: clean this after subgroups feature is settled in IDL.
const deviceSubgroupsLimits = device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number };
if (!this.subgroupsSupported || !deviceSubgroupsLimits.minSubgroupSize || !deviceSubgroupsLimits.maxSubgroupSize) {
this.subgroupSizeRange = undefined;
} else {
this.subgroupSizeRange = [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize];
}
}
}

/**
* this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
* the first parameter so that it is stored for future use.
*/
export class WebGpuBackend {
adapterInfo: AdapterInfoImpl;
device: GPUDevice;
deviceInfo: DeviceInfoImpl;
/**
* an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
*/
Expand Down Expand Up @@ -243,16 +265,22 @@ export class WebGpuBackend {
requiredFeatures,
};

if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) {
requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName);
} else if (adapter.features.has('timestamp-query')) {
requiredFeatures.push('timestamp-query');
// Try requiring WebGPU features
const requireFeatureIfAvailable = (feature: GPUFeatureName) =>
adapter.features.has(feature) && requiredFeatures.push(feature) && true;
// Try chromium-experimental-timestamp-query-inside-passes and fallback to timestamp-query
if (!requireFeatureIfAvailable('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName)) {
requireFeatureIfAvailable('timestamp-query');
}
if (adapter.features.has('shader-f16')) {
requiredFeatures.push('shader-f16');
requireFeatureIfAvailable('shader-f16');
// Try subgroups
if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) {
// If subgroups feature is available, also try subgroups-f16
requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName);
}

this.device = await adapter.requestDevice(deviceDescriptor);
this.deviceInfo = new DeviceInfoImpl(this.device);
this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo()));
this.gpuDataManager = createGpuDataManager(this);
this.programManager = new ProgramManager(this);
Expand Down
22 changes: 9 additions & 13 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ import { WebGpuBackend } from './backend-webgpu';
import { LOG_DEBUG } from './log';
import { TensorView } from './tensor-view';
import { ShapeUtil } from './util';
import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types';
import {
AdapterInfo,
ComputeContext,
ComputeContextInputsOutputsMapping,
DeviceInfo,
ProgramInfo,
} from './webgpu/types';
import { WebNNBackend } from './backend-webnn';

/* eslint-disable no-bitwise */
Expand Down Expand Up @@ -70,6 +76,7 @@ class TensorViewImpl implements TensorView {

class ComputeContextImpl implements ComputeContext {
readonly adapterInfo: AdapterInfo;
readonly deviceInfo: DeviceInfo;
readonly opKernelContext: number;
readonly inputs: readonly TensorView[];
readonly outputCount: number;
Expand All @@ -87,6 +94,7 @@ class ComputeContextImpl implements ComputeContext {
contextDataOffset: number,
) {
this.adapterInfo = backend.adapterInfo;
this.deviceInfo = backend.deviceInfo;

// extract context data
const ptrSize = module.PTR_SIZE;
Expand All @@ -112,18 +120,6 @@ class ComputeContextImpl implements ComputeContext {
this.inputs = inputs;
}

getMaxComputeWorkgroupSizes(): [number, number, number] {
return [
this.backend.device.limits.maxComputeWorkgroupSizeX,
this.backend.device.limits.maxComputeWorkgroupSizeY,
this.backend.device.limits.maxComputeWorkgroupSizeZ,
];
}

getMaxComputeWorkgroupStoragesize(): number {
return this.backend.device.limits.maxComputeWorkgroupStorageSize;
}

compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
// prepare inputs. inputs should always be valid data.
const mappedInputs =
Expand Down
20 changes: 15 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,23 @@ export class ProgramManager {
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
TRACE_FUNC_BEGIN(programInfo.name);
const device = this.backend.device;
const extensions: string[] = [];
if (device.features.has('shader-f16')) {
extensions.push('enable f16;');
}
const enableDirectives: string[] = [];

// Enable WGSL extensions based on available WebGPU features
const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [
{ feature: 'shader-f16', extension: 'f16' },
{ feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' },
{ feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' },
];
extensionsInfo.forEach((info) => {
if (device.features.has(info.feature)) {
enableDirectives.push(`enable ${info.extension};`);
}
});

const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits);
const userCode = programInfo.getShaderSource(shaderHelper);
const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const code = `${enableDirectives.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({ code, label: programInfo.name });
LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`);

Expand Down
12 changes: 10 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ export interface AdapterInfo {
isArchitecture: (architecture: GpuArchitecture) => boolean;
isVendor: (vendor: GpuVendor) => boolean;
}
export interface DeviceInfo {
readonly subgroupsSupported: boolean;
readonly subgroupsF16Supported: boolean;
readonly subgroupSizeRange?: readonly [number, number];
}

export interface GpuData {
type: GpuDataType;
Expand Down Expand Up @@ -160,6 +165,11 @@ export interface ComputeContext {
*/
readonly adapterInfo: AdapterInfo;

/**
* gpu device info
*/
readonly deviceInfo: DeviceInfo;

/**
* stores the pointer to OpKernelContext
*/
Expand Down Expand Up @@ -187,8 +197,6 @@ export interface ComputeContext {

compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[];
output(index: number, dims: readonly number[]): number;
getMaxComputeWorkgroupSizes(): [number, number, number];
getMaxComputeWorkgroupStoragesize(): number;
}

export type TimestampQuery = 'none' | 'inside-passes' | 'at-passes';
Loading