Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
953eae6
change the default bash of qqwen 2.5
zhaochen20 Aug 4, 2025
0384aef
add engine support
zhaochen20 Aug 5, 2025
a0ab8b0
OVER_SAMPLE_RATE
zhaochen20 Aug 5, 2025
652f0f7
update sgl rollout
zhaochen20 Aug 5, 2025
3fdfd00
is val
zhaochen20 Aug 5, 2025
d37f51d
self.abort
zhaochen20 Aug 5, 2025
51410e7
cancel task
zhaochen20 Aug 5, 2025
de8feb3
cancel then abort
zhaochen20 Aug 5, 2025
da37e6a
fix await
zhaochen20 Aug 5, 2025
5308507
finish over sample
zhaochen20 Aug 5, 2025
25a15b9
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 5, 2025
07ad3f9
Merge branch 'main' into over_sample_sgl
zhaochen20 Aug 5, 2025
1b8bfa9
add log to over sample
zhaochen20 Aug 5, 2025
519a1b0
Add benchmakr script
zhaochen20 Aug 5, 2025
ae43c5a
increase to 45
zhaochen20 Aug 5, 2025
45265ec
incrse
zhaochen20 Aug 5, 2025
ef6ab93
add over sample
zhaochen20 Aug 5, 2025
264bed6
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 5, 2025
0c63925
Merge branch 'main' into over_sample_sgl
zhaochen20 Aug 5, 2025
5d3b970
fix reward and loss cal
zhaochen20 Aug 5, 2025
89da6d7
revert loss agg
zhaochen20 Aug 5, 2025
614f0ad
revert reward padding
zhaochen20 Aug 5, 2025
08b7e02
revert reward padding
zhaochen20 Aug 5, 2025
36abd78
finish
zhaochen20 Aug 5, 2025
b979a73
revert non_aborted_mask
zhaochen20 Aug 5, 2025
b4fdfcf
clean up codes
zhaochen20 Aug 5, 2025
88a23ce
Merge branch 'over_sample' of https://github.com/zhaochenyang20/verl …
zhaochen20 Aug 5, 2025
4e2316b
clean up codes
zhaochen20 Aug 5, 2025
3a37c6e
modify over sample rate
zhaochen20 Aug 5, 2025
94d7c68
update examples
zhaochen20 Aug 5, 2025
5fd6b67
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 6, 2025
1b6eb79
Merge branch 'main' into over_sample
zhaochen20 Aug 6, 2025
271b493
[tmp file] benchmark
zhaochen20 Aug 6, 2025
6e98696
upd
zhaochen20 Aug 6, 2025
a32cfb4
upd
zhaochen20 Aug 6, 2025
ea29d31
fix
zhaochen20 Aug 6, 2025
cc8fb16
launch
zhaochen20 Aug 6, 2025
60c1610
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 6, 2025
96a8ada
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 7, 2025
903c2d3
Merge branch 'main' into over_sample
zhaochen20 Aug 7, 2025
53099d1
discard dynamic batch
zhaochen20 Aug 7, 2025
3d7fa4a
add micro_32
zhaochen20 Aug 7, 2025
5a733f8
log bs
zhaochen20 Aug 7, 2025
4cc54b8
Add log_prob_micro_batch_size_per_gpu
zhaochen20 Aug 7, 2025
c67ebff
Add ppo_max_token_len_per_gpu
zhaochen20 Aug 7, 2025
43a9c31
add qwen3 4b
zhaochen20 Aug 7, 2025
8e17a93
use qwen3 4b
zhaochen20 Aug 7, 2025
71687d7
descrease micro
zhaochen20 Aug 7, 2025
b084dfd
update dapo for benchmarking
zhaochen20 Aug 7, 2025
b66327e
fix hydra
zhaochen20 Aug 7, 2025
23c2957
use 8 gpu
zhaochen20 Aug 7, 2025
b944230
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 7, 2025
7ee096e
Merge branch 'main' into over_sample
zhaochen20 Aug 7, 2025
97d137d
delete not used variale
PopSoda2002 Aug 7, 2025
35b65eb
use full set to eval
zhaochen20 Aug 7, 2025
6f3e49e
Merge branch 'over_sample' of https://github.com/zhaochenyang20/verl …
zhaochen20 Aug 7, 2025
b0f0245
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 11, 2025
2f5d4f5
Merge branch 'main' into over_sample
zhaochen20 Aug 11, 2025
4c882d1
finish over sampling
zhaochen20 Aug 11, 2025
4822564
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 13, 2025
b28f178
Merge branch 'main' into over_sample
zhaochen20 Aug 13, 2025
a794b29
fix rollout config
zhaochen20 Aug 14, 2025
cdc2eff
do not capture all
zhaochen20 Aug 14, 2025
77e8491
clean up validation
zhaochen20 Aug 14, 2025
4fd6b36
fix conflicts register
zhaochen20 Aug 14, 2025
58f3570
fix future
zhaochen20 Aug 14, 2025
23452ed
use event loop
zhaochen20 Aug 14, 2025
3ce9b1f
get event loop
zhaochen20 Aug 14, 2025
d6128a2
upd, fix up
zhaochen20 Aug 14, 2025
ba87964
fix up
zhaochen20 Aug 14, 2025
96c5d75
delete comment
zhaochen20 Aug 14, 2025
c1e8884
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 14, 2025
06beaf0
Merge branch 'main' into over_sample
zhaochen20 Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ ulimit -n 65535
PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"

