Skip to content

Commit e295638

Browse files
committed
Move Hq / Hk into function constants.
1 parent 93d0f48 commit e295638

File tree

5 files changed

+18
-28
lines changed

5 files changed

+18
-28
lines changed

lib/nnc/mfa/v2/AttentionDescriptor.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
121121
};
122122

123123
if (device->supportsFamily(MTL::GPUFamily(1009))) {
124-
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), Hq, Hk, createMemoryPrecisions(), true, false, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
124+
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), true, false, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
125125
} else {
126-
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), Hq, Hk, createMemoryPrecisions(), false, true, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
126+
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), false, true, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
127127
}
128128
}
129129

@@ -137,6 +137,10 @@ std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> Attention
137137
uint32_t columnDimension = matrixDimensions[1];
138138
constants->setConstantValue(&rowDimension, MTL::DataTypeUInt, NS::Integer(0));
139139
constants->setConstantValue(&columnDimension, MTL::DataTypeUInt, 1);
140+
uint32_t Hq = this->Hq;
141+
constants->setConstantValue(&Hq, MTL::DataTypeUInt, 2);
142+
uint32_t HHkRatio = this->Hq / this->Hk;
143+
constants->setConstantValue(&HHkRatio, MTL::DataTypeUInt, 3);
140144
std::vector<AttentionOperand> operands;
141145
switch (type.value) {
142146
case AttentionKernelType::forward:
@@ -151,7 +155,7 @@ std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> Attention
151155
}
152156
for (const auto& operand : operands) {
153157
uint32_t batchStride = batchStrides[operand].value_or(0);
154-
constants->setConstantValue(&batchStride, MTL::DataTypeUInt, 2 + operand.bufferIndex());
158+
constants->setConstantValue(&batchStride, MTL::DataTypeUInt, 4 + operand.bufferIndex());
155159
}
156160

157161
NS::String* swiftName = NS::String::string("attention", NS::UTF8StringEncoding);

lib/nnc/mfa/v2/AttentionKernel.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ AttentionKernel::AttentionKernel(AttentionKernelDescriptor descriptor, MTL::Devi
1717

1818
blockDimensions = descriptor.blockDimensions;
1919
headDimension = descriptor.headDimension;
20-
Hq = descriptor.Hq;
21-
Hk = descriptor.Hk;
2220
leadingDimensions = descriptor.leadingDimensions;
2321
scale = descriptor.scale;
2422
disableAsyncCopy = false;
@@ -498,10 +496,10 @@ std::string AttentionKernel::createConstants() const noexcept {
498496
operands = {AttentionOperand::Q, AttentionOperand::K, AttentionOperand::V, AttentionOperand::O, AttentionOperand::dO, AttentionOperand::dV, AttentionOperand::dK};
499497
break;
500498
}
501-
std::string output = "#define Hq (" + std::to_string(Hq) + ")\n";
499+
std::string output = "";
502500
for (const auto& operand : operands) {
503501
output += " constant uint " + operand.name() + "_batch_stride [[function_constant(";
504-
output += std::to_string(operand.bufferIndex() + 2) + ")]];\n";
502+
output += std::to_string(operand.bufferIndex() + 4) + ")]];\n";
505503
}
506504
return R"(
507505
@@ -511,6 +509,9 @@ std::string AttentionKernel::createConstants() const noexcept {
511509
constant uint R [[function_constant(0)]];
512510
constant uint C [[function_constant(1)]];
513511
512+
constant uint Hq [[function_constant(2)]];
513+
constant uint H_Hk_ratio [[function_constant(3)]];
514+
514515
)" + output;
515516
}
516517

