Skip to content

Support fused add rmsnorm for LLaMA #1667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions csrc/layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,19 @@ void rms_norm(
torch::Tensor& weight,
float epsilon);

void fused_add_rms_norm(
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
m.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
}
55 changes: 55 additions & 0 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,36 @@ __global__ void rms_norm_kernel(
}
}

// TODO: Further optimize this kernel.
template<typename scalar_t>
__global__ void fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx];
x += (float) residual[blockIdx.x * hidden_size + idx];
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}

} // namespace vllm

void rms_norm(
Expand All @@ -60,3 +90,28 @@ void rms_norm(
hidden_size);
});
}

void fused_add_rms_norm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
}
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Custom normalization layers."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn

Expand All @@ -21,7 +22,19 @@ def __init__(
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
layernorm_ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
layernorm_ops.rms_norm(
out,
Expand Down
23 changes: 14 additions & 9 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,28 @@ def forward(
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
return hidden_states, residual


class LlamaModel(nn.Module):
Expand Down Expand Up @@ -254,20 +257,22 @@ def forward(
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states = self.norm(hidden_states)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


Expand Down