function now() {
date '+%d-%H-%M'
}

EXPERIMENT_NAME="qwen2.5-3b_baseline_$(now)_$OVER_SAMPLE_RATE"

python3 -m verl.trainer.main_ppo \
--config-path="$CONFIG_PATH" \
--config-name='gsm8k_multiturn_grpo' \
Expand All @@ -31,21 +37,23 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.multi_stage_wake_up=True \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='gsm8k_async_rl' \
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \
trainer.project_name='benchmark_over_sample_2' \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=20 \
trainer.val_before_train=True \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
Expand Down
36 changes: 28 additions & 8 deletions verl/trainer/ppo/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]

# 检测被abort的请求:response_mask全为0的请求
# 被abort的请求的response部分完全没有有效token
# 使用batch中的response_mask来保持与agg_loss的一致性
aborted_mask = (response_length == 0).bool() # response_length为0表示被abort

non_aborted_mask = ~aborted_mask

print("over sample rate in metric_utils: ", non_aborted_mask.sum() / len(non_aborted_mask))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this


non_aborted_sequence_score = sequence_score[non_aborted_mask]
non_aborted_sequence_reward = sequence_reward[non_aborted_mask]

score_mean = torch.mean(non_aborted_sequence_score).detach().item()
score_max = torch.max(non_aborted_sequence_score).detach().item()
score_min = torch.min(non_aborted_sequence_score).detach().item()

reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
reward_max = torch.max(non_aborted_sequence_reward).detach().item()
reward_min = torch.min(non_aborted_sequence_reward).detach().item()

valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)

Expand All @@ -128,14 +148,14 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
return_var = torch.var(valid_returns)

metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
# score (只对非abort请求计算平均值)
"critic/score/mean": score_mean,
"critic/score/max": score_max,
"critic/score/min": score_min,
# reward (只对非abort请求计算平均值)
"critic/rewards/mean": reward_mean,
"critic/rewards/max": reward_max,
"critic/rewards/min": reward_min,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to apply the same non_aborted_mask logic to the response_length metrics as well ( mean/max/min response length)? Otherwise, the metrics might still include padded responses from aborted requests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed accordingly. Adding the response length mean after/before abort. And I added the drop rate metric.

# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
Expand Down
242 changes: 225 additions & 17 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

