Skip to content

feat: Add support for multiple histories #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/art/local/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def packed_tensors_from_tokenized_results(
assistant_mask[-1].extend(result.assistant_mask)
logprobs[-1].extend(result.logprobs)
advantages[-1].extend([result.advantage] * len(result.token_ids))
weights[-1].extend(
[1 / (sum(result.assistant_mask) + 1e-6)] * len(result.token_ids)
)
weights[-1].extend([result.weight] * len(result.token_ids))
if truncate_long_results:
token_ids[-1] = token_ids[-1][:seq_len]
group_ids[-1] = group_ids[-1][:seq_len]
Expand Down
62 changes: 35 additions & 27 deletions src/art/local/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,32 @@
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from typing import cast, Generator

from ..trajectories import Trajectory, TrajectoryGroup
from ..trajectories import get_messages, History, TrajectoryGroup


@dataclass
class TokenizedResult:
trajectory: Trajectory
advantage: float
chat: str
tokens: list[str]
token_ids: list[int]
input_pos: list[int]
assistant_mask: list[int]
logprobs: list[float]
weight: float = 0.0
prompt_id: int = 0
prompt_length: int = 0

def without_prompt(self) -> "TokenizedResult":
return TokenizedResult(
trajectory=self.trajectory,
advantage=self.advantage,
chat=self.chat,
tokens=self.tokens[self.prompt_length :],
token_ids=self.token_ids[self.prompt_length :],
input_pos=self.input_pos[self.prompt_length :],
assistant_mask=self.assistant_mask[self.prompt_length :],
logprobs=self.logprobs[self.prompt_length :],
weight=self.weight,
prompt_id=self.prompt_id,
prompt_length=0,
)
Expand All @@ -57,13 +57,27 @@ def tokenize_trajectory_groups(
# Skip trajectories with no advantage
if advantage == 0:
continue
if result := tokenize_trajectory(
tokenizer,
trajectory,
advantage,
allow_training_without_logprobs,
):
results.append(result)
trajectory_results: list[TokenizedResult] = []
for history in [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably doesn't matter but we could make Trajectory inherit from History to avoid having to create a new one here.

History(
messages_and_choices=trajectory.messages_and_choices,
tools=trajectory.tools,
),
*trajectory.additional_histories,
]:
if result := tokenize_trajectory(
tokenizer,
history,
advantage,
allow_training_without_logprobs,
):
trajectory_results.append(result)
weight = 1 / (
sum(sum(result.assistant_mask) for result in trajectory_results) + 1e-6
)
for result in trajectory_results:
result.weight = weight
results.extend(trajectory_results)
# Choose a random prompt id
prompt_id = random.randint(-(2**63), 2**63 - 1)
# Find the longest shared prefix
Expand All @@ -88,7 +102,7 @@ def tokenize_trajectory_groups(

def tokenize_trajectory(
tokenizer: "PreTrainedTokenizerBase",
trajectory: Trajectory,
history: History,
advantage: float,
allow_training_without_logprobs: bool,
) -> TokenizedResult | None:
Expand All @@ -97,7 +111,7 @@ def tokenize_trajectory(
"""
# Find the index of the last assistant message
last_assistant_index = -1
for i, message_or_choice in enumerate(trajectory.messages_and_choices):
for i, message_or_choice in enumerate(history.messages_and_choices):
if (
isinstance(message_or_choice, dict)
and message_or_choice["role"] == "assistant"
Expand All @@ -111,26 +125,21 @@ def tokenize_trajectory(
# If there are no trainable assistant messages, return None
if last_assistant_index == -1:
return None
trajectory = trajectory.model_copy(
update={
"messages_and_choices": trajectory.messages_and_choices[
: last_assistant_index + 1
]
}
)
messages_and_choices = history.messages_and_choices[: last_assistant_index + 1]
messages = get_messages(messages_and_choices)
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], trajectory.messages()),
tools=trajectory.tools, # type: ignore
cast(list[dict], messages),
tools=history.tools, # type: ignore
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], trajectory.messages()),
tools=trajectory.tools, # type: ignore
cast(list[dict], messages),
tools=history.tools, # type: ignore
),
)
sentinal_token_id = max(
Expand All @@ -151,10 +160,10 @@ def tokenize_trajectory(
"content": sentinal_token,
}
)
for message_or_choice in trajectory.messages_and_choices
for message_or_choice in messages_and_choices
],
),
tools=trajectory.tools, # type: ignore
tools=history.tools, # type: ignore
return_dict=True,
return_assistant_token_mask=allow_training_without_logprobs,
),
Expand All @@ -166,7 +175,7 @@ def tokenize_trajectory(
else [0] * len(token_ids)
)
logprobs = [float("nan")] * len(token_ids)
for message_or_choice in trajectory.messages_and_choices:
for message_or_choice in messages_and_choices:
if isinstance(message_or_choice, dict):
continue
choice = message_or_choice
Expand All @@ -185,7 +194,6 @@ def tokenize_trajectory(
)
assistant_mask[sentinal_index : sentinal_index + 1] = [1] * len(token_logprobs)
return TokenizedResult(
trajectory=trajectory,
advantage=advantage,
chat=chat,
tokens=[tokenizer.decode(token_id) for token_id in token_ids],
Expand Down
66 changes: 38 additions & 28 deletions src/art/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@ class PydanticException(pydantic.BaseModel):
traceback: str


class History(pydantic.BaseModel):
messages_and_choices: MessagesAndChoices
tools: Tools | None = None


class Trajectory(pydantic.BaseModel):
messages_and_choices: MessagesAndChoices
tools: Tools | None = None
additional_histories: list[History] = []
reward: float
metrics: dict[str, float | int | bool] = {}
metadata: dict[str, MetadataValue] = {}
Expand Down Expand Up @@ -54,34 +60,7 @@ def __str__(self) -> str:
return f"Trajectory(reward={self.reward}, metrics={self.metrics}, metadata={self.metadata})"

def messages(self) -> Messages:
return [
(
{
"role": "assistant",
"content": message_or_choice.message.content,
**(
{
"tool_calls": [
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
for tool_call in message_or_choice.message.tool_calls
]
}
if message_or_choice.message.tool_calls
else {}
), # type: ignore
}
if isinstance(message_or_choice, Choice)
else message_or_choice
)
for message_or_choice in self.messages_and_choices
]
return get_messages(self.messages_and_choices)

# Used for logging to console
def for_logging(self) -> dict[str, Any]:
Expand All @@ -102,6 +81,37 @@ def for_logging(self) -> dict[str, Any]:
return loggable_dict


def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
return [
(
{
"role": "assistant",
"content": message_or_choice.message.content,
**(
{
"tool_calls": [
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
for tool_call in message_or_choice.message.tool_calls
]
}
if message_or_choice.message.tool_calls
else {}
), # type: ignore
}
if isinstance(message_or_choice, Choice)
else message_or_choice
)
for message_or_choice in messages_and_choices
]


class TrajectoryGroup(pydantic.BaseModel):
trajectories: list[Trajectory]
metadata: dict[str, MetadataValue] = {}
Expand Down