Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
f0ce4d0
Don't create an optimizer for the transformer if there is no learning…
AngledLuffa Feb 14, 2024
0a0f172
Add a flag to set a different sized batch for the 2nd optimizer. Let…
AngledLuffa Feb 14, 2024
eb8558b
Keep a map of dependency optimizers and iterate them in loops instead…
AngledLuffa Feb 4, 2024
0dcf712
Make checkpoints less often - every 500 by default, with occasional o…
AngledLuffa Feb 5, 2024
0687c37
Add a warmup scheduler to finetuning the depparse transformer
AngledLuffa Feb 9, 2024
96fb54e
Fix --no_checkpoint option
AngledLuffa Feb 11, 2024
e75cb31
Oops, fix usage of weight_decay in the common function to build the b…
AngledLuffa Feb 11, 2024
331d3ae
Add flags for using weight decay in the first round of optimizer in t…
AngledLuffa Feb 11, 2024
c087fcf
Move --use_peft and the checking of --use_peft vs --bert_finetune to …
AngledLuffa Feb 12, 2024
df37c25
Set the default number of hidden layers used from the transformer to 4
AngledLuffa Feb 13, 2024
19fa8c0
Force bert saved when loading / saving a model
AngledLuffa Feb 14, 2024
b9ee08f
Add a PEFT wrapper to the dependency parser
AngledLuffa Feb 14, 2024
f99e7aa
Saved depparse models were missing the last score update
AngledLuffa Feb 14, 2024
59085d3
Refactor the code which runs the predictions in depparse. Use this t…
AngledLuffa Feb 14, 2024
4c7d46f
Add a linear warmup scheduler for the 2nd optimizer pass
AngledLuffa Feb 14, 2024
72a3cd5
Continue training from the current global_step count rather than rewi…
AngledLuffa Feb 14, 2024
cd50fc6
Refactor a method that builds the LoRA wrapper around a bert model, s…
AngledLuffa Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions stanza/models/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def build_argparse():
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
parser.add_argument('--bert_finetune', default=False, action='store_true', help="Finetune the Bert model")
parser.add_argument('--use_peft', default=False, action='store_true', help="Finetune Bert using peft")
parser.add_argument('--bert_learning_rate', default=0.01, type=float, help='Scale the learning rate for transformer finetuning by this much')
parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much')