OVER_SAMPLE_RATE = float(os.getenv("OVER_SAMPLE_RATE", 1))


# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723
def _set_envs_and_config(server_args: ServerArgs):
Expand Down Expand Up @@ -170,6 +172,21 @@ async def update_weights_from_tensor(self, update_weights_request: UpdateWeights
async def flush_cache(self):
return await self.tokenizer_manager.flush_cache()

async def abort_request(self, rid: str = "", abort_all: bool = False):
"""Abort a specific request or all requests.

Args:
rid: The request ID to abort. If empty and abort_all is False, no action is taken.
abort_all: If True, abort all running requests regardless of rid.
"""
try:
result = self.tokenizer_manager.abort_request(rid=rid, abort_all=abort_all)
print(f"🔍 Abort result: {result}")
return result if result is not None else {"status": "aborted"}
except Exception as e:
logger.error(f"Failed to abort requests: {e}")
raise


# NOTE(sgm): add for verl. We can optimize it by making
# the dataloader yield List[int] without padding.
Expand Down Expand Up @@ -872,7 +889,9 @@ async def _async_rollout_a_request(
# Only continue the conversation if the prompt length is not greater than max_model_len - 1,
# since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra
# token accounts for the EOS token).
if len(_req.get_generation_prompt_ids(self.processing_class)) + 1 >= self.config.max_model_len:
prompt_length = len(_req.get_generation_prompt_ids(self.processing_class))

if prompt_length + 1 >= self.config.max_model_len:
finish_reason_type = FinishReasonTypeEnum.LENGTH
break

Expand Down Expand Up @@ -1024,9 +1043,11 @@ async def _handle_engine_generate(
self, generation_prompt_ids: list[int], sampling_params: dict, image_data: Optional[list[Any]] = None
) -> dict:
max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)

kwargs = sampling_params.copy()
kwargs["max_new_tokens"] = max_new_tokens
kwargs["n"] = 1 # group size is supported in preprocess

output = await self._engine.async_generate(
input_ids=generation_prompt_ids,
sampling_params=kwargs,
Expand Down Expand Up @@ -1059,16 +1080,6 @@ async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRollout
interaction = self.interaction_map[interaction_name]
await interaction.start_interaction(_req.request_id, **interaction_kwargs)

@GPUMemoryLogger(role="sglang rollout", logger=logger)
@torch.no_grad()
def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto:
logger.warning(
"`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`",
DeprecationWarning,
stacklevel=2,
)
return self._req_level_generate_sequences(prompts, **kwargs)

@GPUMemoryLogger(role="sglang rollout", logger=logger)
@torch.no_grad()
def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
Expand All @@ -1082,16 +1093,125 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro
do_sample = prompts.meta_info.get("do_sample", True)
is_validate = prompts.meta_info.get("validate", False)
tgt_device = prompts.batch["input_ids"].device

if self._tp_rank == 0:
req_list = self._preprocess_prompt_to_async_rollout_requests(
prompts,
)
loop = asyncio.get_event_loop()
output_req_list = loop.run_until_complete(
asyncio.gather(
*[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],
)
)

# 添加进度监控和abort功能
total_requests = len(req_list)
target_completion = int(total_requests * OVER_SAMPLE_RATE) # 80%完成时abort
print(f"🎯 Training mode: over sampling target {target_completion}/{total_requests}")
completed_count = 0
aborted_requests = []

# 区分训练和验证阶段
if is_validate:
print(f"🔍 Validation mode: processing all {total_requests} requests without abort")

# 验证阶段:处理所有请求,不使用abort
async def process_all_requests():
return await asyncio.gather(
*[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],
)

loop = asyncio.get_event_loop()
output_req_list = loop.run_until_complete(process_all_requests())
else:
print(f"🎯 Training mode: over sampling target {target_completion}/{total_requests}")

completion_lock = asyncio.Lock()
all_tasks = []

async def process_request_with_monitoring(req):
nonlocal completed_count
try:
result = await self._async_rollout_a_request(req, do_sample, is_validate, **kwargs)

async with completion_lock:
if completed_count < target_completion:
completed_count += 1
print(f"✅ Request {req.request_id} completed ({completed_count}/{total_requests})")
return result
else:
# 这个请求虽然完成了,但已经超过目标,返回padding
logger.info(f"Request {req.request_id} finished after target met, creating padding")
return self._create_padding_request(req)
except asyncio.CancelledError:
# 请求被取消,返回padding
logger.info(f"Request {req.request_id} was cancelled, creating padding")
aborted_requests.append(req.request_id)
return self._create_padding_request(req)
except Exception as e:
# 请求失败,也算作完成
logger.warning(f"Request {req.request_id} failed: {e}")
aborted_requests.append(req.request_id)
async with completion_lock:
if completed_count < target_completion:
completed_count += 1
return self._create_padding_request(req)

async def monitor_and_cancel():
nonlocal completed_count
while completed_count < target_completion:
await asyncio.sleep(0.1)

print(f"🎯 Target reached: {completed_count}/{total_requests} completed!")
print("🚫 Cancelling remaining requests and sending abort to engine...")

# 取消剩余的任务
cancelled_count = 0
for task in all_tasks:
if not task.done():
task.cancel()
cancelled_count += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the usage of this variable cancelled_count ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Not used

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


print(f"📋 Cancelled {cancelled_count} remaining tasks")

# 向engine发送abort信号,中断所有正在进行的请求
try:
abort_result = await self._engine.abort_request(abort_all=True)
print(f"✅ Abort signal sent to engine: {abort_result}")
except Exception as e:
print(f"❌ Failed to send abort signal to engine: {e}")

async def run_with_cancellation():
nonlocal all_tasks

# 创建所有任务
all_tasks = [asyncio.create_task(process_request_with_monitoring(req)) for req in req_list]

# 启动监控任务
monitor_task = asyncio.create_task(monitor_and_cancel())

try:
# 等待所有任务完成(包括被取消的)
results = await asyncio.gather(*all_tasks, return_exceptions=True)

# 处理结果,将异常转换为padding
output_req_list = []
for i, result in enumerate(results):
if isinstance(result, Exception):
# 如果是异常(包括CancelledError),创建padding
logger.warning(f"Task {i} resulted in exception: {result}")
output_req_list.append(self._create_padding_request(req_list[i]))
else:
output_req_list.append(result)

return output_req_list
finally:
# 取消监控任务
monitor_task.cancel()
try:
await monitor_task
except asyncio.CancelledError:
pass

# 运行异步任务
loop = asyncio.get_event_loop()
output_req_list = loop.run_until_complete(run_with_cancellation())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally, I think encapsulating this part code in the training mode is better, currently the whole function is too long

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do this, but the current design is good to me. 😂 Just wrap these three functions like process_request_with_monitoring, monitor_and_cancel, run_with_cancellation should be cool.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, It varies from person to person


sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))
else:
sorted_output_req_list = None
Expand Down Expand Up @@ -1279,6 +1399,94 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro
non_tensor_batch=non_tensor_batch,
)

