Skip to content

Commit bc9062d

Browse files
authored
[sglang] Fix tool format and response position ids padding in AsyncSGLangRollout (#1475)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? > Add one-line overview of what this PR aims to achieve or accomplish. Resolved the tool formatting issue: Previously, arguments were stored as strings, causing iterative addition of `\\` due to multiple calls to `json.dumps`. Fixed the `response_position_ids` mismatch between `generate_sequences` and `generate_sequences_with_tools`: In the earlier implementation, `generate_sequences_with_tools` used zero padding for positions where `attention mask == 0`, which resulted in NaN values during the training phase. ### Specific Changes > List the specific changes. - Introduced a new schema, `OpenAIFunctionCallSchema`, to store converted tool calls. - Updated the `AsyncSGLangRollout` tool to skip non-dict type arguments instead of handling any string at the arguments position. - Aligned `response_position_ids` in `generate_sequences_with_tools` with the behavior of `generate_sequences`. - Enhanced tool descriptions to prevent misleading parse errors, as returning 0.0 caused the model to incorrectly modify answers. ### API > Demonstrate how the API changes if any. - Revise the `execute` interface of the tool to directly accept `dict[str, Any]` instead of a JSON string. ### Usage Example > Provide usage example(s) for easier usage. ```python # Add code snippet or script demonstrating how to use this ``` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add CI test(s) if neccessary.
1 parent db83855 commit bc9062d

File tree

5 files changed

+70
-29
lines changed

5 files changed

+70
-29
lines changed

examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ tools:
55
type: "function"
66
function:
77
name: "calc_gsm8k_reward"
8-
description: "A tool for calculating the reward of gsm8k. (1.0 if your answer is correct, 0.0 if your answer is incorrect)"
8+
description: "A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)"
99
parameters:
1010
type: "object"
1111
properties:

verl/tools/base_tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Optional, Tuple
15+
from typing import Any, Optional, Tuple
1616
from uuid import uuid4
1717

1818
from .schemas import OpenAIFunctionToolSchema
@@ -52,7 +52,7 @@ async def create(self, instance_id: Optional[str] = None, **kwargs) -> str:
5252
else:
5353
return instance_id
5454

55-
async def execute(self, instance_id: str, parameters: str, **kwargs) -> Tuple[str, float, dict]:
55+
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
5656
"""Execute the tool.
5757
5858
Args:

verl/tools/gsm8k_tool.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import json
1716
import logging
1817
import os
19-
from typing import Optional, Tuple
18+
from typing import Any, Optional, Tuple
2019
from uuid import uuid4
2120

2221
from verl.utils.reward_score import gsm8k
@@ -74,26 +73,22 @@ async def create(self, instance_id: Optional[str] = None, ground_truth: Optional
7473
}
7574
return instance_id
7675

77-
async def execute(self, instance_id: str, parameters: str, **kwargs) -> Tuple[str, float, dict]:
78-
try:
79-
_parameters = json.loads(parameters)
80-
except json.JSONDecodeError:
81-
_parameters = {}
82-
if isinstance(_parameters, dict):
83-
answer = _parameters.get("answer", "")
84-
if not isinstance(answer, str):
85-
answer = str(answer)
86-
else:
87-
answer = ""
76+
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
77+
answer = parameters.get("answer", "")
78+
if not isinstance(answer, str):
79+
answer = str(answer)
80+
8881
if answer.startswith("#### "):
8982
self._instance_dict[instance_id]["response"] = answer
9083
else:
9184
self._instance_dict[instance_id]["response"] = "#### " + answer
85+
9286
reward = await self.calc_reward(instance_id)
9387
# penalty for non improved answer submission
9488
tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05
9589
# update the reward
9690
self._instance_dict[instance_id]["reward"] = reward
91+
9792
return f"Current parsed {answer=} {reward=}", tool_reward, {}
9893

9994
async def calc_reward(self, instance_id: str, **kwargs) -> float:

verl/tools/schemas.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Literal
15+
import json
16+
from typing import Any, Literal
1617

1718
from pydantic import BaseModel
1819

@@ -56,9 +57,31 @@ class OpenAIFunctionParsedSchema(BaseModel):
5657
arguments: str # JSON string
5758

5859

60+
class OpenAIFunctionCallSchema(BaseModel):
61+
"""The parsed schema of a tool in OpenAI format."""
62+
63+
name: str
64+
arguments: dict[str, Any]
65+
66+
@staticmethod
67+
def from_openai_function_parsed_schema(parsed_schema: OpenAIFunctionParsedSchema) -> tuple["OpenAIFunctionCallSchema", bool]:
68+
has_decode_error = False
69+
try:
70+
arguments = json.loads(parsed_schema.arguments)
71+
except json.JSONDecodeError:
72+
arguments = {}
73+
has_decode_error = True
74+
# If the arguments is not a dict, it means the arguments is not a valid JSON string
75+
if not isinstance(arguments, dict):
76+
arguments = {}
77+
has_decode_error = True
78+
79+
return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error
80+
81+
5982
class OpenAIFunctionToolCall(BaseModel):
6083
"""The tool call in OpenAI format."""
6184

6285
id: str
6386
type: Literal["function"] = "function"
64-
function: OpenAIFunctionParsedSchema
87+
function: OpenAIFunctionCallSchema

verl/workers/rollout/sglang_rollout/async_sglang_rollout.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from verl import DataProto
4242
from verl.third_party.sglang import parallel_state as sglang_ps
4343
from verl.tools.base_tool import BaseTool
44-
from verl.tools.schemas import OpenAIFunctionParsedSchema, OpenAIFunctionToolCall
44+
from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall
4545
from verl.utils.debug import GPUMemoryLogger
4646
from verl.utils.model import compute_position_id_with_mask
4747
from verl.utils.net_utils import is_ipv6
@@ -93,6 +93,7 @@ def __init__(
9393
"""
9494
super().__init__()
9595
self.config = config
96+
os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true")
9697

9798
tool_list = None
9899
if config.multi_turn.tool_config_path is not None:
@@ -216,6 +217,7 @@ def initialize_tools(tools_config) -> list:
216217
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
217218

218219
if first_rank_in_node:
220+
rank = dist.get_rank()
219221
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
220222
self._engine = Engine(
221223
model_path=actor_module,
@@ -230,6 +232,16 @@ def initialize_tools(tools_config) -> list:
230232
load_format=load_format,
231233
dist_init_addr=dist_init_addr,
232234
trust_remote_code=trust_remote_code,
235+
# NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new
236+
# when random.seed is being set during training
237+
port=30000 + rank,
238+
# NOTE(Chenyang): if you want to debug the SGLang engine output
239+
# please set the following parameters
240+
# Otherwise, it will make the engine run too slow
241+
# log_level="INFO",
242+
# log_requests=True,
243+
# log_requests_level=2,
244+
# max_running_requests=1,
233245
)
234246
else:
235247
self._engine = None
@@ -271,7 +283,7 @@ def update_sampling_params(self, **kwargs):
271283
for key, value in old_sampling_params_args.items():
272284
self.sampling_params[key] = value
273285

274-
@GPUMemoryLogger(role="sglang rollout", logger=logger)
286+
@GPUMemoryLogger(role="sglang async rollout", logger=logger)
275287
@torch.no_grad()
276288
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
277289
# if self.config.free_cache_engine:
@@ -508,13 +520,18 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo
508520
except AttributeError:
509521
normed_content = content
510522
tool_calls = []
511-
parsed_tool_calls = [
512-
OpenAIFunctionToolCall(
513-
id=str(tool_call.tool_index),
514-
function=OpenAIFunctionParsedSchema(name=tool_call.name, arguments=tool_call.parameters),
523+
parsed_tool_calls = []
524+
for tool_call in tool_calls:
525+
function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(OpenAIFunctionParsedSchema(name=tool_call.name, arguments=tool_call.parameters))
526+
# Drop the tool call if its arguments has decode error
527+
if has_decode_error:
528+
continue
529+
parsed_tool_calls.append(
530+
OpenAIFunctionToolCall(
531+
id=str(tool_call.tool_index),
532+
function=function,
533+
)
515534
)
516-
for tool_call in tool_calls
517-
]
518535
if len(parsed_tool_calls) > 0:
519536
_req.add_assistant_message(
520537
self.tokenizer,
@@ -550,6 +567,7 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool):
550567

551568
return _req
552569

570+
@GPUMemoryLogger(role="sglang async rollout", logger=logger)
553571
@torch.no_grad()
554572
def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto:
555573
# Async rollout with tools support
@@ -632,9 +650,10 @@ def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataPro
632650
prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left")
633651
if prompt_position_ids.shape[1] < self.config.prompt_length:
634652
prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True)
635-
response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0)
636-
if response_position_ids.shape[1] < self.config.response_length:
637-
response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0)
653+
response_length = response_ids.size(1)
654+
delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device)
655+
delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1)
656+
response_position_ids = prompt_position_ids[:, -1:] + delta_position_id
638657
prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left")
639658
if prompt_loss_mask.shape[1] < self.config.prompt_length:
640659
prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True)
@@ -660,6 +679,10 @@ def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataPro
660679
batch_size=len(sorted_output_req_list),
661680
)
662681

682+
# free cache engine
683+
if self.config.free_cache_engine and self._engine is not None and self._tp_rank == 0:
684+
self._engine.tokenizer_manager.flush_cache()
685+
663686
return DataProto(batch=batch, non_tensor_batch={"messages": np.array(messages), "reward_scores": np.array(reward_scores)})
664687

665688
def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]:

0 commit comments

Comments
 (0)