@@ -33,8 +33,6 @@ def update(self, current, n_steps):
33
33
mult = 1 + proportional_error * n_steps / self .horizon
34
34
self .value *= mult
35
35
36
- # Cell
37
-
38
36
class FixedKLController :
39
37
"""Fixed KL controller."""
40
38
def __init__ (self , kl_coef ):
@@ -48,20 +46,21 @@ class AcceleratePPOModel(AccelerateRLModel):
48
46
def __init__ (self , config , train_mode = True ):
49
47
super ().__init__ (config , train_mode )
50
48
51
- self .store = PPORolloutStorage ()
49
+ self .store = PPORolloutStorage (self . tokenizer . pad_token_id )
52
50
53
51
rollout_loader = self .store .create_loader (
54
52
self .config .train .batch_size , shuffle = True
55
53
)
54
+
56
55
self .model , self .opt , self .scheduler , rollout_loader = self .accelerator .prepare (
57
56
self .model , self .opt , self .scheduler , rollout_loader
58
57
)
59
- self .store .clear_history ()
60
58
61
59
self .dummy_input = self .tokenize ("dummy input" )[
62
60
"input_ids"
63
61
] # Hack to make acclerate distributed work with model generation
64
62
63
+ self .store .clear_history ()
65
64
if config .method .target is not None :
66
65
self .kl_ctl = AdaptiveKLController (
67
66
config .method .init_kl_coef ,
@@ -78,6 +77,7 @@ def get_arch(self, config: TRLConfig):
78
77
def loss (
79
78
self , query_tensors , response_tensors , all_logprobs , all_values , all_rewards
80
79
):
80
+
81
81
lastgaelam = 0
82
82
advantages_reversed = []
83
83
gen_len = response_tensors .shape [1 ]
@@ -99,7 +99,11 @@ def loss(
99
99
advantages = advantages .detach ()
100
100
101
101
all_tokens = torch .cat ((query_tensors , response_tensors ), dim = 1 )
102
- logits , _ , vpred = self .model (all_tokens )
102
+ attention_mask = all_tokens .not_equal (self .tokenizer .pad_token_id ).long ()
103
+ position_ids = attention_mask .cumsum (- 1 ) - 1
104
+ position_ids .masked_fill_ (attention_mask .eq (0 ), 0 )
105
+
106
+ logits , _ , vpred = self .model (all_tokens , attention_mask , position_ids = position_ids )
103
107
logprob = logprobs_from_logits (logits [:, :- 1 , :], all_tokens [:, 1 :])
104
108
105
109
# only the generation part of the values/logprobs is needed
@@ -111,9 +115,12 @@ def loss(
111
115
all_values + self .config .method .cliprange_value ,
112
116
)
113
117
118
+ vf_mask = attention_mask [:, - gen_len - 1 :- 1 ]
119
+ pg_mask = attention_mask [:, - gen_len :]
120
+
114
121
vf_losses1 = (vpred - returns ) ** 2
115
122
vf_losses2 = (vpredclipped - returns ) ** 2
116
- vf_loss = 0.5 * torch .mean (torch .max (vf_losses1 , vf_losses2 ))
123
+ vf_loss = 0.5 * torch .sum (torch .max (vf_losses1 , vf_losses2 ) * vf_mask ) / vf_mask . sum ( )
117
124
118
125
kl = logprob - all_logprobs
119
126
# Record mean_kl for kl coef adjustment
@@ -127,57 +134,58 @@ def loss(
127
134
1.0 + self .config .method .cliprange ,
128
135
)
129
136
130
- pg_loss = torch .mean (torch .max (pg_losses , pg_losses2 ))
137
+ pg_loss = torch .sum (torch .max (pg_losses , pg_losses2 ) * pg_mask ) / pg_mask . sum ( )
131
138
132
139
model_loss = pg_loss + self .config .method .vf_coef * vf_loss
133
140
return model_loss , pg_loss , vf_loss
134
141
135
142
def post_epoch_callback (self ):
136
- # TODO(dahoas): are experiences being made for dataloaders on each process or same dataloader
137
143
self .epoch += 1
138
144
self .store .clear_history ()
139
145
self .orch .make_experience (
140
146
self .config .method .num_rollouts , self .iter_count
141
147
) # Collect more rollouts for training
142
148
143
149
def post_backward_callback (self ):
144
- batch = self .logs ["batch" ]
145
150
# Update kl_coefficient
146
151
self .kl_ctl .update (self .mean_kl ,self .config .train .batch_size )
147
- # Run evaluation
152
+
153
+ all_samples = []
154
+ for prompts in self .eval_dataloader :
155
+ query , response , _ = self .act (prompts )
156
+ pad_token = self .tokenizer .eos_token_id if self .tokenizer else 0
157
+ samples = torch .hstack ((query , response ))
158
+ all_samples .append (F .pad (samples , (0 , self .max_length - samples .shape [1 ]), value = pad_token ))
159
+
160
+ samples = self .accelerator .gather (torch .vstack (all_samples ))
161
+
148
162
if self .accelerator .is_main_process :
149
- if (
150
- self .iter_count % self .config .train .eval_interval == 0
151
- or self .iter_count <= self .config .method .ppo_epochs
152
- ):
153
- text = self .tokenizer .batch_decode (batch .query_tensors )
154
- eval_batch : PromptBatch = PromptBatch (
155
- text = text , tokens = batch .query_tensors
156
- )
157
- query_tensors , response_tensors , response_text = self .act (eval_batch )
158
- gen_texts = [q + r for q , r in zip (eval_batch .text , response_text )]
159
- scores = self .orch .score (gen_texts )
160
- mean_score = torch .mean (scores ).item ()
161
- rows = list (zip (gen_texts , scores .tolist ()))
162
- stats = {
163
- "mean_score" : mean_score ,
164
- "responses" : wandb .Table (columns = ["response" , "score" ], rows = rows ),
165
- "pg_loss" : self .logs ["pg_loss" ],
166
- "vf_loss" : self .logs ["vf_loss" ],
167
- "kl_coef" : self .kl_ctl .value ,
168
- }
169
- self .accelerator .log (stats , step = self .iter_count )
170
- self .accelerator .print (
171
- "Step: {}, Mean score: {}, pg_loss: {}, vf_loss: {}, kl_coef: {}" .format (
172
- self .iter_count , mean_score , stats ["pg_loss" ], stats ["vf_loss" ], self .kl_ctl .value ,
173
- )
163
+ samples = self .tokenizer .batch_decode (samples , skip_special_tokens = True )
164
+ scores = self .orch .score (samples )
165
+ mean_score = torch .mean (torch .as_tensor (scores )).item ()
166
+ rows = list (zip (samples , scores ))
167
+ stats = {
168
+ "mean_score" : mean_score ,
169
+ "responses" : wandb .Table (columns = ["response" , "score" ], rows = rows ),
170
+ "pg_loss" : self .logs ["pg_loss" ],
171
+ "vf_loss" : self .logs ["vf_loss" ],
172
+ "kl_coef" : self .kl_ctl .value ,
173
+ }
174
+
175
+ self .accelerator .log (stats , step = self .iter_count )
176
+ self .accelerator .print (
177
+ "Step: {}, Mean score: {}, pg_loss: {}, vf_loss: {}, kl_coef: {}" .format (
178
+ self .iter_count , mean_score , stats ["pg_loss" ], stats ["vf_loss" ], self .kl_ctl .value ,
174
179
)
180
+ )
175
181
176
182
def learn (self ):
183
+ self .eval_dataloader = self .eval_pipeline .create_loader (self .config .train .batch_size )
184
+
177
185
rollout_loader = self .store .create_loader (
178
186
self .config .train .batch_size , shuffle = True
179
187
)
180
- rollout_loader = self .accelerator .prepare (rollout_loader )
188
+ rollout_loader , self . eval_dataloader = self .accelerator .prepare (rollout_loader , self . eval_dataloader )
181
189
182
190
self .iter_count = 0
183
191
self .epoch = 0
@@ -204,8 +212,7 @@ def learn(self):
204
212
"batch" : batch ,
205
213
"rewards" : rewards ,
206
214
}
207
- # self.post_backward_callback()
208
- # exit()
215
+
209
216
self .opt .zero_grad ()
210
217
self .accelerator .backward (loss )
211
218
self .opt .step ()
0 commit comments