Expand Down Expand Up @@ -307,7 +306,7 @@ def parse_args(args=None):
"""
parser = build_argparse()
args = parser.parse_args(args)
resolve_peft_args(args)
resolve_peft_args(args, tlogger)

if args.wandb_name:
args.wandb = True
Expand All @@ -319,9 +318,6 @@ def parse_args(args=None):
args.momentum = DEFAULT_MOMENTUM.get(args.optim, None)
if args.learning_rate is None:
args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim, None)
if args.use_peft and not args.bert_finetune:
logger.info("--use_peft set. setting --bert_finetune as well")
args.bert_finetune = True

return args

Expand Down
14 changes: 2 additions & 12 deletions stanza/models/classifiers/cnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from stanza.models.common.bert_embedding import extract_bert_embeddings
from stanza.models.common.data import get_long_tensor, sort_all
from stanza.models.common.foundation_cache import load_bert
from stanza.models.common.peft_config import build_peft_wrapper
from stanza.models.common.vocab import PAD_ID, UNK_ID

"""
Expand Down Expand Up @@ -131,18 +132,7 @@ def __init__(self, pretrain, extra_vocab, labels,
raise ValueError("Got a forward charlm as a backward charlm!")

if self.config.use_peft:
# Hide import so that the peft dependency is optional
from peft import LoraConfig, get_peft_model
logger.info("Creating lora adapter with rank %d and alpha %d", self.config.lora_rank, self.config.lora_alpha)
peft_config = LoraConfig(inference_mode=False,
r=self.config.lora_rank,
target_modules=self.config.lora_target_modules,
lora_alpha=self.config.lora_alpha,
lora_dropout=self.config.lora_dropout,
modules_to_save=self.config.lora_modules_to_save,
bias="none")

bert_model = get_peft_model(bert_model, peft_config)
bert_model = build_peft_wrapper(bert_model, vars(self.config), tlogger)
# we use a peft-specific pathway for saving peft weights
self.add_unsaved_module('bert_model', bert_model)
self.bert_model.train()
Expand Down
23 changes: 22 additions & 1 deletion stanza/models/common/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def add_peft_args(parser):
parser.add_argument('--lora_target_modules', type=str, default=None, help="Comma separated list of LoRA targets. Default will be '%s' or a model-specific parameter" % DEFAULT_LORA_TARGETS)
parser.add_argument('--lora_modules_to_save', type=str, default=None, help="Comma separated list of modules to save (eg, fully tune) when using LoRA. Default will be '%s' or a model-specific parameter" % DEFAULT_LORA_SAVE)

parser.add_argument('--use_peft', default=False, action='store_true', help="Finetune Bert using peft")

def resolve_peft_args(args):
def resolve_peft_args(args, logger):
if not hasattr(args, 'bert_model'):
return

Expand All @@ -55,3 +56,23 @@ def resolve_peft_args(args):
args.lora_modules_to_save = []
else:
args.lora_modules_to_save = args.lora_modules_to_save.split(",")

if hasattr(args, 'bert_finetune'):
if args.use_peft and not args.bert_finetune:
logger.info("--use_peft set. setting --bert_finetune as well")
args.bert_finetune = True

def build_peft_wrapper(bert_model, args, logger):
# Hide import so that the peft dependency is optional
from peft import LoraConfig, get_peft_model
logger.info("Creating lora adapter with rank %d and alpha %d", args['lora_rank'], args['lora_alpha'])
peft_config = LoraConfig(inference_mode=False,
r=args['lora_rank'],
target_modules=args['lora_target_modules'],
lora_alpha=args['lora_alpha'],
lora_dropout=args['lora_dropout'],
modules_to_save=args['lora_modules_to_save'],
bias="none")

bert_model = get_peft_model(bert_model, peft_config)
return bert_model
4 changes: 2 additions & 2 deletions stanza/models/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def get_split_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=
optimizers = {
"general_optimizer": dispatch_optimizer(name, parameters, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
}
if bert_parameters is not None:
if bert_parameters is not None and bert_learning_rate > 0.0:
if bert_weight_decay is not None:
bert_parameters['weight_decay'] = bert_weight_decay
extra_args['weight_decay'] = bert_weight_decay
optimizers["bert_optimizer"] = dispatch_optimizer(name, bert_parameters, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
return optimizers

Expand Down
3 changes: 3 additions & 0 deletions stanza/models/depparse/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)

def set_batch_size(self, batch_size):
self.batch_size = batch_size

def reshuffle(self):
data = [y for x in self.data for y in x]
self.data = self.chunk_batches(data)
Expand Down
17 changes: 13 additions & 4 deletions stanza/models/depparse/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
from stanza.models.common.foundation_cache import load_bert, load_charlm
from stanza.models.common.hlstm import HighwayLSTM
from stanza.models.common.dropout import WordDropout
from stanza.models.common.peft_config import build_peft_wrapper
from stanza.models.common.vocab import CompositeVocab
from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
from stanza.models.common import utils

logger = logging.getLogger('stanza')

class Parser(nn.Module):
def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None):
def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, force_bert_saved=False):
super().__init__()

self.vocab = vocab
Expand Down Expand Up @@ -83,7 +84,14 @@ def add_unsaved_module(name, module):
# an average of layers 2, 3, 4 will be used
# (for historic reasons)
self.bert_layer_mix = None
if self.args.get('bert_finetune', False):
if self.args.get('use_peft', False):
bert_model, bert_tokenizer = load_bert(self.args['bert_model'], foundation_cache)
bert_model = build_peft_wrapper(bert_model, self.args, logger)
# we use a peft-specific pathway for saving peft weights
add_unsaved_module('bert_model', bert_model)
add_unsaved_module('bert_tokenizer', bert_tokenizer)
self.bert_model.train()
elif self.args.get('bert_finetune', False) or force_bert_saved:
bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
self.bert_model = bert_model
add_unsaved_module('bert_tokenizer', bert_tokenizer)
Expand Down Expand Up @@ -127,7 +135,7 @@ def add_unsaved_module(name, module):
def log_norms(self):
utils.log_norms(self)

def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text, detach=True):
def pack(x):
return pack_padded_sequence(x, sentlens, batch_first=True)

Expand Down Expand Up @@ -181,9 +189,10 @@ def pack(x):

if self.bert_model is not None:
device = next(self.parameters()).device
detach = detach or not self.args.get('bert_finetune', False) or not self.training
processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=True,
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
detach=not self.args.get('bert_finetune', False))
detach=detach)
if self.bert_layer_mix is not None:
# add the average so that the default behavior is to
# take an average of the N layers, and anything else
Expand Down
78 changes: 66 additions & 12 deletions stanza/models/depparse/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import torch
from torch import nn

try:
import transformers
except ImportError:
pass

from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.common import utils, loss
from stanza.models.common.foundation_cache import NoTransformerFoundationCache
Expand Down Expand Up @@ -67,14 +72,36 @@ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None,
wandb.watch(self.model, log_freq=4, log="all", log_graph=True)

