Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):
log_probs = [o[0]["log_probs"] for o in output["output"]] # (bs, seq_size)
else:
log_probs = [o["log_probs"] for o in output["output"]] # (bs, seq_size)
log_probs = torch.cat(log_probs, dim=0).to(torch.float32)
log_probs = torch.cat(log_probs, dim=0)
if use_dynamic_bsz:
indices = output["indices"]
indices = list(itertools.chain.from_iterable(indices))
Expand All @@ -238,7 +238,7 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):
log_probs = log_probs[revert_indices]
else:
log_probs = torch.empty(
size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
size=(batch_size, response_length), device=input_ids.device
)
log_probs = log_probs.to(get_device_id())
# broadcast across pp ranks
Expand All @@ -253,7 +253,7 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):
# Note that o[0] is metrics, o[1] is entropy
if mpu.is_pipeline_last_stage(ignore_virtual=True):
entropys = torch.cat([o[1] for o in output["output"]], dim=0)
entropys = entropys.to(torch.float32)
entropys = entropys
if use_dynamic_bsz:
indices = output["indices"]
indices = list(itertools.chain.from_iterable(indices))
Expand All @@ -262,7 +262,7 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):
entropys = entropys[revert_indices]
else:
entropys = torch.empty(
size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
size=(batch_size, response_length), device=input_ids.device
)
# broadcast across pp ranks
entropys = entropys.to(get_device_id())
Expand Down Expand Up @@ -295,10 +295,10 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that
responses = input_ids[:, -response_length:]

``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability
``old_log_probs``: tensor of shape [batch_size, response_length]. The log probability
of responses.

``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of
``advantages``: tensor of shape [batch_size, response_length]. The advantages of
responses.
See PPO paper for details. https://arxiv.org/abs/1707.06347

Expand Down