Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions stanza/models/langid/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,15 @@ def __init__(self, char_to_idx, tag_to_idx, num_layers, embedding_dim, hidden_di
def build_lang_mask(self, use_gpu=None):
"""
Build language mask if a lang subset is specified (e.g. ["en", "fr"])

The mask will be added to the results to set the prediction scores of illegal languages to -inf
"""
device = torch.device("cuda") if use_gpu else None
lang_mask_list = [int(lang in self.lang_subset) for lang in self.idx_to_tag] if self.lang_subset else \
[1 for lang in self.idx_to_tag]
self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)
if self.lang_subset:
lang_mask_list = [0.0 if lang in self.lang_subset else -float('inf') for lang in self.idx_to_tag]
self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)
else:
self.lang_mask = torch.zeros(len(self.idx_to_tag), device=device, dtype=torch.float)

def loss(self, Y_hat, Y):
return self.loss_train(Y_hat, Y)
Expand All @@ -87,7 +91,7 @@ def prediction_scores(self, x):
if self.lang_subset:
prediction_batch_size = prediction_probs.size()[0]
batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)])
prediction_probs = prediction_probs * batch_mask
prediction_probs = prediction_probs + batch_mask
return torch.argmax(prediction_probs, dim=1)

def save(self, path):
Expand Down
17 changes: 17 additions & 0 deletions stanza/tests/langid/test_langid.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,23 @@ def test_lang_subset():
nlp(docs)
assert [doc.lang for doc in docs] == ["en", "en"]

def test_lang_subset_unlikely_language():
"""
Test that the language subset masking chooses a legal language, even if all legal languages are supa unlikely
"""
sentences = ["你好" * 200]
docs = [Document([], text=text) for text in sentences]
nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"])
nlp(docs)
assert [doc.lang for doc in docs] == ["en"]

processor = nlp.processors['langid']
model = processor._model
text_tensor = processor._text_to_tensor(sentences)
en_idx = model.tag_to_idx['en']
predictions = model(text_tensor)
assert predictions[0, en_idx] < 0, "If this test fails, then regardless of how unlikely it was, the model is predicting the input string is possibly English. Update the test by picking a different combination of languages & input"

def test_multilingual_pipeline():
"""
Basic test of multilingual pipeline
Expand Down