@@ -7,12 +7,12 @@ import {ShapeUtil} from '../../util';
77import  { AttributeWithCacheKey ,  createAttributeWithCacheKey }  from  '../attribute-with-cache-key' ; 
88import  { ComputeContext ,  ProgramInfo ,  ProgramUniform }  from  '../types' ; 
99
10- import  { getMaxComponents ,  inputVariable ,  outputVariable ,  ShaderHelper ,  tensorTypeToWsglStorageType ,  UniformsArrayType }  from  './common' ; 
10+ import  { createTensorShapeVariables ,   getMaxComponents ,  inputVariable ,  outputVariable ,  ShaderHelper ,  tensorTypeToWsglStorageType ,  UniformsArrayType }  from  './common' ; 
1111
1212//  TODO support quantization bits not equal to 4 
1313export  interface  MatMulNBitsAttributes  extends  AttributeWithCacheKey  { 
14-   k : number ; 
15-   n : number ; 
14+   K : number ; 
15+   N : number ; 
1616  accuracyLevel : number ; 
1717  bits : number ; 
1818  blockSize : number ; 
@@ -24,25 +24,25 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
2424  } 
2525  const  a  =  inputs [ 0 ] ; 
2626  const  aRank  =  a . dims . length ; 
27-   if  ( a . dims [ aRank  -  1 ]  !==  attributes . k )  { 
27+   if  ( a . dims [ aRank  -  1 ]  !==  attributes . K )  { 
2828    throw  new  Error ( 'The last dim of input shape does not match the k value' ) ; 
2929  } 
30-   const  nBlocksPerCol  =  Math . floor ( ( attributes . k  +  attributes . blockSize  -  1 )  /  attributes . blockSize ) ; 
30+   const  nBlocksPerCol  =  Math . floor ( ( attributes . K  +  attributes . blockSize  -  1 )  /  attributes . blockSize ) ; 
3131  const  blobSize  =  attributes . blockSize  /  8  *  attributes . bits ; 
3232  const  b  =  inputs [ 1 ] ; 
33-   if  ( ! ShapeUtil . areEqual ( b . dims ,  [ attributes . n ,  nBlocksPerCol ,  blobSize ] ) )  { 
33+   if  ( ! ShapeUtil . areEqual ( b . dims ,  [ attributes . N ,  nBlocksPerCol ,  blobSize ] ) )  { 
3434    throw  new  Error ( 'The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize' ) ; 
3535  } 
3636  const  scales  =  inputs [ 2 ] ; 
3737  const  scalesShape  =  scales . dims ; 
38-   if  ( ShapeUtil . size ( scalesShape )  !==  attributes . n  *  nBlocksPerCol )  { 
38+   if  ( ShapeUtil . size ( scalesShape )  !==  attributes . N  *  nBlocksPerCol )  { 
3939    throw  new  Error ( 'scales input size error.' ) ; 
4040  } 
4141  if  ( inputs . length  ===  4 )  { 
4242    const  zeroPoints  =  inputs [ 3 ] ; 
4343    const  zeroPointsShape  =  zeroPoints . dims ; 
4444    const  expectedZeroPointsSize  = 
45-         attributes . bits  >  4  ? ( attributes . n  *  nBlocksPerCol )  : attributes . n  *  Math . floor ( ( nBlocksPerCol  +  1 )  /  2 ) ; 
45+         attributes . bits  >  4  ? ( attributes . N  *  nBlocksPerCol )  : attributes . N  *  Math . floor ( ( nBlocksPerCol  +  1 )  /  2 ) ; 
4646    if  ( ShapeUtil . size ( zeroPointsShape )  !==  expectedZeroPointsSize )  { 
4747      throw  new  Error ( 'zeroPoints input size error.' ) ; 
4848    } 
@@ -53,19 +53,19 @@ export const createMatMulNBitsProgramInfo =
5353    ( inputs : readonly  TensorView [ ] ,  attributes : MatMulNBitsAttributes ) : ProgramInfo  =>  { 
5454      const  inputShape  =  inputs [ 0 ] . dims ; 
5555      const  aRank  =  inputShape . length ; 
56-       const  outputShape  =  inputShape . slice ( 0 ,  aRank  -  1 ) . concat ( attributes . n ) ; 
57-       const  m  =  inputShape [ aRank  -  2 ] ; 
56+       const  outputShape  =  inputShape . slice ( 0 ,  aRank  -  1 ) . concat ( attributes . N ) ; 
57+       const  M  =  inputShape [ aRank  -  2 ] ; 
5858      const  blobSize  =  attributes . blockSize  /  8  *  attributes . bits ; 
5959      const  blobSizeInWords  =  blobSize  /  4 ; 
60-       const  outputNumber  =  getMaxComponents ( m ) ; 
60+       const  outputNumber  =  getMaxComponents ( M ) ; 
6161      const  components  =  1 ;   // getMaxComponents(attributes.n); 
62-       const  aComponents  =  getMaxComponents ( attributes . k ) ; 
62+       const  aComponents  =  getMaxComponents ( attributes . K ) ; 
6363      const  bComponents  =  getMaxComponents ( blobSizeInWords ) ; 
64-       const  zComponents  =  1 ;   // getMaxComponents(attributes.n  / 8); 
64+       const  zComponents  =  1 ;   // getMaxComponents(attributes.N  / 8); 
6565      const  outputSize  =  ShapeUtil . size ( outputShape )  /  components  /  outputNumber ; 
6666      const  programUniforms : ProgramUniform [ ]  =  [ 
67-         { type : DataType . uint32 ,  data : outputSize } ,  { type : DataType . uint32 ,  data : attributes . k } , 
68-         { type : DataType . uint32 ,  data : attributes . n } ,  { type : DataType . uint32 ,  data : attributes . accuracyLevel } , 
67+         { type : DataType . uint32 ,  data : outputSize } ,  { type : DataType . uint32 ,  data : attributes . K } , 
68+         { type : DataType . uint32 ,  data : attributes . N } ,  { type : DataType . uint32 ,  data : attributes . accuracyLevel } , 
6969        { type : DataType . uint32 ,  data : attributes . bits } ,  { type : DataType . uint32 ,  data : attributes . blockSize } 
7070      ] ; 
7171      const  getShaderSource  =  ( shaderHelper : ShaderHelper )  =>  { 
@@ -88,7 +88,7 @@ export const createMatMulNBitsProgramInfo =
8888          { name : 'output_size' ,  type : 'u32' } ,  { name : 'K' ,  type : 'u32' } ,  { name : 'N' ,  type : 'u32' } , 
8989          { name : 'accuracy_level' ,  type : 'u32' } ,  { name : 'bits' ,  type : 'u32' } ,  { name : 'block_size' ,  type : 'u32' } 
9090        ] ; 
91-         const  nBlocksPerCol  =  Math . floor ( ( attributes . k  +  attributes . blockSize  -  1 )  /  attributes . blockSize ) ; 
91+         const  nBlocksPerCol  =  Math . floor ( ( attributes . K  +  attributes . blockSize  -  1 )  /  attributes . blockSize ) ; 
9292        const  dataType  =  tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ; 
9393        const  dequantizeArrayReturnType  =  ( ( )  =>  { 
9494          switch  ( aComponents )  { 
0 commit comments