@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
66import { ShapeUtil } from '../../util' ;
77import { ComputeContext , ProgramInfo } from '../types' ;
88
9- import { inputVariable , outputVariable , ShaderHelper } from './common' ;
9+ import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper } from './common' ;
1010
1111const getRepeats = ( repeatsTensorView : TensorView ) : readonly number [ ] =>
1212 Array . from ( repeatsTensorView . getBigInt64Array ( ) , Number ) ;
@@ -54,30 +54,33 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf
5454 const outputSize = ShapeUtil . size ( outputShape ) ;
5555
5656 const dataType = inputs [ 0 ] . dataType ;
57- const input = inputVariable ( 'input' , dataType , inputShape ) ;
58- const output = outputVariable ( 'output' , dataType , outputShape ) ;
57+ const input = inputVariable ( 'input' , dataType , inputShape . length ) ;
58+ const output = outputVariable ( 'output' , dataType , outputShape . length ) ;
5959
6060 const getShaderSource = ( shaderHelper : ShaderHelper ) => `
6161 const inputShape = ${ input . indices ( ...inputShape ) } ;
62- ${ shaderHelper . declareVariables ( input , output ) }
62+ ${ shaderHelper . registerUniform ( 'output_size' , 'u32' ) . declareVariables ( input , output ) }
6363 ${ shaderHelper . mainStart ( ) }
64- ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( outputSize ) }
65- let outputIndices = ${ output . offsetToIndices ( 'global_idx' ) } ;
66- var inputIndices : ${ input . type . indices } ;
64+ ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
65+ let output_indices = ${ output . offsetToIndices ( 'global_idx' ) } ;
66+ var input_indices : ${ input . type . indices } ;
6767 for (var i = 0; i < ${ inputShape . length } ; i++) {
68- let inputDimValue = ${ output . indicesGet ( 'outputIndices' , 'i' ) } % ${ input . indicesGet ( 'inputShape' , 'i' ) } ;
68+ let input_dim_i = ${ input . indicesGet ( 'uniforms.input_shape' , 'i' ) } ;
69+ let input_dim_value = ${ output . indicesGet ( 'output_indices' , 'i' ) } % input_dim_i;
6970
70- ${ input . indicesSet ( 'inputIndices ' , 'i' , 'inputDimValue ' ) }
71+ ${ input . indicesSet ( 'input_indices ' , 'i' , 'input_dim_value ' ) }
7172 }
72- ${ output . setByOffset ( 'global_idx' , input . getByIndices ( 'inputIndices ' ) ) }
73+ ${ output . setByOffset ( 'global_idx' , input . getByIndices ( 'input_indices ' ) ) }
7374 }` ;
7475
7576 return {
7677 name : 'Tile' ,
77- shaderCache : { hint : `${ repeats } ` } ,
78+ shaderCache : { hint : `${ repeats } ` , inputDependencies : [ 'rank' ] } ,
7879 getRunData : ( ) => ( {
7980 outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
8081 dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
82+ programUniforms :
83+ [ { type : DataType . uint32 , data : outputSize } , ...createTensorShapeVariables ( inputs [ 0 ] . dims , outputShape ) ] ,
8184 } ) ,
8285 getShaderSource,
8386 } ;
0 commit comments