Skip to content

[Bug] Got unexpected answer while using EAGLE3 #8671

@zyksir

Description

@zyksir

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

I write a small test to test the correctness of the models. This test will send different response in parallel and check the response. I am using the newest sglang version in main branch to test it.
I find that

  • when I am using EAGLE3 and set max_concurrency to 4, the json response will have a high prob to be meaningless, like '{\n"{\\n \\"title\\": \\"The Great Gatsby\\",\\n \\"author\\": \\"F. Scott Fitzgerald\\",\\n \\"publication_year\\": 1925,\\n \\"genres\\": [\\n \\"Fiction\\",\\n \\"Tragedy\\"\\n ]\\n}" \n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n '
  • When i set max_concurrency to 1 or 2, this issue disappears. When I disable eagle, this issues disappears.

Reproduction

  • The launch server command is
export MODEL=meta-llama/Llama-3.1-8B-Instruct
export DRAFT=yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
python3 -m sglang.launch_server \
  --model ${MODEL} \
  --cuda-graph-max-bs 4 \
  --context-length 8192 \
  --dtype bfloat16 --mem-frac=0.8 \
  --speculative-algo EAGLE3 \
  --speculative-draft ${DRAFT} \
  --speculative-num-steps 3 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4

The test code is here

import asyncio
import json
import logging
from typing import Dict
from copy import deepcopy

from openai import AsyncOpenAI
from openai.types.chat.chat_completion import ChatCompletion

def json_mode_assertions(
    test_name: str,
    response: ChatCompletion,
    logger: logging.Logger,
) -> None:
    # Run base assertions first
    logger.info(f"[{test_name}] Running json_mode_assertions...")

    # Focus on JSON-specific business logic
    choice = response.choices[0]
    assert choice.finish_reason in ["stop", "abort"], (
        f"[{test_name}] JSON mode: finish_reason should be 'stop', got {choice.finish_reason}"
    )

    content = choice.message.content
    assert content is not None, f"[{test_name}] JSON mode: Message content is None"

    # Validate JSON parsability
    try:
        json.loads(content)  # This will raise an error if the content is not valid JSON
    except json.JSONDecodeError as e:
        raise AssertionError(
            f"[{test_name}] JSON mode: Invalid JSON content. Error: {str(e)}"
        ) from e

json_schema_invalid = {
    "type": "object",
    "properties": {
        "value": {"type": "float"},  # float is not part of the json schema spec
    },
    "required": ["value"],
    "additionalProperties": False,
}

oai_test_cases_list: list[Dict] = [
    {
        "_test_name": "JSON Mode Test (Name and Age)",
        "_test_assertion_callback": json_mode_assertions,
        "messages": [
            {
                "role": "system",
                "content": "You are a helpful assistant designed to output JSON.",
            },
            {
                "role": "user",
                "content": 'Provide a JSON object with two keys: \'name\' (string) and \'age\' (integer). Example: {"name": "John Doe", "age": 30}',
            },
        ],
        "temperature": 0.0,
        "response_format": {"type": "json_object"},
    },
    {
        "_test_name": "JSON Schema Test (Invalid)",
        "_test_expected_error_code": 400,
        "messages": [{"role": "user", "content": "Generate a short value."}],
        "temperature": 0.0,
        "response_format": {
            "type": "json_schema",
            "json_schema": {"name": "test_schema", "schema": json_schema_invalid},
        },
    },
]

