Skip to content

Commit dc08470

Browse files
committed
add arg num_workers for ernie3.0
1 parent c2f7e31 commit dc08470

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

tests/test_tipc/benchmark/modules/ernie3_for_sequence_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def create_data_loader(self, args, **kwargs):
145145
train_loader = DataLoader(
146146
dataset=train_ds,
147147
batch_sampler=train_batch_sampler,
148-
num_workers=4, # when paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0
148+
num_workers=args.num_workers, # when paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0
149149
)
150150

151151
self.num_batch = len(train_loader)

tests/test_tipc/benchmark/options.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ def get_parser():
136136
parser.add_argument("--epoch", type=int, default=10, help="Number of epochs. ")
137137

138138
parser.add_argument("--generated_inputs", action="store_true", help="Use generated inputs. ")
139+
parser.add_argument(
140+
"--num_workers",
141+
type=int,
142+
default=4,
143+
help="num_workers of dataloader. When paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0 ",
144+
)
139145

140146
# For benchmark.
141147
parser.add_argument(

0 commit comments

Comments
 (0)