3131from ppfleetx .utils .log import convert_timestamp_to_data , get_timestamp , logger
3232from ppfleetx .utils .version import version_check
3333
34+ def use_new_executor ():
35+ new_executor_micro_batching = os .environ .get (
36+ 'FLAGS_new_executor_micro_batching' , None
37+ )
38+ return new_executor_micro_batching in [
39+ 1 ,
40+ '1' ,
41+ True ,
42+ 'True' ,
43+ 'true' ,
44+ ]
3445
3546class AutoEngine (BasicEngine ):
3647 def __init__ (self , configs , module = None , mode = "train" ):
@@ -152,7 +163,10 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
152163
153164 total_train_batch = self ._max_steps if self ._run_mode == "step" else len (train_data_loader )
154165 total_train_step = self ._max_steps if self ._run_mode == "step" else total_train_batch * self ._num_train_epochs
155- total_eval_batch = len (valid_data_loader ) if valid_data_loader is not None else 0
166+ if use_new_executor ():
167+ total_eval_batch = len (valid_data_loader ) if valid_data_loader is not None else 0
168+ else :
169+ total_eval_batch = valid_data_loader ._steps if valid_data_loader is not None else 0
156170 valid_data_loader = valid_data_loader if valid_data_loader is not None else None
157171 eval_finished_step = 0
158172
@@ -163,26 +177,41 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
163177 if step < self ._load_recovery ["step" ]:
164178 continue
165179
166- batches = self ._validate_batch (batch )
167180
168181 fetch_list = None
169182 if self ._strategy .amp .enable :
170183 # fetch_list = ["find_infinite_scale.tmp_0", "loss_scaling_0"]
171184 fetch_list = []
172185
173186 final_loss = None
174- for micro_batch in batches :
175- with paddle .profiler .utils ._nvprof_range (iter_id = step , start = self .nvprof_start , end = self .nvprof_end ):
176- outs = self ._auto_engine .run (micro_batch , fetch_list = fetch_list , mode = "train" )
177- # pp: some devices don't have loss in outs
178- if "loss" in outs :
179- if final_loss is None :
180- final_loss = np .sum (outs ["loss" ])
181- else :
182- final_loss += np .sum (outs ["loss" ])
183-
184- if final_loss is not None and self ._accumulate_steps > 1 :
185- final_loss /= self ._accumulate_steps
187+ if use_new_executor ():
188+ batches = self ._validate_batch (batch )
189+ for micro_batch in batches :
190+ with paddle .profiler .utils ._nvprof_range (iter_id = step , start = self .nvprof_start , end = self .nvprof_end ):
191+ outs = self ._auto_engine .run (micro_batch , fetch_list = fetch_list , mode = "train" )
192+ # pp: some devices don't have loss in outs
193+ if "loss" in outs :
194+ if final_loss is None :
195+ final_loss = np .sum (outs ["loss" ])
196+ else :
197+ final_loss += np .sum (outs ["loss" ])
198+
199+ if final_loss is not None and self ._accumulate_steps > 1 :
200+ final_loss /= self ._accumulate_steps
201+ else :
202+ if self ._pp_degree == 1 and self ._accumulate_steps > 1 : # gradient merge
203+ local_steps = self ._accumulate_steps
204+ else :
205+ local_steps = 1
206+ for _ in range (local_steps ):
207+ with paddle .profiler .utils ._nvprof_range (iter_id = step , start = self .nvprof_start , end = self .nvprof_end ):
208+ outs = self ._auto_engine .run (batch , fetch_list = fetch_list , mode = "train" )
209+ # pp: some devices don't have loss in outs
210+ if "loss" in outs :
211+ if final_loss is None :
212+ final_loss = np .sum (outs ["loss" ])
213+ else :
214+ final_loss += np .sum (outs ["loss" ])
186215
187216 if final_loss is not None :
188217 train_losses .append (final_loss )
@@ -267,27 +296,49 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):
267296
268297 train_data_loader , valid_data_loader = None , None
269298 if train_dataset :
270- train_data_loader = self ._auto_engine .dataloader (
271- dataset = train_dataset ,
272- batch_size = self ._global_batch_size ,
273- steps_per_epoch = self ._max_steps ,
274- epochs = self ._num_train_epochs ,
275- collate_fn = train_dataset .collate_fn ,
276- num_workers = 1 ,
277- sample_split = train_dataset .sample_split ,
278- mode = "train" ,
279- )
299+ if use_new_executor ():
300+ train_data_loader = self ._auto_engine .dataloader (
301+ dataset = train_dataset ,
302+ batch_size = self ._global_batch_size ,
303+ steps_per_epoch = self ._max_steps ,
304+ epochs = self ._num_train_epochs ,
305+ collate_fn = train_dataset .collate_fn ,
306+ num_workers = 1 ,
307+ sample_split = train_dataset .sample_split ,
308+ mode = "train" ,
309+ )
310+ else :
311+ train_data_loader = self ._auto_engine .dataloader_from_generator (
312+ dataset = train_dataset ,
313+ batch_size = self ._global_batch_size ,
314+ steps_per_epoch = self ._max_steps ,
315+ epochs = self ._num_train_epochs ,
316+ collate_fn = train_dataset .collate_fn ,
317+ sample_split = train_dataset .sample_split ,
318+ mode = "train" ,
319+ )
280320 if valid_dataset and self ._eval_freq <= self ._max_steps :
281- valid_data_loader = self ._auto_engine .dataloader (
282- dataset = valid_dataset ,
283- batch_size = self ._global_batch_size ,
284- steps_per_epoch = self ._max_steps ,
285- epochs = self ._num_train_epochs ,
286- collate_fn = valid_dataset .collate_fn ,
287- num_workers = 1 ,
288- sample_split = valid_dataset .sample_split ,
289- mode = "eval" ,
290- )
321+ if use_new_executor ():
322+ valid_data_loader = self ._auto_engine .dataloader (
323+ dataset = valid_dataset ,
324+ batch_size = self ._global_batch_size ,
325+ steps_per_epoch = self ._max_steps ,
326+ epochs = self ._num_train_epochs ,
327+ collate_fn = valid_dataset .collate_fn ,
328+ num_workers = 1 ,
329+ sample_split = valid_dataset .sample_split ,
330+ mode = "eval" ,
331+ )
332+ else :
333+ valid_data_loader = self ._auto_engine .dataloader_from_generator (
334+ dataset = valid_dataset ,
335+ batch_size = self ._global_batch_size ,
336+ steps_per_epoch = self ._max_steps ,
337+ epochs = self ._num_train_epochs ,
338+ collate_fn = valid_dataset .collate_fn ,
339+ sample_split = valid_dataset .sample_split ,
340+ mode = "eval" ,
341+ )
291342
292343 for epoch_index in range (start_epoch , epoch ):
293344 train_epoch_start = get_timestamp ()
@@ -320,6 +371,8 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):
320371 convert_timestamp_to_data (get_timestamp () - train_start )
321372 )
322373 )
374+ if valid_data_loader and not use_new_executor ():
375+ valid_data_loader ._inner_dataloader .reset ()
323376
324377 if self .profiler :
325378 self ._profiler_done ()
@@ -328,16 +381,28 @@ def evaluate(self, epoch=1, valid_dataset=None):
328381
329382 valid_data_loader = None
330383 if valid_dataset :
331- valid_data_loader = self ._auto_engine .dataloader (
332- dataset = valid_dataset ,
333- batch_size = self ._global_batch_size ,
334- steps_per_epoch = self ._max_steps ,
335- epochs = self ._num_train_epochs ,
336- collate_fn = valid_dataset .collate_fn ,
337- num_workers = 1 ,
338- sample_split = valid_dataset .sample_split ,
339- mode = "eval" ,
340- )
384+ if use_new_executor ():
385+ valid_data_loader = self ._auto_engine .dataloader (
386+ dataset = valid_dataset ,
387+ batch_size = self ._global_batch_size ,
388+ steps_per_epoch = self ._max_steps ,
389+ epochs = self ._num_train_epochs ,
390+ collate_fn = valid_dataset .collate_fn ,
391+ num_workers = 1 ,
392+ sample_split = valid_dataset .sample_split ,
393+ mode = "eval" ,
394+ )
395+ else :
396+ valid_data_loader = self ._auto_engine .dataloader_from_generator (
397+ dataset = valid_dataset ,
398+ batch_size = self ._global_batch_size ,
399+ steps_per_epoch = self ._max_steps ,
400+ epochs = self ._num_train_epochs ,
401+ collate_fn = valid_dataset .collate_fn ,
402+ num_workers = 1 ,
403+ sample_split = valid_dataset .sample_split ,
404+ mode = "eval" ,
405+ )
341406
342407 for epoch_index in range (epoch ):
343408 eval_epoch_start = get_timestamp ()
@@ -388,16 +453,28 @@ def predict(self, epoch=1, test_dataset=None):
388453
389454 test_data_loader = None
390455 if test_dataset :
391- test_data_loader = self ._auto_engine .dataloader (
392- dataset = test_dataset ,
393- batch_size = self ._global_batch_size ,
394- steps_per_epoch = self ._max_steps ,
395- epochs = self ._num_train_epochs ,
396- collate_fn = test_dataset .collate_fn ,
397- num_workers = 1 ,
398- sample_split = test_dataset .sample_split ,
399- mode = "predict" ,
400- )
456+ if use_new_executor ():
457+ test_data_loader = self ._auto_engine .dataloader (
458+ dataset = test_dataset ,
459+ batch_size = self ._global_batch_size ,
460+ steps_per_epoch = self ._max_steps ,
461+ epochs = self ._num_train_epochs ,
462+ collate_fn = test_dataset .collate_fn ,
463+ num_workers = 1 ,
464+ sample_split = test_dataset .sample_split ,
465+ mode = "predict" ,
466+ )
467+ else :
468+ test_data_loader = self ._auto_engine .dataloader_from_generator (
469+ dataset = test_dataset ,
470+ batch_size = self ._global_batch_size ,
471+ steps_per_epoch = self ._max_steps ,
472+ epochs = self ._num_train_epochs ,
473+ collate_fn = test_dataset .collate_fn ,
474+ num_workers = 1 ,
475+ sample_split = test_dataset .sample_split ,
476+ mode = "predict" ,
477+ )
401478
402479 test_start = get_timestamp ()
403480 test_losses = []
0 commit comments