class OAITestSuite:
    def __init__(
        self,
        client: AsyncOpenAI,
        model_name: str,
        default_max_tokens: int,
        max_concurrency: int = None,
    ):
        self.client = client
        self.model_name = model_name
        self.default_max_tokens = default_max_tokens
        self.semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None

    async def run(self, times=1):
        logger.info("Starting OAI conformance test suite.")

        # Create tasks for all test cases
        tasks = []
        for _ in range(times):
            for test_case_obj in oai_test_cases_list:
                current_max_tokens = (
                    test_case_obj["max_tokens"]
                    if "max_tokens" in test_case_obj and test_case_obj["max_tokens"] is not None
                    else self.default_max_tokens
                )
                # Create a task for each test case
                async def limited_request_func(func, *args, **kwargs):
                    if self.semaphore is None:
                        return await func(*args, **kwargs)
                    async with self.semaphore:
                        return await func(*args, **kwargs)
                task = limited_request_func(
                    self._run_single_oai_test,
                    client=self.client,
                    test_case=deepcopy(test_case_obj),
                    max_tokens=current_max_tokens,
                )
                tasks.append(task)

        # Run all tasks concurrently
        test_results = await asyncio.gather(*tasks)

        logger.info("OAI conformance test suite finished.")
        self._log_test_report(test_results)

    async def _run_single_oai_test(
        self, client: AsyncOpenAI, test_case: Dict, max_tokens: int
    ) -> dict:
        test_name = test_case.pop("_test_name")
        logger.info(f"--- Running OAI Conformance Test: {test_name} ---")
        test_expected_error_code = test_case.pop("_test_expected_error_code", None)
        test_assertion_callback = test_case.pop("_test_assertion_callback", None)
        payload = test_case
        payload["model"] = self.model_name
        payload["max_tokens"] = max_tokens
        result_details = {
            "name": test_name,
            "status": "UNKNOWN",
            "message": "",
        }

        logger.info(f"[{test_name}] Sending request with payload: {json.dumps(payload)}...")
        try:
            response = await client.chat.completions.create(**payload)

            if test_expected_error_code is not None:
                log_msg = f"Expected error status code {test_expected_error_code} but got success response"
                result_details["status"] = "FAILED"
                result_details["message"] = log_msg
            else:
                result_details["status"] = "PASSED"
                if test_assertion_callback:
                    try:
                        logger.info(f"[{test_name}] Response data for non-streaming: {str(response)}")
                        test_assertion_callback(test_name, response, logger)
                        result_details["message"] = "Non-streaming assertions passed."
                    except AssertionError as e:
                        error_msg = f"Assertion failed in non-streaming callback: {e}"
                        logger.error(f"[{test_name}] {error_msg}")
                        result_details["status"] = "FAILED_ASSERTION"
                        result_details["message"] = error_msg
                else:
                    info_msg = "Non-streaming test case has no assertion_callback. Logging raw response data."
                    logger.info(f"[{test_name}] {info_msg}")
                    logger.info(f"[{test_name}] Raw response data: {str(response)}")
                    result_details["message"] = info_msg
        except Exception as e:
            if test_expected_error_code is not None:
                error_status_code = getattr(e, "status_code", None)
                if error_status_code == test_expected_error_code:
                    log_msg = f"Correctly received expected error status code {test_expected_error_code}."
                    logger.info(f"[{test_name}] {log_msg}")
                    result_details["status"] = "PASSED_EXPECTED_ERROR"
                    result_details["message"] = log_msg
                else:
                    log_msg = f"Expected status code {test_expected_error_code} for error test, got {error_status_code}. Error: {str(e)}"
                    logger.error(f"[{test_name}] {log_msg}")
                    result_details["status"] = "PASSED"
                    result_details["message"] = log_msg
            else:
                log_msg = f"Unexpected error: {str(e)}"
                logger.error(f"[{test_name}] {log_msg}")
                result_details["status"] = "ERROR"
                result_details["message"] = log_msg
        logger.info(f"--- Finished OAI Conformance Test: {test_name} ---")
        return result_details

    def _log_test_report(self, test_results: list[dict]):
        logger.info("===== OAI Conformance Test Report =====")
        passed_count = 0
        failed_count = 0
        error_count = 0
        passed_expected_error_count = 0
        failed_assertion_count = 0

        for result in test_results:
            if result["status"] == "PASSED":
                passed_count += 1
            elif result["status"] == "FAILED":
                failed_count += 1
            elif result["status"] == "ERROR":
                error_count += 1
            elif result["status"] == "PASSED_EXPECTED_ERROR":
                passed_expected_error_count += 1
            elif result["status"] == "FAILED_ASSERTION":
                failed_assertion_count += 1
                failed_count += 1  # Count assertion failures as overall failures
        
        if failed_count > 0 or error_count > 0:
            logger.info("----- Details for Failed/Errored Tests -----")
            for result in test_results:
                if result["status"] not in ["PASSED", "PASSED_EXPECTED_ERROR"]:
                    logger.info(f"  Test: {result['name']}")
                    logger.info(f"    Status: {result['status']}")
                    logger.info(f"    Message: {result['message']}")
        logger.info("=======================================")

logger = logging.getLogger('specforce')
logging.basicConfig(
    level=logging.INFO,  # or DEBUG, WARNING, etc.
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
base_url = "http://localhost:30000/v1"
model_name = "meta-llama/Llama-3.1-8B-Instruct"
client = AsyncOpenAI(
    base_url=base_url,
    api_key="dummy-key",
)
oai_test_suite = OAITestSuite(
    client=client,
    model_name=model_name,
    default_max_tokens=100,
    max_concurrency=4,
)
await oai_test_suite.run(times=10)
!curl http://127.0.0.1:30000/flush_cache

Environment

I can reproduce this using one single H100. the output of python3 -m sglang.check_env is the following:

Python: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H100 80GB HBM3
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 550.127.08
PyTorch: 2.7.1+cu126
sglang: 0.4.10
sgl_kernel: 0.2.8
flashinfer_python: 0.2.9rc2
triton: 3.3.1
transformers: 4.54.1
torchao: 0.9.0
numpy: 2.3.2
aiohttp: 3.12.14
fastapi: 0.116.1
hf_transfer: 0.1.9
huggingface_hub: 0.34.1
interegular: 0.3.3
modelscope: 1.28.0
orjson: 3.11.0
outlines: 0.1.11
packaging: 25.0
psutil: 7.0.0
pydantic: 2.11.7
python-multipart: 0.0.20
pyzmq: 27.0.0
uvicorn: 0.35.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.21
openai: 1.97.1
tiktoken: 0.9.0
anthropic: 0.58.2
litellm: 1.74.7
decord: 0.6.0

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions