-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Open
Labels
Description
🚀 The feature, motivation and pitch
The pytorch lightning implementation of torch_geometric.graphgym.train is pretty nice. The only problem now is that I cannot pass my own callbacks as trainer_config. Below is a simple suggested change. Would this be possible to implement?
My suggested insertion
if 'callbacks' in trainer_config.keys():
callbacks.extend(trainer_config['callbacks'])
trainer_config.pop('callbacks')
Current train function:
def train(
model: GraphGymModule,
datamodule: GraphGymDataModule,
logger: bool = True,
trainer_config: Optional[Dict[str, Any]] = None,
):
callbacks = []
if logger:
callbacks.append(LoggerCallback())
if cfg.train.enable_ckpt:
ckpt_cbk = pl.callbacks.ModelCheckpoint(dirpath=get_ckpt_dir())
callbacks.append(ckpt_cbk)
#insert suggestion here
trainer_config = trainer_config or {}
trainer = pl.Trainer(
**trainer_config,
enable_checkpointing=cfg.train.enable_ckpt,
callbacks=callbacks,
default_root_dir=cfg.out_dir,
max_epochs=cfg.optim.max_epoch,
accelerator=cfg.accelerator,
devices='auto' if not torch.cuda.is_available() else cfg.devices,
)
With this change it would be possible to pass customized callbacks to the train function:
trainer_config = {
'callbacks': [my_callback]
}