Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from pytorch_widedeep.wdtypes import List, Tensor, Optional, Literal
from pytorch_widedeep.wdtypes import List, Tensor, Literal, Optional
from pytorch_widedeep.utils.hf_utils import get_model_class, get_config_and_model
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
Expand Down
10 changes: 9 additions & 1 deletion pytorch_widedeep/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def fit( # noqa: C901
eval_dataloader: Optional[CustomDataLoader] = None,
feature_importance_sample_size: Optional[int] = None,
finetune: bool = False,
stop_after_finetuning: bool = False,
**kwargs,
):
r"""Fit method.
Expand Down Expand Up @@ -374,6 +375,8 @@ def fit( # noqa: C901
For details on how these routines work, please see the Examples
section in this documentation and the Examples folder in the repo. <br/>
Param Alias: `warmup`
stop_after_finetuning: bool, default=False
stop the training after the finetuning process.

Other Parameters
----------------
Expand Down Expand Up @@ -451,14 +454,19 @@ def fit( # noqa: C901
if finetune:
self.with_finetuning: bool = True
self._do_finetune(train_loader, **finetune_args)
if self.verbose:
if self.verbose and not stop_after_finetuning:
print(
"Fine-tuning (or warmup) of individual components completed. "
"Training the whole model for {} epochs".format(n_epochs)
)
else:
self.with_finetuning = False

if stop_after_finetuning:
if self.verbose:
print("Stopping after finetuning")
return

self.callback_container.on_train_begin(
{
"batch_size": batch_size,
Expand Down
8 changes: 7 additions & 1 deletion pytorch_widedeep/training/trainer_from_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def fit( # type: ignore[override] # noqa: C901
n_epochs: int = 1,
validation_freq: int = 1,
finetune: bool = False,
stop_after_finetuning: bool = False,
**kwargs,
):

Expand All @@ -230,14 +231,19 @@ def fit( # type: ignore[override] # noqa: C901
if finetune:
self.with_finetuning: bool = True
self._do_finetune(train_loader, **finetune_args)
if self.verbose:
if self.verbose and not stop_after_finetuning:
print(
"Fine-tuning (or warmup) of individual components completed. "
"Training the whole model for {} epochs".format(n_epochs)
)
else:
self.with_finetuning = False

if stop_after_finetuning:
if self.verbose:
print("Stopping after finetuning")
return

self.callback_container.on_train_begin(
{
"batch_size": train_loader.batch_size,
Expand Down
34 changes: 34 additions & 0 deletions tests/test_model_functioning/test_fit_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,37 @@ def test_multiclass_warning():

with pytest.raises(ValueError):
trainer = Trainer(model, loss="multiclass", verbose=0) # noqa: F841


##############################################################################
# Test stop_after_finetuning
##############################################################################


def test_stop_after_finetuning():
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[32, 16],
mlp_dropout=[0.5, 0.5],
)
model = WideDeep(wide=wide, deeptabular=deeptabular)

trainer = Trainer(model, objective="binary", verbose=0)

trainer.fit(
X_wide=X_wide,
X_tab=X_tab,
target=target_binary,
batch_size=4,
finetune=True,
finetune_epochs=4,
stop_after_finetuning=True,
)

preds = trainer.predict(X_wide=X_wide, X_tab=X_tab)

assert preds.shape[0] == 32
assert trainer.with_finetuning