Skip to content

Conversation

@defei-coder
Copy link

@defei-coder defei-coder commented Dec 13, 2023

FA2 Supports deterministic Computation Feature, fix this issue #429

  • Add one deterministic flag like FA1 to control whether the results are deterministic.
  • This feature needs extra worksapce for semaphores.
  • For convenience, users are allowed not to allocate extra workspace, it will be allocated automatically and throws one warning message.
  • We do not support deterministic=True and local=True at the same time, will workaround this limitation later.

Example(for usage and test):

    workspace = torch.zeros(flash_get_bwd_workspace_size_func(batch_size, nheads, seqlen_q, d), device=device, dtype=torch.int32)
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, workspace=workspace, deterministic=True)
    g = torch.randn_like(out)

    (
        dq,
        dk,
        dv,
     ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)

    (
        dq1,
        dk1,
        dv1,
    ) = torch.autograd.grad(out, (q, k, v), g)
assert (dq - dq1).abs().max().item() == 0

By this way, result of dq is deterministic. We run twice backward, get dq and dq1, The results of these two are completely identical.

@evanluyifan
Copy link

Since extra workspace could be reused during runtime, to avoid redundant GPU driver alloc/free at high frequence.
We may need to expose the python api for alloc extra workspace, let the framework allocator to handle gpu mem reuse.

@hwu36
Copy link

hwu36 commented Dec 13, 2023

https://research.colfax-intl.com/nvidia-hopper-flashattention-2/ replaced ampere mma.sync and cp.async with hopper tma and wgmma on hopper. It needs to use smaller tile size to prevent reg spilling and use warp specialization to hide more data movement.

@jayhshah

@evanluyifan
Copy link

Hi, @tridao this PR has been opened for a while, could you help to do the code review?
BTW, this change has been tested in cases from Meituan, the bwd outputs are deterministic.

@tridao
Copy link
Member

tridao commented Dec 21, 2023

Thanks, just got back from some travel, let me review it this week.

@tridao
Copy link
Member

tridao commented Dec 24, 2023

Thanks, I've incorporated some of the idea here with a slightly different approach, and there's an option for deterministic bwd as of v2.4.1. I've acknowledged your contribution in the README.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants