Skip to content
79 changes: 53 additions & 26 deletions stanza/models/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DevScoring(Enum):
WEIGHTED_F1 = 'WF'

logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.classifiers.trainer')

DEFAULT_TRAIN='data/sentiment/en_sstplus.train.txt'
DEFAULT_DEV='data/sentiment/en_sst3roots.dev.txt'
Expand Down Expand Up @@ -178,6 +179,12 @@ def parse_args(args=None):
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")

parser.add_argument('--bilstm', dest='bilstm', action='store_true', help="Use a bilstm after the inputs, before the convs")
parser.add_argument('--bilstm_hidden_dim', type=int, default=200, help="Dimension of the bilstm to use")
parser.add_argument('--no_bilstm', dest='bilstm', action='store_false', help="Don't use a bilstm after the inputs, before the convs")

parser.add_argument('--maxpool_width', type=int, default=1, help="Width of the maxpool kernel to use")

parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')

Expand Down Expand Up @@ -369,6 +376,8 @@ def log_param_sizes(model):
logger.debug(" Total size: %d", total_size)

def train_model(model, model_file, args, train_set, dev_set, labels):
tlogger.setLevel(logging.DEBUG)

# TODO: separate this into a trainer like the other models.
# TODO: possibly reuse the trainer code other models have
# TODO: use a (torch) dataloader to possibly speed up the GPU usage
Expand Down Expand Up @@ -516,6 +525,48 @@ def print_args(args):
log_lines = ['%s: %s' % (k, args[k]) for k in keys]
logger.info('ARGS USED AT TRAINING TIME:\n%s\n' % '\n'.join(log_lines))

def load_model(args):
"""
Load both the pretrained embedding and other pieces from the args as well as the model itself
"""
pretrain = load_pretrain(args)
charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)

if os.path.exists(args.load_name):
load_name = args.load_name
else:
load_name = os.path.join(args.save_dir, args.load_name)
if not os.path.exists(load_name):
raise FileNotFoundError("Could not find model to load in either %s or %s" % (args.load_name, load_name))
return cnn_classifier.load(load_name, pretrain, charmodel_forward, charmodel_backward)

def build_new_model(args, train_set):
"""
Load pretrained pieces and then build a new model
"""
if train_set is None:
raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors")

pretrain = load_pretrain(args)
charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)

labels = dataset_labels(train_set)
extra_vocab = dataset_vocab(train_set)

bert_model, bert_tokenizer = load_bert(args.bert_model)

return cnn_classifier.CNNClassifier(pretrain=pretrain,
extra_vocab=extra_vocab,
labels=labels,
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
args=args)


def main(args=None):
args = parse_args(args)
seed = utils.set_random_seed(args.seed, args.cuda)
Expand All @@ -537,34 +588,10 @@ def main(args=None):
else:
train_set = None

pretrain = load_pretrain(args)

charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)

if args.load_name:
if os.path.exists(args.load_name):
load_name = args.load_name
else:
load_name = os.path.join(args.save_dir, args.load_name)
if not os.path.exists(load_name):
raise FileNotFoundError("Could not find model to load in either %s or %s" % (args.load_name, load_name))
model = cnn_classifier.load(load_name, pretrain, charmodel_forward, charmodel_backward)
model = load_model(args)
else:
assert train_set is not None
labels = dataset_labels(train_set)
extra_vocab = dataset_vocab(train_set)

bert_model, bert_tokenizer = load_bert(args.bert_model)

model = cnn_classifier.CNNClassifier(pretrain=pretrain,
extra_vocab=extra_vocab,
labels=labels,
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
args=args)
model = build_new_model(args, train_set)

if args.cuda:
model.cuda()
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/classifiers/classifier_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def add_pretrain_args(parser):
parser.add_argument('--wordvec_type', type=lambda x: WVType[x.upper()], default='word2vec', help='Different vector types have different options, such as google 300d replacing numbers with #')
parser.add_argument('--shorthand', type=str, default='en_ewt', help="Treebank shorthand, eg 'en' for English")
parser.add_argument('--extra_wordvec_dim', type=int, default=0, help="Extra dim of word vectors - will be trained")
parser.add_argument('--extra_wordvec_method', type=lambda x: ExtraVectors[x.upper()], default='none', help='How to train extra dimensions of word vectors, if at all')
parser.add_argument('--extra_wordvec_method', type=lambda x: ExtraVectors[x.upper()], default='sum', help='How to train extra dimensions of word vectors, if at all')
parser.add_argument('--extra_wordvec_max_norm', type=float, default=None, help="Max norm for initializing the extra vectors")

def add_device_args(parser):
Expand Down
96 changes: 76 additions & 20 deletions stanza/models/classifiers/cnn_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import math
import os
import random
import re
from types import SimpleNamespace
Expand All @@ -21,10 +23,17 @@

https://arxiv.org/abs/1408.5882

Also included are maxpool 2d, conv 2d, and a bilstm, as in

Text Classification Improved by Integrating Bidirectional LSTM
with Two-dimensional Max Pooling
https://aclanthology.org/C16-1329.pdf

The architecture is simple:

- Embedding at the bottom layer
- separate learnable entry for UNK, since many of the embeddings we have use 0 for UNK
- maybe a bilstm layer, as per a command line flag
- Some number of conv2d layers over the embedding
- Maxpool layers over small windows, window size being a parameter
- FC layer to the classification layer
Expand All @@ -39,6 +48,7 @@
"""

logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.classifiers.trainer')

class CNNClassifier(nn.Module):
def __init__(self, pretrain, extra_vocab, labels,
Expand Down Expand Up @@ -73,6 +83,9 @@ def __init__(self, pretrain, extra_vocab, labels,
char_lowercase = args.char_lowercase,
charlm_projection = args.charlm_projection,
bert_model = args.bert_model,
bilstm = args.bilstm,
bilstm_hidden_dim = args.bilstm_hidden_dim,
maxpool_width = args.maxpool_width,
model_type = 'CNNClassifier')

self.char_lowercase = args.char_lowercase
Expand All @@ -86,12 +99,12 @@ def __init__(self, pretrain, extra_vocab, labels,

self.add_unsaved_module('forward_charlm', charmodel_forward)
if charmodel_forward is not None:
logger.debug("Got forward char model of dimension {}".format(charmodel_forward.hidden_dim()))
tlogger.debug("Got forward char model of dimension {}".format(charmodel_forward.hidden_dim()))
if not charmodel_forward.is_forward_lm:
raise ValueError("Got a backward charlm as a forward charlm!")
self.add_unsaved_module('backward_charlm', charmodel_backward)
if charmodel_backward is not None:
logger.debug("Got backward char model of dimension {}".format(charmodel_backward.hidden_dim()))
tlogger.debug("Got backward char model of dimension {}".format(charmodel_backward.hidden_dim()))
if charmodel_backward.is_forward_lm:
raise ValueError("Got a forward charlm as a backward charlm!")

Expand Down Expand Up @@ -125,7 +138,7 @@ def __init__(self, pretrain, extra_vocab, labels,
embedding_dim = self.config.extra_wordvec_dim,
max_norm = self.config.extra_wordvec_max_norm,
padding_idx = 0)
logger.debug("Extra embedding size: {}".format(self.extra_embedding.weight.shape))
tlogger.debug("Extra embedding size: {}".format(self.extra_embedding.weight.shape))
else:
self.extra_vocab = None
self.extra_vocab_map = None
Expand Down Expand Up @@ -165,21 +178,53 @@ def __init__(self, pretrain, extra_vocab, labels,
self.bert_dim = self.bert_model.config.hidden_size
total_embedding_dim += self.bert_dim

self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels=1,
out_channels=self.config.filter_channels,
kernel_size=(filter_size, total_embedding_dim))
for filter_size in self.config.filter_sizes])
if self.config.bilstm:
conv_input_dim = self.config.bilstm_hidden_dim * 2
self.bilstm = nn.LSTM(batch_first=True,
input_size=total_embedding_dim,
hidden_size=self.config.bilstm_hidden_dim,
num_layers=2,
bidirectional=True,
dropout=0.2)
else:
conv_input_dim = total_embedding_dim
self.bilstm = None

self.fc_input_size = 0
self.conv_layers = nn.ModuleList()
self.max_window = 0
for filter_size in self.config.filter_sizes:
if isinstance(filter_size, int):
self.max_window = max(self.max_window, filter_size)
fc_delta = self.config.filter_channels // self.config.maxpool_width
tlogger.debug("Adding full width filter %d. Output channels: %d -> %d", filter_size, self.config.filter_channels, fc_delta)
self.fc_input_size += fc_delta
self.conv_layers.append(nn.Conv2d(in_channels=1,
out_channels=self.config.filter_channels,
kernel_size=(filter_size, conv_input_dim)))
elif isinstance(filter_size, tuple) and len(filter_size) == 2:
filter_height, filter_width = filter_size
self.max_window = max(self.max_window, filter_width)
filter_channels = max(1, self.config.filter_channels // (conv_input_dim // filter_width))
fc_delta = filter_channels * (conv_input_dim // filter_width) // self.config.maxpool_width
tlogger.debug("Adding filter %s. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
self.fc_input_size += fc_delta
self.conv_layers.append(nn.Conv2d(in_channels=1,
out_channels=filter_channels,
stride=(1, filter_width),
kernel_size=(filter_height, filter_width)))
else:
raise ValueError("Expected int or 2d tuple for conv size")

previous_layer_size = len(self.config.filter_sizes) * self.config.filter_channels
tlogger.debug("Input dim to FC layers: %d", self.fc_input_size)
fc_layers = []
previous_layer_size = self.fc_input_size
for shape in self.config.fc_shapes:
fc_layers.append(nn.Linear(previous_layer_size, shape))
previous_layer_size = shape
fc_layers.append(nn.Linear(previous_layer_size, self.config.num_classes))
self.fc_layers = nn.ModuleList(fc_layers)

self.max_window = max(self.config.filter_sizes)

self.dropout = nn.Dropout(self.config.dropout)

def add_unsaved_module(self, name, module):
Expand Down Expand Up @@ -325,12 +370,23 @@ def map_word(word):
# still works even if there's just one item
input_vectors = torch.cat(all_inputs, dim=2)

if self.config.bilstm:
input_vectors, _ = self.bilstm(self.dropout(input_vectors))

# reshape to fit the input tensors
x = input_vectors.unsqueeze(1)

conv_outs = [self.dropout(F.relu(conv(x).squeeze(3)))
for conv in self.conv_layers]
pool_outs = [F.max_pool1d(out, out.shape[2]).squeeze(2) for out in conv_outs]
conv_outs = []
for conv, filter_size in zip(self.conv_layers, self.config.filter_sizes):
# TODO: non-int filter sizes
if isinstance(filter_size, int):
conv_out = self.dropout(F.relu(conv(x).squeeze(3)))
conv_outs.append(conv_out)
else:
conv_out = conv(x).transpose(2, 3).flatten(1, 2)
conv_out = self.dropout(F.relu(conv_out))
conv_outs.append(conv_out)
pool_outs = [F.max_pool2d(out, (self.config.maxpool_width, out.shape[2])).squeeze(2) for out in conv_outs]
pooled = torch.cat(pool_outs, dim=1)

previous_layer = pooled
Expand All @@ -346,6 +402,8 @@ def map_word(word):

# TODO: all this code is basically the same as for POS and NER. Should refactor
def save(filename, model, skip_modules=True):
save_dir = os.path.split(filename)[0]
os.makedirs(save_dir, exist_ok=True)
model_state = model.state_dict()
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
if skip_modules:
Expand All @@ -358,13 +416,8 @@ def save(filename, model, skip_modules=True):
'labels': model.labels,
'extra_vocab': model.extra_vocab,
}
try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as e:
logger.warning("Saving failed to {}... continuing anyway. Error: {}".format(filename, e))
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))

def load(filename, pretrain, charmodel_forward, charmodel_backward, foundation_cache=None):
try:
Expand All @@ -378,6 +431,9 @@ def load(filename, pretrain, charmodel_forward, charmodel_backward, foundation_c
setattr(checkpoint['config'], 'char_lowercase', getattr(checkpoint['config'], 'char_lowercase', False))
setattr(checkpoint['config'], 'charlm_projection', getattr(checkpoint['config'], 'charlm_projection', None))
setattr(checkpoint['config'], 'bert_model', getattr(checkpoint['config'], 'bert_model', None))
setattr(checkpoint['config'], 'bilstm', getattr(checkpoint['config'], 'bilstm', False))
setattr(checkpoint['config'], 'bilstm_hidden_dim', getattr(checkpoint['config'], 'bilstm_hidden_dim', 0))
setattr(checkpoint['config'], 'maxpool_width', getattr(checkpoint['config'], 'maxpool_width', 1))

# TODO: the getattr is not needed when all models have this baked into the config
model_type = getattr(checkpoint['config'], 'model_type', 'CNNClassifier')
Expand Down
Empty file.
Loading