@@ -147,6 +147,7 @@ def evaluate(self):
147
147
"""Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided"""
148
148
stats = {}
149
149
all_samples = []
150
+ prompts_sizes = []
150
151
generate_time = time ()
151
152
for prompts in self .eval_dataloader :
152
153
if isinstance (prompts , torch .Tensor ):
@@ -165,34 +166,54 @@ def evaluate(self):
165
166
value = pad_token ,
166
167
)
167
168
)
168
- stats ["generate_time" ] = time () - generate_time
169
+ sizes = torch .tensor (prompts .input_ids .shape [1 ]).repeat (
170
+ len (prompts .input_ids )
171
+ )
172
+ prompts_sizes .append (sizes .to (samples .device ))
173
+
174
+ stats ["time/generate" ] = time () - generate_time
169
175
170
176
samples = self .accelerator .gather (torch .vstack (all_samples ))
177
+ prompts_sizes = self .accelerator .gather (torch .hstack (prompts_sizes ))
171
178
172
179
if self .accelerator .is_main_process :
173
180
if self .tokenizer :
174
- samples = self .tokenizer .batch_decode (samples , skip_special_tokens = True )
181
+ str_samples = self .tokenizer .batch_decode (
182
+ samples , skip_special_tokens = True
183
+ )
184
+
185
+ prompts , responses = [], []
186
+ for sample , prompt_size in zip (samples , prompts_sizes ):
187
+ prompts .append (sample [:prompt_size ])
188
+ responses .append (sample [prompt_size :])
189
+
190
+ str_prompts = self .tokenizer .batch_decode (
191
+ prompts , skip_special_tokens = True
192
+ )
193
+ str_responses = self .tokenizer .batch_decode (
194
+ responses , skip_special_tokens = True
195
+ )
175
196
176
- if isinstance (samples [0 ], str ):
177
- columns_data = [samples ]
197
+ if isinstance (str_samples [0 ], str ):
198
+ columns_data = [str_prompts , str_responses ]
178
199
else :
179
200
columns_data = [samples .tolist ()]
180
- columns = ["samples " ]
201
+ columns = ["prompt" , "response " ]
181
202
182
203
# in online setting, compute the reward for validation
183
204
if self .reward_fn :
184
- rewards = torch .as_tensor (self .reward_fn (samples ), dtype = torch .float )
205
+ rewards = torch .tensor (self .reward_fn (str_samples ), dtype = torch .float )
185
206
mean_reward = rewards .mean ()
186
207
columns .append ("reward" )
187
208
columns_data .append (rewards )
188
- stats ["mean_reward " ] = mean_reward
209
+ stats ["reward/mean " ] = mean_reward
189
210
print (f"{ mean_reward = } " )
190
211
191
212
# additionally log any other metrics
192
213
if self .metric_fn :
193
214
metric_time = time ()
194
- metrics = self .metric_fn (samples )
195
- stats ["metric_time " ] = time () - metric_time
215
+ metrics = self .metric_fn (str_samples )
216
+ stats ["time/metric " ] = time () - metric_time
196
217
197
218
mean_metrics = {
198
219
f"metrics/{ k } " : torch .as_tensor (xs ).mean (- 1 )
@@ -258,8 +279,8 @@ def learn(self):
258
279
if self .iter_count % self .config .train .checkpoint_interval == 0 :
259
280
self .save ()
260
281
261
- stats ["forward_time " ] = forward_time
262
- stats ["backward_time " ] = backward_time
282
+ stats ["time/forward " ] = forward_time
283
+ stats ["time/backward " ] = backward_time
263
284
264
285
if self .iter_count % self .config .train .eval_interval == 0 :
265
286
results = self .evaluate ()
0 commit comments