@@ -121,9 +121,9 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
121121 };
122122
123123 if (device->supportsFamily (MTL::GPUFamily (1009 ))) {
124- return AttentionKernelDescriptor (createBlockDimensions (), createCacheState (), createHeadDimension (), Hq, Hk, createMemoryPrecisions (), true , false , createRegisterPrecisions (device), createTransposeState (), createLeadingDimensions (), type, scale);
124+ return AttentionKernelDescriptor (createBlockDimensions (), createCacheState (), createHeadDimension (), createMemoryPrecisions (), true , false , createRegisterPrecisions (device), createTransposeState (), createLeadingDimensions (), type, scale);
125125 } else {
126- return AttentionKernelDescriptor (createBlockDimensions (), createCacheState (), createHeadDimension (), Hq, Hk, createMemoryPrecisions (), false , true , createRegisterPrecisions (device), createTransposeState (), createLeadingDimensions (), type, scale);
126+ return AttentionKernelDescriptor (createBlockDimensions (), createCacheState (), createHeadDimension (), createMemoryPrecisions (), false , true , createRegisterPrecisions (device), createTransposeState (), createLeadingDimensions (), type, scale);
127127 }
128128}
129129
@@ -137,6 +137,10 @@ std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> Attention
137137 uint32_t columnDimension = matrixDimensions[1 ];
138138 constants->setConstantValue (&rowDimension, MTL::DataTypeUInt, NS::Integer (0 ));
139139 constants->setConstantValue (&columnDimension, MTL::DataTypeUInt, 1 );
140+ uint32_t Hq = this ->Hq ;
141+ constants->setConstantValue (&Hq, MTL::DataTypeUInt, 2 );
142+ uint32_t HHkRatio = this ->Hq / this ->Hk ;
143+ constants->setConstantValue (&HHkRatio, MTL::DataTypeUInt, 3 );
140144 std::vector<AttentionOperand> operands;
141145 switch (type.value ) {
142146 case AttentionKernelType::forward:
@@ -151,7 +155,7 @@ std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> Attention
151155 }
152156 for (const auto & operand : operands) {
153157 uint32_t batchStride = batchStrides[operand].value_or (0 );
154- constants->setConstantValue (&batchStride, MTL::DataTypeUInt, 2 + operand.bufferIndex ());
158+ constants->setConstantValue (&batchStride, MTL::DataTypeUInt, 4 + operand.bufferIndex ());
155159 }
156160
157161 NS::String* swiftName = NS::String::string (" attention" , NS::UTF8StringEncoding);
0 commit comments