Skip to content
18 changes: 18 additions & 0 deletions csrc/rocm/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a,
void* scale_b, const int M, const int K, const int Kp,
const int N, const int Otp_in, cudaStream_t stream,
const int CuCount);

void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in,
const int64_t Otp_in, const int64_t CuCount) {
auto M = in_a.size(0);
auto K = in_a.size(1);
auto Kp = in_a.stride(0);
int N = N_in;
int Otp = Otp_in;
wvSpltKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(),
scale_a.data_ptr(), scale_b.data_ptr(), M, K, Kp, N, Otp,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

Expand Down
Loading