Skip to content

Custom callbacks in graphgym.train #10386

@henkvanvoorst92

Description

@henkvanvoorst92

🚀 The feature, motivation and pitch

The pytorch lightning implementation of torch_geometric.graphgym.train is pretty nice. The only problem now is that I cannot pass my own callbacks as trainer_config. Below is a simple suggested change. Would this be possible to implement?

My suggested insertion

if 'callbacks' in trainer_config.keys():
    callbacks.extend(trainer_config['callbacks'])
    trainer_config.pop('callbacks')

Current train function:

def train(
    model: GraphGymModule,
    datamodule: GraphGymDataModule,
    logger: bool = True,
    trainer_config: Optional[Dict[str, Any]] = None,
):
    callbacks = []
    if logger:
        callbacks.append(LoggerCallback())
    if cfg.train.enable_ckpt:
        ckpt_cbk = pl.callbacks.ModelCheckpoint(dirpath=get_ckpt_dir())
        callbacks.append(ckpt_cbk)
   
#insert suggestion here

    trainer_config = trainer_config or {}
    trainer = pl.Trainer(
        **trainer_config,
        enable_checkpointing=cfg.train.enable_ckpt,
        callbacks=callbacks,
        default_root_dir=cfg.out_dir,
        max_epochs=cfg.optim.max_epoch,
        accelerator=cfg.accelerator,
        devices='auto' if not torch.cuda.is_available() else cfg.devices,
    )

With this change it would be possible to pass customized callbacks to the train function:

trainer_config = {
    'callbacks': [my_callback]
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions