-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
[Model] Add Gemma 2 #5908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Add Gemma 2 #5908
Changes from all commits
7db6122
df2c007
a176803
a1ddec8
7fbcf48
8d5c6e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
from vllm.tracing import is_otel_installed | ||
from vllm.transformers_utils.config import get_config, get_hf_text_config | ||
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, | ||
is_hip, is_neuron, is_tpu, is_xpu, | ||
is_hip, is_neuron, is_tpu, is_xpu, print_warning_once, | ||
update_environment_variables) | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -141,6 +141,17 @@ def __init__( | |
code_revision, rope_scaling, rope_theta) | ||
self.hf_text_config = get_hf_text_config(self.hf_config) | ||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) | ||
|
||
if (not self.disable_sliding_window | ||
and self.hf_text_config.model_type == "gemma2" | ||
and self.hf_text_config.sliding_window is not None): | ||
print_warning_once( | ||
"Gemma 2 uses sliding window attention for every odd layer, " | ||
"which is currently not supported by vLLM. Disabling sliding " | ||
"window and capping the max length to the sliding window size " | ||
f"({self.hf_text_config.sliding_window}).") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changes look good. I think this warning should be updated to say something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-neuralmagic Oh maybe I misunderstood the change here. My intention was to enable the full (8K) context length with global attention for all layers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay sorry for the confusion the way you had it before will do this :) Setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think either option is reasonable.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-neuralmagic Hmm... OK let's use 4K context length for now and see if people want 8K content length despite the difference from the original model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-neuralmagic Updated the warning msg. PTAL! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM |
||
self.disable_sliding_window = True | ||
|
||
self.max_model_len = _get_and_verify_max_len( | ||
hf_config=self.hf_text_config, | ||
max_model_len=max_model_len, | ||
|
@@ -257,8 +268,7 @@ def verify_with_parallel_config( | |
"BitAndBytes quantization with TP or PP is not supported yet.") | ||
|
||
def get_hf_config_sliding_window(self) -> Optional[int]: | ||
"""Get the sliding window size, or None if disabled. | ||
""" | ||
"""Get the sliding window size, or None if disabled.""" | ||
|
||
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in | ||
# addition to sliding window size. We check if that field is present | ||
|
@@ -1256,10 +1266,16 @@ def _get_and_verify_dtype( | |
dtype = dtype.lower() | ||
if dtype == "auto": | ||
if config_dtype == torch.float32: | ||
# Following the common practice, we use float16 for float32 | ||
# models. | ||
logger.info("Casting torch.float32 to torch.float16.") | ||
torch_dtype = torch.float16 | ||
if config.model_type == "gemma2": | ||
logger.info( | ||
"For Gemma 2, we downcast float32 to bfloat16 instead " | ||
"of float16 by default. Please specify `dtype` if you " | ||
"want to use float16.") | ||
torch_dtype = torch.bfloat16 | ||
else: | ||
# Following the common practice, we use float16 for float32 | ||
# models. | ||
torch_dtype = torch.float16 | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
torch_dtype = config_dtype | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,3 +95,49 @@ def extra_repr(self) -> str: | |
s = f"hidden_size={self.weight.data.size(0)}" | ||
s += f", eps={self.variance_epsilon}" | ||
return s | ||
|
||
|
||
class GemmaRMSNorm(CustomOp): | ||
"""RMS normalization for Gemma. | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Two differences from the above RMSNorm: | ||
1. x * (1 + w) instead of x * w. | ||
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
eps: float = 1e-6, | ||
) -> None: | ||
super().__init__() | ||
self.weight = nn.Parameter(torch.zeros(hidden_size)) | ||
self.variance_epsilon = eps | ||
|
||
def forward_native( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should try decorating this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I was thinking about it or writing a CUDA kernel. Let's discuss this in another PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good -- agree it should be in a different PR :) |
||
self, | ||
x: torch.Tensor, | ||
residual: Optional[torch.Tensor] = None, | ||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
"""PyTorch-native implementation equivalent to forward().""" | ||
orig_dtype = x.dtype | ||
if residual is not None: | ||
x = x + residual | ||
residual = x | ||
|
||
x = x.float() | ||
variance = x.pow(2).mean(dim=-1, keepdim=True) | ||
x = x * torch.rsqrt(variance + self.variance_epsilon) | ||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) | ||
# See https://github.com/huggingface/transformers/pull/29402 | ||
x = x * (1.0 + self.weight.float()) | ||
x = x.to(orig_dtype) | ||
return x if residual is None else (x, residual) | ||
|
||
def forward_cuda( | ||
self, | ||
x: torch.Tensor, | ||
residual: Optional[torch.Tensor] = None, | ||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. | ||
return self.forward_native(x, residual) |
Uh oh!
There was an error while loading. Please reload this page.