You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1276
1278
TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8");
1279
+
TORCH_CHECK((head_size_v < 256 && head_size < 256) || !deterministic, "FlashAttention backward only supports deterministic when head dimension less than 256");
1277
1280
intconst max_headdim = get_max_headdim();
1278
1281
TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
1279
1282
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
0 commit comments