Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions stanza/models/charlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import torch

from stanza.models.common.char_model import build_charlm_vocab, CharacterLanguageModel
from stanza.models.common.char_model import build_charlm_vocab, CharacterLanguageModel, CharacterLanguageModelTrainer
from stanza.models.common.vocab import CharVocab
from stanza.models.common import utils
from stanza.models import _training_logging
Expand Down Expand Up @@ -96,6 +96,8 @@ def parse_args(args=None):
parser.add_argument('--eval_steps', type=int, default=100000, help="Update step interval to run eval on dev; set to -1 to eval after each epoch")
parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
parser.add_argument('--vocab_save_name', type=str, default=None, help="File name to save the vocab")
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
parser.add_argument('--save_dir', type=str, default='saved_models/charlm', help="Directory to save models in")
parser.add_argument('--summary', action='store_true', help='Use summary writer to record progress.')
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
Expand Down Expand Up @@ -156,15 +158,16 @@ def evaluate_epoch(args, vocab, data, model, criterion):
total_loss += data.size(1) * loss.data.item()
return total_loss / batches.size(1)

def evaluate_and_save(args, vocab, data, model, criterion, scheduler, best_loss, global_step, model_file, writer=None):
def evaluate_and_save(args, vocab, data, trainer, best_loss, global_step, model_file, checkpoint_file, writer=None):
"""
Run an evaluation over entire dataset, print progress and save the model if necessary.
"""
start_time = time.time()
loss = evaluate_epoch(args, vocab, data, model, criterion)
loss = evaluate_epoch(args, vocab, data, trainer.model, trainer.criterion)
ppl = math.exp(loss)
elapsed = int(time.time() - start_time)
scheduler.step(loss)
# TODO: step the scheduler less often when the eval frequency is higher
trainer.scheduler.step(loss)
logger.info(
"| eval checkpoint @ global step {:10d} | time elapsed {:6d}s | loss {:5.2f} | ppl {:8.2f}".format(
global_step,
Expand All @@ -175,18 +178,27 @@ def evaluate_and_save(args, vocab, data, model, criterion, scheduler, best_loss,
)
if best_loss is None or loss < best_loss:
best_loss = loss
model.save(model_file)
logger.info('new best model saved at step {:10d}.'.format(global_step))
trainer.save(model_file, full=False)
logger.info('new best model saved at step {:10d}'.format(global_step))
if writer:
writer.add_scalar('dev_loss', loss, global_step=global_step)
writer.add_scalar('dev_ppl', ppl, global_step=global_step)
if checkpoint_file:
trainer.save(checkpoint_file, full=True)
logger.info('new checkpoint saved at step {:10d}'.format(global_step))

return loss, ppl, best_loss

def train(args):
model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
else '{}/{}_{}_charlm.pt'.format(args['save_dir'], args['shorthand'], args['direction'])
vocab_file = args['save_dir'] + '/' + args['vocab_save_name'] if args['vocab_save_name'] is not None \
else '{}/{}_vocab.pt'.format(args['save_dir'], args['shorthand'])
if args['checkpoint']:
checkpoint_file = os.path.join(args['save_dir'], args['checkpoint_save_name']) if args['checkpoint_save_name'] \
else os.path.join(args['save_dir'], '{}_{}_charlm_checkpoint.pt'.format(args['shorthand'], args['direction']))
else:
checkpoint_file = None

if os.path.exists(vocab_file):
logger.info('Loading existing vocab file')
Expand All @@ -197,12 +209,10 @@ def train(args):
torch.save(vocab['char'].state_dict(), vocab_file)
logger.info("Training model with vocab size: {}".format(len(vocab['char'])))

model = CharacterLanguageModel(args, vocab, is_forward_lm=True if args['direction'] == 'forward' else False)
if args['cuda']: model = model.cuda()
params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor=args['anneal'], patience=args['patience'])
if checkpoint_file and os.path.exists(checkpoint_file):
trainer = CharacterLanguageModelTrainer.load(args, checkpoint_file, finetune=True)
else:
trainer = CharacterLanguageModelTrainer.from_new_model(args, vocab)

writer = None
if args['summary']:
Expand Down Expand Up @@ -243,7 +253,7 @@ def train(args):
iteration, i = 0, 0
# over the data chunk
while i < batches.size(1) - 1 - 1:
model.train()
trainer.model.train()
global_step += 1
start_time = time.time()
bptt = args['bptt_size'] if np.random.random() < 0.95 else args['bptt_size']/ 2.
Expand All @@ -257,14 +267,14 @@ def train(args):
data = data.cuda()
target = target.cuda()

optimizer.zero_grad()
output, hidden, decoded = model.forward(data, lens, hidden)
loss = criterion(decoded.view(-1, len(vocab['char'])), target)
trainer.optimizer.zero_grad()
output, hidden, decoded = trainer.model.forward(data, lens, hidden)
loss = trainer.criterion(decoded.view(-1, len(vocab['char'])), target)
total_loss += loss.data.item()
loss.backward()

torch.nn.utils.clip_grad_norm_(params, args['max_grad_norm'])
optimizer.step()
torch.nn.utils.clip_grad_norm_(trainer.params, args['max_grad_norm'])
trainer.optimizer.step()

hidden = repackage_hidden(hidden)

Expand All @@ -290,15 +300,15 @@ def train(args):

# evaluate if necessary
if eval_within_epoch and global_step % args['eval_steps'] == 0:
_, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, model, criterion, scheduler, best_loss, \
global_step, model_file, writer)
_, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, \
global_step, model_file, checkpoint_file, writer)
if args['wandb']:
wandb.log({'ppl': ppl, 'best_loss': best_loss}, step=global_step)

# if eval_interval isn't provided, run evaluation after each epoch
if not eval_within_epoch:
_, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, model, criterion, scheduler, best_loss, \
epoch, model_file, writer) # use epoch in place of global_step for logging
_, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, \
epoch, model_file, checkpoint_file, writer) # use epoch in place of global_step for logging
if args['wandb']:
wandb.log({'ppl': ppl, 'best_loss': best_loss}, step=global_step)

Expand Down
76 changes: 72 additions & 4 deletions stanza/models/common/char_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,91 @@ def train(self, mode=True):
if self.finetune: # only set to training mode in finetune status
super().train(mode)

def save(self, filename):
os.makedirs(os.path.split(filename)[0], exist_ok=True)
def full_state(self):
state = {
'vocab': self.vocab['char'].state_dict(),
'args': self.args,
'state_dict': self.state_dict(),
'pad': self.pad,
'is_forward_lm': self.is_forward_lm
}
return state

def save(self, filename):
os.makedirs(os.path.split(filename)[0], exist_ok=True)
state = self.full_state()
torch.save(state, filename, _use_new_zipfile_serialization=False)

@classmethod
def load(cls, filename, finetune=False):
state = torch.load(filename, lambda storage, loc: storage)
def from_full_state(cls, state, finetune=False):
vocab = {'char': CharVocab.load_state_dict(state['vocab'])}
model = cls(state['args'], vocab, state['pad'], state['is_forward_lm'])
model.load_state_dict(state['state_dict'])
model.eval()
model.finetune = finetune # set finetune status
return model

@classmethod
def load(cls, filename, finetune=False):
state = torch.load(filename, lambda storage, loc: storage)
# allow saving just the Model object,
# and allow for old charlms to still work
if 'state_dict' in state:
return cls.from_full_state(state, finetune)
return cls.from_full_state(state['model'], finetune)

class CharacterLanguageModelTrainer():
def __init__(self, model, params, optimizer, criterion, scheduler):
self.model = model
self.params = params
self.optimizer = optimizer
self.criterion = criterion
self.scheduler = scheduler

def save(self, filename, full=True):
os.makedirs(os.path.split(filename)[0], exist_ok=True)
state = {
'model': self.model.full_state()
}
if full and self.optimizer is not None:
state['optimizer'] = self.optimizer.state_dict()
if full and self.criterion is not None:
state['criterion'] = self.criterion.state_dict()
if full and self.scheduler is not None:
state['scheduler'] = self.scheduler.state_dict()
torch.save(state, filename, _use_new_zipfile_serialization=False)

@classmethod
def from_new_model(cls, args, vocab):
model = CharacterLanguageModel(args, vocab, is_forward_lm=True if args['direction'] == 'forward' else False)
if args['cuda']: model = model.cuda()
params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor=args['anneal'], patience=args['patience'])
return cls(model, params, optimizer, criterion, scheduler)


@classmethod
def load(cls, args, filename, finetune=False):
"""
Load the model along with any other saved state for training

Note that you MUST set finetune=True if planning to continue training
Otherwise the only benefit you will get will be a warm GPU
"""
state = torch.load(filename, lambda storage, loc: storage)
model = CharacterLanguageModel.from_full_state(state['model'], finetune)
if args['cuda']: model = model.cuda()

params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
if 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])

criterion = torch.nn.CrossEntropyLoss()
if 'criterion' in state: criterion.load_state_dict(state['criterion'])

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor=args['anneal'], patience=args['patience'])
if 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])
return cls(model, params, optimizer, criterion, scheduler)

12 changes: 10 additions & 2 deletions stanza/tests/common/test_char_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test_build_model():
fout.write("\n")
save_name = 'en_test.forward.pt'
vocab_save_name = 'en_text.vocab.pt'
checkpoint_save_name = 'en_text.checkpoint.pt'
args = ['--train_file', train_file,
'--eval_file', eval_file,
'--eval_steps', '0', # eval once per opoch
Expand All @@ -121,15 +122,22 @@ def test_build_model():
'--shorthand', 'en_test',
'--save_dir', tempdir,
'--save_name', save_name,
'--vocab_save_name', vocab_save_name]
'--vocab_save_name', vocab_save_name,
'--checkpoint_save_name', checkpoint_save_name]
args = charlm.parse_args(args)
charlm.train(args)

assert os.path.exists(os.path.join(tempdir, vocab_save_name))
# test that saving & loading worked

# test that saving & loading of the model worked
assert os.path.exists(os.path.join(tempdir, save_name))
model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, save_name))

# test that saving & loading of the checkpoint worked
assert os.path.exists(os.path.join(tempdir, checkpoint_save_name))
model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, checkpoint_save_name))
trainer = char_model.CharacterLanguageModelTrainer.load(args, os.path.join(tempdir, checkpoint_save_name))

@pytest.fixture
def english_forward():
# eg, stanza_test/models/en/forward_charlm/1billion.pt
Expand Down