def _create_padding_request(self, original_req: AsyncRolloutRequest) -> AsyncRolloutRequest:
"""创建一个padding请求,用于替代被abort的请求。

这个padding请求的特点是:
1. 状态为COMPLETED,但包含空的response
2. response_loss_mask全为0,确保在loss计算中被忽略
3. 保持原始请求的结构,但内容为空
"""
# 创建padding的response_ids (全为pad_token_id)
padding_response_length = self.config.response_length
padding_response_ids = torch.full(
(1, padding_response_length),
self.pad_token_id,
dtype=torch.long,
device=original_req.input_ids.device if original_req.input_ids is not None else "cpu",
)

# 创建padding的attention_mask (全为0)
padding_response_attention_mask = torch.zeros(
(1, padding_response_length),
dtype=torch.long,
device=original_req.attention_mask.device if original_req.attention_mask is not None else "cpu",
)

# 创建padding的position_ids
if original_req.position_ids is not None:
prompt_length = original_req.prompt_ids.shape[-1] if original_req.prompt_ids is not None else 0
padding_response_position_ids = torch.arange(
prompt_length, prompt_length + padding_response_length, dtype=torch.long
).unsqueeze(0)
if original_req.position_ids.dim() == 2:
# 如果是2D tensor (如qwen2vl)
padding_response_position_ids = padding_response_position_ids.repeat(
original_req.position_ids.shape[0], 1
)
else:
padding_response_position_ids = None

