Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions trlx/data/ppo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ class PPORLElement:
Should be a long tensor.
:type response_tensor: torch.Tensor

:param logprobs: The log probabilities over all tokens in the vocabulary for
each token generated from the policy network
(i.e. the autoregressive model).
Should be a float tensor of same size as tokens,
with a dimension across the vocabulary.
:param logprobs: The log probabilities over the response tokens generated
by the policy network (i.e. the autoregressive model).
Should be a float tensor of same size as tokens.
:type logprobs: torch.Tensor

:param values: The values for each token generated from the value network or value head.
Expand All @@ -32,7 +30,7 @@ class PPORLElement:

query_tensor: TensorType["query_size"]
response_tensor: TensorType["response_size"]
logprobs: TensorType["response_size", "vocab_size"]
logprobs: TensorType["response_size"]
values: TensorType["response_size"]
rewards: TensorType["response_size"]

Expand Down Expand Up @@ -60,6 +58,6 @@ class PPORLBatch:

query_tensors: TensorType["batch_size", "query_size"]
response_tensors: TensorType["batch_size", "response_size"]
logprobs: TensorType["batch_size", "response_size", "vocab_size"]
logprobs: TensorType["batch_size", "response_size"]
values: TensorType["batch_size", "response_size"]
rewards: TensorType["batch_size", "response_size"]