def __init_optim(self):
# TODO: can get rid of args.get when models are rebuilt
if (self.args.get("second_stage", False) and self.args.get('second_optim')):
self.optimizer = utils.get_optimizer(self.args['second_optim'], self.model,
self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6,
bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0))
self.optimizer = utils.get_split_optimizer(self.args['second_optim'], self.model,
self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6,
bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0),
is_peft=self.args.get('use_peft', False))
else:
self.optimizer = utils.get_split_optimizer(self.args['optim'], self.model,
self.args['lr'], betas=(0.9, self.args['beta2']),
eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0),
weight_decay=self.args.get('weight_decay', None),
bert_weight_decay=self.args.get('bert_weight_decay', 0.0),
is_peft=self.args.get('use_peft', False))
self.scheduler = {}
if self.args.get("second_stage", False) and self.args.get('second_optim'):
if self.args.get('second_warmup_steps', None):
for name, optimizer in self.optimizer.items():
name = name + "_scheduler"
warmup_scheduler = transformers.get_constant_schedule_with_warmup(optimizer, self.args['second_warmup_steps'])
self.scheduler[name] = warmup_scheduler
else:
self.optimizer = utils.get_optimizer(self.args['optim'], self.model,
self.args['lr'], betas=(0.9, self.args['beta2']),
eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0))
if "bert_optimizer" in self.optimizer:
zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer["bert_optimizer"], factor=0, total_iters=self.args['bert_start_finetuning'])
warmup_scheduler = transformers.get_constant_schedule_with_warmup(
self.optimizer["bert_optimizer"],
self.args['bert_warmup_steps'])
self.scheduler["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR(
self.optimizer["bert_optimizer"],
schedulers=[zero_scheduler, warmup_scheduler],
milestones=[self.args['bert_start_finetuning']])

def update(self, batch, eval=False):
device = next(self.model.parameters()).device
Expand All @@ -85,15 +112,21 @@ def update(self, batch, eval=False):
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)
for opt in self.optimizer.values():
opt.zero_grad()
# if there is no bert optimizer, we will tell the model to detach bert so it uses less GPU
detach = "bert_optimizer" not in self.optimizer
loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text, detach=detach)
loss_val = loss.data.item()
if eval:
return loss_val

loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this causes problems with PEFT

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The detach setting just controls whether or not the gradients of the transformer embedding are kept. I don't think it should affect peft in any way if the model is frozen

self.optimizer.step()
for opt in self.optimizer.values():
opt.step()
for scheduler in self.scheduler.values():
scheduler.step()
return loss_val

def predict(self, batch, unsort=True):
Expand Down Expand Up @@ -127,9 +160,14 @@ def save(self, filename, skip_modules=True, save_optimizer=False):
'last_best_step': self.last_best_step,
'dev_score_history': self.dev_score_history,
}
if self.args.get('use_peft', False):
# Hide import so that peft dependency is optional
from peft import get_peft_model_state_dict
params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model)

if save_optimizer and self.optimizer is not None:
params['optimizer_state_dict'] = self.optimizer.state_dict()
params['optimizer_state_dict'] = {k: opt.state_dict() for k, opt in self.optimizer.items()}
params['scheduler_state_dict'] = {k: scheduler.state_dict() for k, scheduler in self.scheduler.items()}

try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
Expand Down Expand Up @@ -160,15 +198,31 @@ def load(self, filename, pretrain, args=None, foundation_cache=None, device=None
if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
foundation_cache = NoTransformerFoundationCache(foundation_cache)
self.model = Parser(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache)

# if we are set to not finetune bert, but there is an existing
# bert in the model, we need to respect that and force it to
# be resaved next time the model is saved
force_bert_saved = any(x.startswith("bert_model") for x in checkpoint['model'].keys())

self.model = Parser(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, force_bert_saved=force_bert_saved)
self.model.load_state_dict(checkpoint['model'], strict=False)
if self.args.get('use_peft', False):
# hide import so that the peft dependency is optional
from peft import set_peft_model_state_dict
set_peft_model_state_dict(self.model.bert_model, checkpoint['bert_lora'])
if device is not None:
self.model = self.model.to(device)

self.__init_optim()
optim_state_dict = checkpoint.get("optimizer_state_dict")
if optim_state_dict:
self.optimizer.load_state_dict(optim_state_dict)
for k, state in optim_state_dict.items():
self.optimizer[k].load_state_dict(state)

scheduler_state_dict = checkpoint.get("scheduler_state_dict")
if scheduler_state_dict:
for k, state in scheduler_state_dict.items():
self.scheduler[k].load_state_dict(state)

self.global_step = checkpoint.get("global_step", 0)
self.last_best_step = checkpoint.get("last_best_step", 0)
Expand Down
Loading