Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Based on this framework, we recorded the 1st place of [ICDAR2013 focused scene t
The difference between our paper and ICDAR challenge is summarized [here](https://github.com/clovaai/deep-text-recognition-benchmark/issues/13).

## Updates
**Aug 3, 2020**: added [guideline to use Baidu warpctc](https://github.com/clovaai/deep-text-recognition-benchmark/pull/209) which reproduces CTC results of our paper. <br>
**Dec 27, 2019**: added [FLOPS](https://github.com/clovaai/deep-text-recognition-benchmark/issues/125) in our paper, and minor updates such as log_dataset.txt and [ICDAR2019-NormalizedED](https://github.com/clovaai/deep-text-recognition-benchmark/blob/86451088248e0490ff8b5f74d33f7d014f6c249a/test.py#L139-L165). <br>
**Oct 22, 2019**: added [confidence score](https://github.com/clovaai/deep-text-recognition-benchmark/issues/82), and arranged the output form of training logs. <br>
**Jul 31, 2019**: The paper is accepted at International Conference on Computer Vision (ICCV), Seoul 2019, as an oral talk. <br>
Expand Down
16 changes: 14 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']

# # To easily compute the total accuracy of our paper.
# eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867',
# 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']

if calculate_infer_time:
evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image.
else:
Expand Down Expand Up @@ -100,10 +104,17 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# Calculate evaluation loss for CTC deocder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
# permute 'preds' to use CTCloss format
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
if opt.baiduCTC:
cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
else:
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)

# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
if opt.baiduCTC:
_, preds_index = preds.max(2)
preds_index = preds_index.view(-1)
else:
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index.data, preds_size.data)

else:
Expand Down Expand Up @@ -246,6 +257,7 @@ def test(opt):
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
""" Model Architecture """
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
Expand Down
23 changes: 18 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.utils.data
import numpy as np

from utils import CTCLabelConverter, AttnLabelConverter, Averager
from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from model import Model
from test import validation
Expand Down Expand Up @@ -45,7 +45,10 @@ def train(opt):

""" model configuration """
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
if opt.baiduCTC:
converter = CTCLabelConverterForBaiduWarpctc(opt.character)
else:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
Expand Down Expand Up @@ -86,7 +89,12 @@ def train(opt):

""" setup loss """
if 'CTC' in opt.Prediction:
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
if opt.baiduCTC:
# need to install warpctc. see our guideline.
from warpctc_pytorch import CTCLoss
criterion = CTCLoss()
else:
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
# loss averager
Expand Down Expand Up @@ -144,8 +152,12 @@ def train(opt):
if 'CTC' in opt.Prediction:
preds = model(image, text)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.log_softmax(2).permute(1, 0, 2)
cost = criterion(preds, text, preds_size, length)
if opt.baiduCTC:
preds = preds.permute(1, 0, 2) # to use CTCLoss format
cost = criterion(preds, text, preds_size, length) / batch_size
else:
preds = preds.log_softmax(2).permute(1, 0, 2)
cost = criterion(preds, text, preds_size, length)

else:
preds = model(image, text[:, :-1]) # align with Attention.forward
Expand Down Expand Up @@ -232,6 +244,7 @@ def train(opt):
parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95')
parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5')
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
""" Data processing """
parser.add_argument('--select_data', type=str, default='MJ-ST',
help='select training data (default is MJ-ST, which means MJ and ST used as training data)')
Expand Down
47 changes: 47 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,53 @@ def decode(self, text_index, length):
return texts


class CTCLabelConverterForBaiduWarpctc(object):
""" Convert between text-label and text-index for baidu warpctc """

def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)

self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1

self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)

def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]

return (torch.IntTensor(text), torch.IntTensor(length))

def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]

char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)

texts.append(text)
index += l
return texts


class AttnLabelConverter(object):
""" Convert between text-label and text-index """

Expand Down