# 创建padding的loss_mask (全为0,确保被忽略)
padding_response_loss_mask = torch.zeros(
(1, padding_response_length),
dtype=torch.long,
device=original_req.loss_mask.device if original_req.loss_mask is not None else "cpu",
)

# 创建新的请求,保持原始结构但使用padding数据
padding_req = AsyncRolloutRequest(
batch_data_id=original_req.batch_data_id,
rollout_offset=original_req.rollout_offset,
request_id=original_req.request_id + "_padding",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's hard for us to profile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, not influence profiling

state=AsyncRolloutRequestStateEnum.COMPLETED,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state change to abort

messages=original_req.messages, # 保持原始messages
multi_modal_keys=original_req.multi_modal_keys,
multi_modal_data=original_req.multi_modal_data,
multi_modal_inputs=original_req.multi_modal_inputs,
tool_schemas=original_req.tool_schemas,
tools_kwargs=original_req.tools_kwargs,
interaction_kwargs=original_req.interaction_kwargs,
input_ids=original_req.input_ids, # 保持原始input_ids
prompt_ids=original_req.prompt_ids, # 保持原始prompt_ids
response_ids=padding_response_ids, # 使用padding的response_ids
attention_mask=original_req.attention_mask, # 保持原始attention_mask
prompt_attention_mask=original_req.prompt_attention_mask, # 保持原始prompt_attention_mask
response_attention_mask=padding_response_attention_mask, # 使用padding的response_attention_mask
position_ids=original_req.position_ids, # 保持原始position_ids
prompt_position_ids=original_req.prompt_position_ids, # 保持原始prompt_position_ids
response_position_ids=padding_response_position_ids, # 使用padding的response_position_ids
loss_mask=original_req.loss_mask, # 保持原始loss_mask
prompt_loss_mask=original_req.prompt_loss_mask, # 保持原始prompt_loss_mask
response_loss_mask=padding_response_loss_mask, # 使用padding的response_loss_mask (全为0)
reward_scores={}, # 空的reward_scores
max_prompt_len=original_req.max_prompt_len,
max_response_len=original_req.max_response_len,
max_model_len=original_req.max_model_len,
metrics={}, # 空的metrics
output_token_ids=None, # 空的output_token_ids
rollout_log_probs=None, # 空的rollout_log_probs
use_inference_chat_template=original_req.use_inference_chat_template,
tokenization_sanity_check_mode=original_req.tokenization_sanity_check_mode,
generation_prompt_ids=original_req.generation_prompt_ids,
base_conv_wo_gen_prompt_end_pos=original_req.base_conv_wo_gen_prompt_end_pos,
base_conv_with_gen_prompt_end_pos=original_req.base_conv_with_gen_prompt_end_pos,
processing_class=self.processing_class, # 添加缺少的 processing_class 参数
)

logger.info(f"Created padding request for aborted request {original_req.request_id}")
return padding_req

def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int = 1) -> list[AsyncRolloutRequest]:
assert "raw_prompt" in prompts.non_tensor_batch, (
"need data.return_raw_chat=True, due to no official way do parse_messages"
Expand Down
Loading