@@ -542,15 +543,14 @@ std::string AttentionKernel::operandLocationWithHeadOffsetValue(AttentionOperand
542543
source.SetValue("OPERAND", operand.name());
543544
if (operand.value == AttentionOperand::L || operand.value == AttentionOperand::D) {
544545
source += "{{OPERAND}} + (gid.z * Hq + gid.y) * R\\";
545-
} else if (Hq > 1) {
546+
} else {
546547
source.SetValue("HEAD_DIMENSION", std::to_string(headDimension));
547-
if (Hq != Hk && (operand.value == AttentionOperand::K || operand.value == AttentionOperand::V || operand.value == AttentionOperand::dK || operand.value == AttentionOperand::dV)) {
548-
source.SetValue("H_HK_RATIO", std::to_string(Hq / Hk));
548+
if (operand.value == AttentionOperand::K || operand.value == AttentionOperand::V || operand.value == AttentionOperand::dK || operand.value == AttentionOperand::dV) {
549549
if (!transposed(operand)) {
550-
source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride + gid.y / {{H_HK_RATIO}} * {{HEAD_DIMENSION}}\\";
550+
source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride + gid.y / H_Hk_ratio * {{HEAD_DIMENSION}}\\";
551551
} else {
552552
source.SetValue("SEQUENCE_LENGTH", sequenceLength(operand));
553-
source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride + gid.y / {{H_HK_RATIO}} * {{HEAD_DIMENSION}} * {{SEQUENCE_LENGTH}}\\";
553+
source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride + gid.y / H_Hk_ratio * {{HEAD_DIMENSION}} * {{SEQUENCE_LENGTH}}\\";
554554
}
555555
} else {
556556
if (!transposed(operand)) {
@@ -560,8 +560,6 @@ std::string AttentionKernel::operandLocationWithHeadOffsetValue(AttentionOperand
560560
source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride + gid.y * {{HEAD_DIMENSION}} * {{SEQUENCE_LENGTH}}\\";
561561
}
562562
}
563-
} else {
564-
source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride\\";
565563
}
566564
return source.ToString();
567565
}

lib/nnc/mfa/v2/AttentionKernel.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ struct AttentionKernel {
3939

4040
unsigned short headDimension;
4141

42-
unsigned short Hq;
43-
44-
unsigned short Hk;
45-
4642
bool disableAsyncCopy;
4743

4844
unsigned short threadgroupMemoryAllocation;

lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ bool AttentionKernelDescriptor::operator==(const AttentionKernelDescriptor& rhs)
99
simd_all(blockDimensions == rhs.blockDimensions) &&
1010
cacheState == rhs.cacheState &&
1111
headDimension == rhs.headDimension &&
12-
Hq == rhs.Hq && Hk == rhs.Hk &&
1312
memoryPrecisions == rhs.memoryPrecisions &&
1413
(preferAsyncCache == rhs.preferAsyncCache) &&
1514
(preferAsyncLoad == rhs.preferAsyncLoad) &&
@@ -25,19 +24,16 @@ std::size_t std::hash<AttentionKernelDescriptor>::operator()(const AttentionKern
2524
using namespace ccv::nnc::mfa::hash;
2625
combine_64(seed, pack_64(simd_make_ushort4(hash.blockDimensions, 0)));
2726
combine_32(seed, pack_32(simd::ushort2 { hash.headDimension, hash.type.value }));
28-
combine_32(seed, pack_32(simd::ushort2 { hash.Hq, hash.Hk }));
2927
combine_32(seed, pack_32(simd::uchar4 { hash.preferAsyncCache, hash.preferAsyncLoad, 0, 0 }));
3028
return seed;
3129
}
3230

3331
// MARK: - Initializer
3432

35-
AttentionKernelDescriptor::AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands<bool> cacheState, unsigned short headDimension, unsigned short Hq, unsigned short Hk, AttentionOperands<GEMMOperandPrecision> memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands<GEMMOperandPrecision> registerPrecisions, AttentionOperands<bool> transposeState, AttentionOperands<unsigned short> leadingDimensions, AttentionKernelType type, float scale) noexcept {
33+
AttentionKernelDescriptor::AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands<bool> cacheState, unsigned short headDimension, AttentionOperands<GEMMOperandPrecision> memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands<GEMMOperandPrecision> registerPrecisions, AttentionOperands<bool> transposeState, AttentionOperands<unsigned short> leadingDimensions, AttentionKernelType type, float scale) noexcept {
3634
this->blockDimensions = blockDimensions;
3735
this->cacheState = cacheState;
3836
this->headDimension = headDimension;
39-
this->Hq = Hq;
40-
this->Hk = Hk;
4137
this->memoryPrecisions = memoryPrecisions;
4238
this->preferAsyncCache = preferAsyncCache;
4339
this->preferAsyncLoad = preferAsyncLoad;

lib/nnc/mfa/v2/AttentionKernelDescriptor.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ struct AttentionKernelDescriptor {
2020
/// Required. The problem size along the head dimension.
2121
unsigned short headDimension;
2222

23-
unsigned short Hq;
24-
25-
unsigned short Hk;
26-
2723
AttentionOperands<GEMMOperandPrecision> memoryPrecisions;
2824

2925
/// Reads with a one-to-one mapping to threads (like GEMM store) and writes.
@@ -62,7 +58,7 @@ struct AttentionKernelDescriptor {
6258
AttentionKernelDescriptor() = delete;
6359

6460
/// Initialize the kernel descriptor.
65-
AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands<bool> cacheState, unsigned short headDimension, unsigned short Hq, unsigned short Hk, AttentionOperands<GEMMOperandPrecision> memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands<GEMMOperandPrecision> registerPrecisions, AttentionOperands<bool> transposeState, AttentionOperands<unsigned short> leadingDimensions, AttentionKernelType type, float scale) noexcept;
61+
AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands<bool> cacheState, unsigned short headDimension, AttentionOperands<GEMMOperandPrecision> memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands<GEMMOperandPrecision> registerPrecisions, AttentionOperands<bool> transposeState, AttentionOperands<unsigned short> leadingDimensions, AttentionKernelType type, float scale) noexcept;
6662

6763
bool operator==(const AttentionKernelDescriptor& rhs) const;
6864
};

0 commit comments

Comments
 (0)