1
1
from time import time
2
- from typing import Callable , Optional
3
2
4
3
import ray
5
4
import torch
5
+ import torch .nn .functional as F
6
6
7
7
from trlx .data .accelerate_base_datatypes import PromptBatch
8
8
from trlx .data .ppo_types import PPORLElement
@@ -24,8 +24,6 @@ def __init__(
24
24
self ,
25
25
trainer : BaseRLTrainer ,
26
26
pipeline : BasePipeline ,
27
- reward_fn : Callable ,
28
- metric_fn : Optional [Callable ] = None ,
29
27
chunk_size : int = 512 ,
30
28
):
31
29
self .pipeline = pipeline
@@ -43,8 +41,6 @@ def __init__(
43
41
self .ref_model .to (self .trainer .accelerator .device )
44
42
45
43
self .trainer .orch = self
46
- self .trainer .reward_fn = reward_fn
47
- self .trainer .metric_fn = metric_fn
48
44
49
45
self .running = RunningMoments ()
50
46
self .ref_mean = self .trainer .config .method .ref_mean
@@ -65,9 +61,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
65
61
stats = {}
66
62
clock = Clock ()
67
63
while len (ppo_rl_elements ) < num_rollouts :
68
- if self .trainer .accelerator .is_main_process :
69
- print (f"Making experience { len (ppo_rl_elements )} / { num_rollouts } " )
70
-
71
64
# Get next batch in prompt dataset and refresh if exhausted
72
65
try :
73
66
batch : PromptBatch = next (self .pipeline_iterator )
@@ -79,30 +72,38 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
79
72
samples = self .trainer .generate (** batch )
80
73
stats ["time/exp_generate" ] = time () - exp_generate_time
81
74
82
- if self .trainer .config .model .model_arch_type == "seq2seq" :
83
- response_tensors = samples
84
- else :
85
- query_tensors = batch .input_ids
86
- response_tensors = samples [:, query_tensors .shape [1 ] :]
87
-
88
- texts = self .trainer .tokenizer .batch_decode (
89
- samples , skip_special_tokens = True
75
+ query_tensors = batch .input_ids
76
+ device = samples .device
77
+ str_samples , str_prompts , str_outputs = self .trainer .decode (
78
+ query_tensors , samples
90
79
)
91
80
92
- if self .trainer .config .model .model_arch_type == "seq2seq" :
93
- articles = self .trainer .tokenizer .batch_decode (
94
- batch .input_ids , skip_special_tokens = True
81
+ # Convert trimmed samples back into tensors for another head pass
82
+ # This can be defered, instead letting the pass to made over the original samples
83
+ # after unbinding and truncating operations lower are fixed
84
+ outputs = self .trainer .tokenizer (str_outputs ).input_ids
85
+ outputs = list (map (torch .LongTensor , outputs ))
86
+ maxsize = max (map (len , outputs ))
87
+ outputs = [
88
+ F .pad (
89
+ output ,
90
+ (0 , maxsize - len (output )),
91
+ value = self .trainer .tokenizer .pad_token_id ,
95
92
)
96
- sep_token = self .trainer .tokenizer .sep_token
97
- texts = [
98
- f"{ article } { sep_token } { response } "
99
- for article , response in zip (articles , texts )
100
- ]
93
+ for output in outputs
94
+ ]
95
+ response_tensors = torch .vstack (outputs ).to (device )
101
96
102
97
exp_score_time = time ()
98
+
103
99
scores = torch .tensor (
104
- self .score (texts ), device = samples .device , dtype = torch .float
105
- )
100
+ self .trainer .reward_fn (
101
+ samples = str_samples ,
102
+ prompts = str_prompts ,
103
+ outputs = str_outputs ,
104
+ ),
105
+ dtype = float ,
106
+ ).to (device )
106
107
stats ["time/exp_score" ] = time () - exp_score_time
107
108
108
109
# store statistics of the initial rollout as reference
@@ -125,9 +126,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
125
126
126
127
# Precompute logprobs, values
127
128
if self .trainer .config .model .model_arch_type == "seq2seq" :
128
- response_tensors = response_tensors
129
- attention_mask = batch .attention_mask .to (response_tensors .device )
130
- query_tensors = batch .input_ids .to (response_tensors .device )
129
+ attention_mask = batch .attention_mask .to (device )
130
+ query_tensors = batch .input_ids .to (device )
131
131
with torch .no_grad ():
132
132
outputs = self .trainer .model (
133
133
input_ids = query_tensors ,
@@ -150,12 +150,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
150
150
).logits
151
151
else :
152
152
all_tokens = torch .cat (
153
- (query_tensors .to (response_tensors . device ), response_tensors ), dim = 1
153
+ (query_tensors .to (device ), response_tensors ), dim = 1
154
154
)
155
155
attention_mask = (
156
156
all_tokens .not_equal (self .trainer .tokenizer .pad_token_id )
157
157
.long ()
158
- .to (all_tokens . device )
158
+ .to (device )
159
159
)
160
160
with torch .no_grad ():
161
161
logits , * _ , values = self .trainer .model (
@@ -175,7 +175,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
175
175
attention_mask = attention_mask ,
176
176
return_dict = False ,
177
177
)
178
- ref_logits = ref_logits .to (self . trainer . accelerator . device )
178
+ ref_logits = ref_logits .to (device )
179
179
180
180
if self .trainer .config .model .model_arch_type == "seq2seq" :
181
181
logprobs = logprobs_from_logits (
0 commit comments