Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ else
if opt.gpuid == -1 then cnn_backend = 'nn' end -- override to nn if gpu is disabled
local cnn_raw = loadcaffe.load(opt.cnn_proto, opt.cnn_model, cnn_backend)
protos.cnn = net_utils.build_cnn(cnn_raw, {encoding_size = opt.input_encoding_size, backend = cnn_backend})
-- initialize a special FeatExpander module that "corrects" for the batch number discrepancy
-- initialize a special FeatExpander module that "corrects" for the batch number discrepancy
-- where we have multiple captions per one image in a batch. This is done for efficiency
-- because doing a CNN forward pass is expensive. We expand out the CNN features for each sentence
protos.expander = nn.FeatExpander(opt.seq_per_img)
Expand All @@ -133,7 +133,7 @@ if opt.gpuid >= 0 then
for k,v in pairs(protos) do v:cuda() end
end

-- flatten and prepare all model parameters to a single vector.
-- flatten and prepare all model parameters to a single vector.
-- Keep CNN params separate in case we want to try to get fancy with different optims on LM/CNN
local params, grad_params = protos.lm:getParameters()
local cnn_params, cnn_grad_params = protos.cnn:getParameters()
Expand All @@ -146,15 +146,17 @@ assert(cnn_params:nElement() == cnn_grad_params:nElement())
-- modules. These thin module will have no intermediates and will be used
-- for checkpointing to write significantly smaller checkpoint files
local thin_lm = protos.lm:clone()
thin_lm.core:share(protos.lm.core, 'weight', 'bias') -- TODO: we are assuming that LM has specific members! figure out clean way to get rid of, not modular.
thin_lm.lookup_table:share(protos.lm.lookup_table, 'weight', 'bias')
for k,v in pairs(thin_lm) do
if type(v) == 'table' and v.share then
v:share(protos.lm[k],'weights','bias')
net_utils.sanitize_gradients(v)
end
end
local thin_cnn = protos.cnn:clone('weight', 'bias')
-- sanitize all modules of gradient storage so that we dont save big checkpoints
net_utils.sanitize_gradients(thin_cnn)
local lm_modules = thin_lm:getModulesList()
for k,v in pairs(lm_modules) do net_utils.sanitize_gradients(v) end

-- create clones and ensure parameter sharing. we have to do this
-- create clones and ensure parameter sharing. we have to do this
-- all the way here at the end because calls such as :cuda() and
-- :getParameters() reshuffle memory around.
protos.lm:createClones()
Expand Down Expand Up @@ -236,10 +238,10 @@ local function lossFun()
-----------------------------------------------------------------------------
-- Forward pass
-----------------------------------------------------------------------------
-- get batch of data
-- get batch of data
local data = loader:getBatch{batch_size = opt.batch_size, split = 'train', seq_per_img = opt.seq_per_img}
data.images = net_utils.prepro(data.images, true, opt.gpuid >= 0) -- preprocess in place, do data augmentation
-- data.images: Nx3x224x224
-- data.images: Nx3x224x224
-- data.seq: LxM where L is sequence length upper bound, and M = N*seq_per_img

-- forward the ConvNet on images (most work happens here)
Expand All @@ -250,7 +252,7 @@ local function lossFun()
local logprobs = protos.lm:forward{expanded_feats, data.labels}
-- forward the language model criterion
local loss = protos.crit:forward(logprobs, data.labels)

-----------------------------------------------------------------------------
-- Backward pass
-----------------------------------------------------------------------------
Expand Down Expand Up @@ -291,7 +293,7 @@ local loss_history = {}
local val_lang_stats_history = {}
local val_loss_history = {}
local best_score
while true do
while true do

-- eval loss/gradient
local losses = lossFun()
Expand Down Expand Up @@ -341,7 +343,7 @@ while true do
save_protos.lm = thin_lm -- these are shared clones, and point to correct param storage
save_protos.cnn = thin_cnn
checkpoint.protos = save_protos
-- also include the vocabulary mapping so that we can use the checkpoint
-- also include the vocabulary mapping so that we can use the checkpoint
-- alone to run on arbitrary images without the data loader
checkpoint.vocab = loader:getVocab()
torch.save(checkpoint_path .. '.t7', checkpoint)
Expand Down