We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1458451 commit 05fe088Copy full SHA for 05fe088
tests/kernels/test_marlin_gemm.py
@@ -231,13 +231,12 @@ def test_gptq_marlin_gemm(
231
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
232
GPTQ_MARLIN_MAX_PARALLEL)
233
234
- opcheck(
235
- torch.ops._C.gptq_marlin_gemm,
236
- (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
237
- workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
238
- a_input.shape[1], is_k_full, False, use_atomic_add, use_fp32_reduce,
239
- False),
240
- test_utils=DEFAULT_OPCHECK_TEST_UTILS)
+ opcheck(torch.ops._C.gptq_marlin_gemm,
+ (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
+ workspace.scratch, quant_type.id, a_input.shape[0],
+ b_weight.shape[1], a_input.shape[1], is_k_full, False,
+ use_atomic_add, use_fp32_reduce, False),
+ test_utils=DEFAULT_OPCHECK_TEST_UTILS)
241
242
output = ops.gptq_marlin_gemm(
243
a_input,
0 commit comments