Skip to content

Fix Bad Outputs in Fast Path for GraniteMoeHybrid #39033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025

Conversation

alex-jw-brooks
Copy link
Contributor

This PR fixes garbage output being produced when running granite moe hybrid models on the fast path; this is currently caused by inconsistent handling in has_previous_state on the fast path because the cache param was updated directly in the forward call of the state space layer in the slow path of bamba + updated through modular instead of handling the cache params at the end of forward like other state space models.

Quick repro case:

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import torch

model_path="ibm-granite/granite-4.0-tiny-preview"
device="cuda"
model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map=device,
        torch_dtype=torch.bfloat16,
    )
tokenizer = AutoTokenizer.from_pretrained(
        model_path
)

conv = [{"role": "user", "content":"What is a rattlesnake? "}]

input_ids = tokenizer.apply_chat_template(conv, return_tensors="pt", thinking=True, return_dict=True, add_generation_prompt=True).to(device)

set_seed(42)
output = model.generate(
    **input_ids,
    max_new_tokens=32,
)

prediction = tokenizer.decode(output[0, input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print(prediction)

Current output with mamba_ssm and causal_conv1d installed:

< the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the

vs without it

<think>A rattlesnake is a venomous snake species known for its distinctive rattle at the end of its tail, which it

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@ArthurZucker can you please take a look?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense but we should call this in https://github.com/huggingface/transformers/pull/39033/files#diff-75325057d583641e2cd9953d8803fa3ee0759f945df1e965152f28f39015c65cL468 no?
THis would be cleaner as contained in where the cache is updated

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Jun 25, 2025

Hey @ArthurZucker thanks for the quick review! IMO it's better to keep it out here, since this is independent of the fast/slow path implementations for forward and better aligns with the way other architectures (e.g., falcon, jamba, etc) are updating has_previous_state .

Torch/cuda forwards are also called per ssm layer, but this is a global property, so it'll also set it to True for the 2nd+ ssm layers on the first decode if we move it in, which may caused indexing errors if we aren't careful in checking when to use_precomputed_states in the forward implementations

@ArthurZucker
Copy link
Collaborator

Okay! Merging then thanks 🤗

@ArthurZucker ArthurZucker merged commit 5995cfa into huggingface:main Jun 26, 2025
12 checks passed
@tdoublep
Copy link

Thanks for the fix, we just hit this comparing against transformers in vLLM CI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants