Skip to content

Multi-Turn Rollout on a single GPU #1365

@ltdemey

Description

@ltdemey

TLDR;

Running into the assertion-error at /verl/workers/rollout/vllm_rollout/vllm_async_server.py:55
when running a single node variant of verl/tests/rollout/test_vllm_multi_turn.py. (See script in Appendix A)

Description

Given the exciting work on multi-turn rollouts introduced in #1138, we are trying to add multimodal environment interactions (similar idea as described in the VAGEN-paper)).

However, I run into an assertion error when running the AsyncLLMManager with a single AsyncLLMWorker. I tried going through the debugger for quite some while but cannot figure out why the Executor does not seem to find the Actors.

Reproduction

(see Appendix A for script)

python tests/rollout/test_vllm_multi_turn.py
2025-05-02 14:57:34,894 DEBUG worker.py:1576 -- Automatically increasing RLIMIT_NOFILE to max value of 65536
2025-05-02 14:57:34,897 DEBUG node.py:293 -- Setting node ID to d1cce7a158262a4f3567d99a4612737ffef92c57256a3bc1d555541b
2025-05-02 14:57:34,903 DEBUG node.py:1401 -- Process STDOUT and STDERR is being redirected to /tmp/ray/session_2025-05-02_14-57-34_895398_2793/logs.
2025-05-02 14:57:35,829 DEBUG node.py:1430 -- Process STDOUT and STDERR is being redirected to /tmp/ray/session_2025-05-02_14-57-34_895398_2793/logs.
2025-05-02 14:57:35,861 DEBUG npu.py:60 -- Could not import AscendCL: No module named 'acl'
2025-05-02 14:57:35,863 DEBUG tpu.py:115 -- Failed to detect number of TPUs: [Errno 2] No such file or directory: '/dev/vfio'
2025-05-02 14:57:35,864 DEBUG services.py:2140 -- Determine to start the Plasma object store with 9.49 GB memory using /dev/shm.
2025-05-02 14:57:35,912 INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
2025-05-02 14:57:36,852 WARNING __init__.py:161 -- DeprecationWarning: `ray.state.available_resources_per_node` is a private attribute and access will be removed in a future Ray version.
colocated worker base class <class 'verl.single_controller.base.worker.Worker'>
bind role actor_rollout method execute_method to class <class 'verl.single_controller.ray.base.create_colocated_worker_cls.<locals>.WorkerDict'>
WARNING:2025-05-02 14:57:37,950:Waiting for register center actor 5Y8l0k_register_center to be ready. Elapsed time: 0 seconds out of 300 seconds.
(WorkerDict pid=3433) Model config after override: Qwen2Config {
(WorkerDict pid=3433)   "architectures": [
(WorkerDict pid=3433)     "Qwen2ForCausalLM"
(WorkerDict pid=3433)   ],
(WorkerDict pid=3433)   "attention_dropout": 0.0,
(WorkerDict pid=3433)   "eos_token_id": 151645,
(WorkerDict pid=3433)   "hidden_act": "silu",
(WorkerDict pid=3433)   "hidden_size": 896,
(WorkerDict pid=3433)   "initializer_range": 0.02,
(WorkerDict pid=3433)   "intermediate_size": 4864,
(WorkerDict pid=3433)   "max_position_embeddings": 32768,
(WorkerDict pid=3433)   "max_window_layers": 21,
(WorkerDict pid=3433)   "model_type": "qwen2",
(WorkerDict pid=3433)   "num_attention_heads": 14,
(WorkerDict pid=3433)   "num_hidden_layers": 24,
(WorkerDict pid=3433)   "num_key_value_heads": 2,
(WorkerDict pid=3433)   "pad_token_id": 151643,
(WorkerDict pid=3433)   "rms_norm_eps": 1e-06,
(WorkerDict pid=3433)   "rope_scaling": null,
(WorkerDict pid=3433)   "rope_theta": 1000000.0,
(WorkerDict pid=3433)   "sliding_window": 32768,
(WorkerDict pid=3433)   "tie_word_embeddings": true,
(WorkerDict pid=3433)   "torch_dtype": "bfloat16",
(WorkerDict pid=3433)   "transformers_version": "4.51.1",
(WorkerDict pid=3433)   "use_cache": true,
(WorkerDict pid=3433)   "use_sliding_window": false,
(WorkerDict pid=3433)   "vocab_size": 151936
(WorkerDict pid=3433) }
(WorkerDict pid=3433) 
(WorkerDict pid=3433) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
(WorkerDict pid=3433) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
(WorkerDict pid=3433) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=3433) [rank0]:[W502 14:57:46.146030738 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
(WorkerDict pid=3433) Qwen2ForCausalLM contains 494.03M parameters
(WorkerDict pid=3433) wrap_policy: functools.partial(<function _or_policy at 0x7f8717a2e170>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f8717a2e050>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})])
(WorkerDict pid=3433) /usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py:444: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.FULL_SHARD since the world size is 1.
(WorkerDict pid=3433)   warnings.warn(
(WorkerDict pid=3433) Total steps: -1, num_warmup_steps: 0
(WorkerDict pid=3433) Actor use_remove_padding=False
(WorkerDict pid=3433) /usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=3433)   warnings.warn(
INFO 05-02 14:57:49 [__init__.py:239] Automatically detected platform cuda.
(AsyncvLLMServer pid=3620) FastAPI startup
(AsyncvLLMServer pid=3620) override_generation_config: {'n': 1, 'logprobs': 0, 'max_tokens': 1024, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(AsyncvLLMServer pid=3620) WARNING 05-02 14:58:04 [arg_utils.py:1713] Detected VLLM_USE_V1=1 with Engine in background thread. Usage should be considered experimental. Please report any issues on Github.
(AsyncvLLMServer pid=3620) WARNING 05-02 14:58:04 [cuda.py:96] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
(AsyncvLLMServer pid=3620) WARNING 05-02 14:58:04 [core_client.py:368] SIGUSR1 handler not installed because we are not running in the main thread. In this case the forked engine process may not be killed when an exception is raised, and you need to handle the engine process shutdown manually.
(AsyncvLLMServer pid=3620) 2025-05-02 14:58:11,322      INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8266 
(AsyncvLLMServer pid=3620) EngineCoreProc.run_engine_core
(AsyncvLLMServer pid=3620) EngineCoreProc.__init__
(AsyncvLLMServer pid=3620) EngineCore.__init__ - Start
(AsyncvLLMServer pid=3620) EngineCore.__init__ - Before model executor initialization
(AsyncvLLMServer pid=3620) namespace: 39fff769-940c-4c88-bb08-0ad63fb5e321
(AsyncvLLMServer pid=3620) actor_names: []
(AsyncvLLMServer pid=3620) all actors: []
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406] EngineCore hit an exception: Traceback (most recent call last):
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]   File "/home/ec2-user/vllm/vllm/v1/engine/core.py", line 394, in run_engine_core
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]     engine_core = EngineCoreProc(*args, **kwargs)
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]   File "/home/ec2-user/vllm/vllm/v1/engine/core.py", line 335, in __init__
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]     super().__init__(vllm_config, executor_class, log_stats)
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]   File "/home/ec2-user/vllm/vllm/v1/engine/core.py", line 68, in __init__
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]     self.model_executor = executor_class(vllm_config)
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]   File "/home/ec2-user/vllm/vllm/executor/executor_base.py", line 52, in __init__
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]     self._init_executor()
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]   File "/home/ec2-user/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py", line 58, in _init_executor
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406]     assert len(actor_names) == vllm_dp_size * vllm_tp_size, f"instance_id: {self.vllm_config.instance_id} has {len(actor_names)} actors, but vllm_dp_size: {vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = {vllm_dp_size * vllm_tp_size} is expected."
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406] AssertionError: instance_id: 39fff769-940c-4c88-bb08-0ad63fb5e321:5Y8l0k:1:0 has 0 actors, but vllm_dp_size: 1 * vllm_tp_size: 1 = 1 is expected.
(AsyncvLLMServer pid=3620) ERROR 05-02 14:58:12 [core.py:406] 
(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff4d25a17abd5c9be7ad69593a01000000 Worker ID: 5e1994c85b7e38c75d974f2886346b5662a2e33db3c8f160d894f1fb Node ID: d1cce7a158262a4f3567d99a4612737ffef92c57256a3bc1d555541b Worker IP address: xxx Worker port: 44215 Worker PID: 3620 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
Traceback (most recent call last):
  File "/home/ec2-user/verl/tests/rollout/test_vllm_multi_turn.py", line 179, in <module>
    test_vllm_multi_turn()
  File "/home/ec2-user/verl/tests/rollout/test_vllm_multi_turn.py", line 100, in test_vllm_multi_turn
    async_rollout_manager = AsyncLLMServerManager(
  File "/home/ec2-user/verl/verl/workers/rollout/async_server.py", line 271, in __init__
    ray.get([worker.init_engine.remote() for worker in self.async_llm_servers])
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2771, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 921, in get_objects
    raise value
ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
        class_name: AsyncvLLMServer
        actor_id: 4d25a17abd5c9be7ad69593a01000000
        pid: 3620
        name: async_llm_worker_0
        namespace: 39fff769-940c-4c88-bb08-0ad63fb5e321
        ip: xxx
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
root@ip-172-31-44-216:/home/ec2-user/verl# 

I correctly see the necessary Actors running in the Ray dashboard.

Image

Appendix A

verl/tests/rollout/test_vllm_multi_turn.py

# FIRST PART OF FILE UNCHANGED 

def test_vllm_multi_turn():
    config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
    model_path = "Qwen/Qwen2.5-0.5B-Instruct"
    model_name = "/".join(model_path.split("/")[-2:])
    config.actor_rollout_ref.model.path = model_path
    config.actor_rollout_ref.rollout.mode = "async"
    config.actor_rollout_ref.rollout.chat_scheduler = "examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler"
    config.actor_rollout_ref.rollout.prompt_length = 2048
    config.actor_rollout_ref.rollout.response_length = 1024


    # test sleep/wake_up with fsdp offload
    config.actor_rollout_ref.actor.fsdp_config.param_offload = False
    config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = False

    # add vm specific config
    config.ray_init.num_cpus = 8
    config.trainer.n_gpus_per_node = int(os.environ.get("NGPUS_PER_NODE", 1))
    config.actor_rollout_ref.rollout.tensor_model_parallel_size = int(os.environ.get("GEN_TP", 1))
    
    # =========================== 1. Create hybrid ActorRollout workers ===========================
    # make openai client happy
    os.environ["no_proxy"] = ""
    os.environ["http_proxy"] = ""
    os.environ["https_proxy"] = ""

    ray.init(
        logging_level=logging.DEBUG,
        runtime_env={
            "env_vars": {
                "TOKENIZERS_PARALLELISM": "true",
                "NCCL_DEBUG": "WARN",
                "VLLM_LOGGING_LEVEL": "WARN",
                "VLLM_USE_V1": "1",
            }
        }
    )
   
## REST OF FILE UNCHANGED

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions