Skip to content

Commit 502db83

Browse files
Merge pull request #105 from MantisAI/feat/adding-min-overlap-threshold
feat: defining a min ground truth percentage to be considered an overlap
2 parents fb2630f + a6c9abd commit 502db83

File tree

5 files changed

+487
-17
lines changed

5 files changed

+487
-17
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ disable = [
6868
"R0801", # duplicate-code
6969
"W9020", # bad-option-value
7070
"W0621", # redefined-outer-name
71+
"W0212", # protected-access
7172
]
7273

7374
[tool.pylint.'DESIGN']

src/nervaluate/evaluator.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
class Evaluator:
1818
"""Main evaluator class for NER evaluation."""
1919

20-
def __init__(self, true: Any, pred: Any, tags: List[str], loader: str = "default") -> None:
20+
def __init__(
21+
self, true: Any, pred: Any, tags: List[str], loader: str = "default", min_overlap_percentage: float = 1.0
22+
) -> None:
2123
"""
2224
Initialize the evaluator.
2325
@@ -26,8 +28,10 @@ def __init__(self, true: Any, pred: Any, tags: List[str], loader: str = "default
2628
pred: Predicted entities in any supported format
2729
tags: List of valid entity tags
2830
loader: Name of the loader to use
31+
min_overlap_percentage: Minimum overlap percentage for partial matches (1-100)
2932
"""
3033
self.tags = tags
34+
self.min_overlap_percentage = min_overlap_percentage
3135
self._setup_loaders()
3236
self._load_data(true, pred, loader)
3337
self._setup_evaluation_strategies()
@@ -37,12 +41,12 @@ def _setup_loaders(self) -> None:
3741
self.loaders: Dict[str, DataLoader] = {"conll": ConllLoader(), "list": ListLoader(), "dict": DictLoader()}
3842

3943
def _setup_evaluation_strategies(self) -> None:
40-
"""Setup evaluation strategies."""
44+
"""Setup evaluation strategies with overlap threshold."""
4145
self.strategies: Dict[str, EvaluationStrategy] = {
42-
"strict": StrictEvaluation(),
43-
"partial": PartialEvaluation(),
44-
"ent_type": EntityTypeEvaluation(),
45-
"exact": ExactEvaluation(),
46+
"strict": StrictEvaluation(self.min_overlap_percentage),
47+
"partial": PartialEvaluation(self.min_overlap_percentage),
48+
"ent_type": EntityTypeEvaluation(self.min_overlap_percentage),
49+
"exact": ExactEvaluation(self.min_overlap_percentage),
4650
}
4751

4852
def _load_data(self, true: Any, pred: Any, loader: str) -> None:

src/nervaluate/strategies.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,45 @@
77
class EvaluationStrategy(ABC):
88
"""Abstract base class for evaluation strategies."""
99

10+
def __init__(self, min_overlap_percentage: float = 1.0):
11+
"""
12+
Initialize strategy with minimum overlap threshold.
13+
14+
Args:
15+
min_overlap_percentage: Minimum overlap percentage required (1-100)
16+
"""
17+
if not 1.0 <= min_overlap_percentage <= 100.0:
18+
raise ValueError("min_overlap_percentage must be between 1.0 and 100.0")
19+
self.min_overlap_percentage = min_overlap_percentage
20+
21+
@staticmethod
22+
def _calculate_overlap_percentage(pred: Entity, true: Entity) -> float:
23+
"""
24+
Calculate the percentage overlap between predicted and true entities.
25+
26+
Returns:
27+
Overlap percentage based on true entity span (0-100)
28+
"""
29+
# Check if there's any overlap first
30+
if pred.start > true.end or pred.end < true.start:
31+
return 0.0
32+
33+
# Calculate overlap boundaries
34+
overlap_start = max(pred.start, true.start)
35+
overlap_end = min(pred.end, true.end)
36+
37+
# Calculate spans (adding 1 because end is inclusive)
38+
overlap_span = overlap_end - overlap_start + 1
39+
true_span = true.end - true.start + 1
40+
41+
# Calculate percentage based on true entity span
42+
return (overlap_span / true_span) * 100.0
43+
44+
def _has_sufficient_overlap(self, pred: Entity, true: Entity) -> bool:
45+
"""Check if entities have sufficient overlap based on threshold."""
46+
overlap_percentage = EvaluationStrategy._calculate_overlap_percentage(pred, true)
47+
return overlap_percentage >= self.min_overlap_percentage
48+
1049
@abstractmethod
1150
def evaluate(
1251
self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0
@@ -50,8 +89,8 @@ def evaluate(
5089
matched_true.add(true_idx)
5190
found_match = True
5291
break
53-
# Check for any overlap
54-
if pred.start <= true.end and pred.end >= true.start:
92+
# Check for sufficient overlap with min threshold
93+
if self._has_sufficient_overlap(pred, true):
5594
result.incorrect += 1
5695
indices.incorrect_indices.append((instance_index, pred_idx))
5796
matched_true.add(true_idx)
@@ -97,8 +136,8 @@ def evaluate(
97136
if true_idx in matched_true:
98137
continue
99138

100-
# Check for overlap
101-
if pred.start <= true.end and pred.end >= true.start:
139+
# Check for sufficient overlap with min threshold
140+
if self._has_sufficient_overlap(pred, true):
102141
if pred.start == true.start and pred.end == true.end:
103142
result.correct += 1
104143
indices.correct_indices.append((instance_index, pred_idx))
@@ -135,7 +174,6 @@ class EntityTypeEvaluation(EvaluationStrategy):
135174
If there's a predicted entity that doesn't match any true entity, we mark it as spurious.
136175
If there's a true entity that doesn't match any predicted entity, we mark it as missed.
137176
138-
# ToDo: define a minimum overlap threshold - see: https://github.com/MantisAI/nervaluate/pull/83
139177
"""
140178

141179
def evaluate(
@@ -153,8 +191,8 @@ def evaluate(
153191
if true_idx in matched_true:
154192
continue
155193

156-
# Check for any overlap (perfect or minimum)
157-
if pred.start <= true.end and pred.end >= true.start:
194+
# Check for sufficient overlap with min threshold
195+
if self._has_sufficient_overlap(pred, true):
158196
if pred.label == true.label:
159197
result.correct += 1
160198
indices.correct_indices.append((instance_index, pred_idx))
@@ -216,8 +254,8 @@ def evaluate(
216254
matched_true.add(true_idx)
217255
found_match = True
218256
break
219-
# Check for any overlap
220-
if pred.start <= true.end and pred.end >= true.start:
257+
# Check for sufficient overlap with min threshold
258+
if self._has_sufficient_overlap(pred, true):
221259
result.incorrect += 1
222260
indices.incorrect_indices.append((instance_index, pred_idx))
223261
matched_true.add(true_idx)

tests/test_evaluator.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,190 @@ def test_results_to_csv(sample_data, tmp_path):
166166
# test invalid scenario for entities mode
167167
with pytest.raises(ValueError, match="Invalid scenario: must be one of"):
168168
evaluator.results_to_csv(mode="entities", scenario="invalid")
169+
170+
171+
def test_evaluator_with_min_overlap_percentage():
172+
"""Test Evaluator class with minimum overlap percentage parameter."""
173+
174+
# Test data: true entity spans positions 0-9 (10 tokens)
175+
true_entities = [[{"label": "PER", "start": 0, "end": 9}]] # 10-token entity
176+
177+
# Predicted entities with different overlap percentages
178+
pred_entities = [[{"label": "PER", "start": 0, "end": 2}]] # 30% overlap
179+
180+
# Test with default 1% threshold - should be partial match
181+
evaluator_default = Evaluator(true=true_entities, pred=pred_entities, tags=["PER"], loader="dict")
182+
results_default = evaluator_default.evaluate()
183+
partial_default = results_default["overall"]["partial"]
184+
assert partial_default.partial == 1
185+
assert partial_default.spurious == 0
186+
187+
# Test with 50% threshold - should be spurious
188+
evaluator_50 = Evaluator(
189+
true=true_entities, pred=pred_entities, tags=["PER"], loader="dict", min_overlap_percentage=50.0
190+
)
191+
results_50 = evaluator_50.evaluate()
192+
partial_50 = results_50["overall"]["partial"]
193+
assert partial_50.partial == 0
194+
assert partial_50.spurious == 1
195+
196+
197+
def test_evaluator_min_overlap_validation():
198+
"""Test that Evaluator validates minimum overlap percentage."""
199+
true_entities = [[{"label": "PER", "start": 0, "end": 5}]]
200+
pred_entities = [[{"label": "PER", "start": 0, "end": 5}]]
201+
202+
# Valid values should work
203+
Evaluator(true_entities, pred_entities, ["PER"], "dict", min_overlap_percentage=1.0)
204+
Evaluator(true_entities, pred_entities, ["PER"], "dict", min_overlap_percentage=50.0)
205+
Evaluator(true_entities, pred_entities, ["PER"], "dict", min_overlap_percentage=100.0)
206+
207+
# Invalid values should raise ValueError during strategy initialization
208+
with pytest.raises(ValueError, match="min_overlap_percentage must be between 1.0 and 100.0"):
209+
Evaluator(true_entities, pred_entities, ["PER"], "dict", min_overlap_percentage=0.5)
210+
211+
with pytest.raises(ValueError, match="min_overlap_percentage must be between 1.0 and 100.0"):
212+
Evaluator(true_entities, pred_entities, ["PER"], "dict", min_overlap_percentage=101.0)
213+
214+
215+
def test_evaluator_min_overlap_affects_all_strategies():
216+
"""Test that minimum overlap percentage affects all evaluation strategies."""
217+
true_entities = [[{"label": "PER", "start": 0, "end": 9}]] # 10 tokens
218+
219+
pred_entities = [[{"label": "PER", "start": 0, "end": 2}]] # 30% overlap
220+
221+
evaluator = Evaluator(
222+
true=true_entities, pred=pred_entities, tags=["PER"], loader="dict", min_overlap_percentage=50.0
223+
)
224+
225+
results = evaluator.evaluate()
226+
227+
# All strategies should respect the 50% threshold
228+
# 30% overlap < 50% threshold, so should be spurious for all strategies
229+
230+
# Partial strategy
231+
partial_result = results["overall"]["partial"]
232+
assert partial_result.spurious == 1
233+
assert partial_result.correct == 0
234+
assert partial_result.partial == 0
235+
236+
# Strict strategy
237+
strict_result = results["overall"]["strict"]
238+
assert strict_result.spurious == 1
239+
assert strict_result.correct == 0
240+
assert strict_result.incorrect == 0
241+
242+
# Entity type strategy
243+
ent_type_result = results["overall"]["ent_type"]
244+
assert ent_type_result.spurious == 1
245+
assert ent_type_result.correct == 0
246+
assert ent_type_result.incorrect == 0
247+
248+
# Exact strategy
249+
exact_result = results["overall"]["exact"]
250+
assert exact_result.spurious == 1
251+
assert exact_result.correct == 0
252+
assert exact_result.incorrect == 0
253+
254+
255+
def test_evaluator_min_overlap_with_different_thresholds():
256+
"""Test Evaluator with different overlap thresholds."""
257+
true_entities = [[{"label": "PER", "start": 0, "end": 9}]] # 10 tokens
258+
259+
# Test cases with different predicted entities
260+
test_cases = [
261+
# (pred_entities, threshold, expected_result_type)
262+
([{"label": "PER", "start": 0, "end": 4}], 50.0, "partial"), # 50% overlap = 50%
263+
([{"label": "PER", "start": 0, "end": 4}], 51.0, "spurious"), # 50% overlap < 51%
264+
([{"label": "PER", "start": 0, "end": 6}], 75.0, "spurious"), # 70% overlap < 75%
265+
([{"label": "PER", "start": 0, "end": 7}], 75.0, "partial"), # 80% overlap > 75%
266+
([{"label": "PER", "start": 0, "end": 9}], 100.0, "correct"), # 100% overlap = exact match
267+
]
268+
269+
for pred_data, threshold, expected_type in test_cases:
270+
pred_entities = [pred_data]
271+
272+
evaluator = Evaluator(
273+
true=true_entities, pred=pred_entities, tags=["PER"], loader="dict", min_overlap_percentage=threshold
274+
)
275+
276+
results = evaluator.evaluate()
277+
partial_results = results["overall"]["partial"]
278+
279+
if expected_type == "correct":
280+
assert partial_results.correct == 1, f"Failed for {pred_data} with threshold {threshold}%"
281+
assert partial_results.partial == 0
282+
assert partial_results.spurious == 0
283+
elif expected_type == "partial":
284+
assert partial_results.partial == 1, f"Failed for {pred_data} with threshold {threshold}%"
285+
assert partial_results.correct == 0
286+
assert partial_results.spurious == 0
287+
elif expected_type == "spurious":
288+
assert partial_results.spurious == 1, f"Failed for {pred_data} with threshold {threshold}%"
289+
assert partial_results.correct == 0
290+
assert partial_results.partial == 0
291+
292+
293+
def test_evaluator_min_overlap_with_multiple_entities():
294+
"""Test Evaluator with multiple entities and minimum overlap threshold."""
295+
true_entities = [
296+
[
297+
{"label": "PER", "start": 0, "end": 4}, # 5 tokens
298+
{"label": "ORG", "start": 10, "end": 14}, # 5 tokens
299+
{"label": "LOC", "start": 20, "end": 24}, # 5 tokens
300+
]
301+
]
302+
303+
pred_entities = [
304+
[
305+
{"label": "PER", "start": 0, "end": 1}, # 40% overlap (2/5 tokens)
306+
{"label": "ORG", "start": 10, "end": 12}, # 60% overlap (3/5 tokens)
307+
{"label": "LOC", "start": 20, "end": 24}, # 100% overlap (exact match)
308+
{"label": "MISC", "start": 30, "end": 32}, # No overlap (spurious)
309+
]
310+
]
311+
312+
# Test with 50% threshold
313+
evaluator = Evaluator(
314+
true=true_entities,
315+
pred=pred_entities,
316+
tags=["PER", "ORG", "LOC", "MISC"],
317+
loader="dict",
318+
min_overlap_percentage=50.0,
319+
)
320+
321+
results = evaluator.evaluate()
322+
partial_results = results["overall"]["partial"]
323+
324+
assert partial_results.correct == 1 # LOC exact match
325+
assert partial_results.partial == 1 # ORG 60% overlap > 50%
326+
assert partial_results.spurious == 2 # PER 40% < 50% and MISC no overlap
327+
assert partial_results.missed == 1 # PER entity not sufficiently matched
328+
329+
330+
def test_evaluator_min_overlap_backward_compatibility():
331+
"""Test that the new feature maintains backward compatibility."""
332+
true_entities = [[{"label": "PER", "start": 0, "end": 9}]]
333+
334+
pred_entities = [[{"label": "PER", "start": 9, "end": 9}]] # 10% overlap (1 token out of 10)
335+
336+
# Without specifying min_overlap_percentage (should default to 1.0)
337+
evaluator_default = Evaluator(true=true_entities, pred=pred_entities, tags=["PER"], loader="dict")
338+
339+
# With explicitly setting to 1.0
340+
evaluator_explicit = Evaluator(
341+
true=true_entities, pred=pred_entities, tags=["PER"], loader="dict", min_overlap_percentage=1.0
342+
)
343+
344+
results_default = evaluator_default.evaluate()
345+
results_explicit = evaluator_explicit.evaluate()
346+
347+
# Results should be identical
348+
for strategy in ["strict", "partial", "ent_type", "exact"]:
349+
default_result = results_default["overall"][strategy]
350+
explicit_result = results_explicit["overall"][strategy]
351+
352+
assert default_result.correct == explicit_result.correct
353+
assert default_result.partial == explicit_result.partial
354+
assert default_result.spurious == explicit_result.spurious
355+
assert default_result.missed == explicit_result.missed

0 commit comments

Comments
 (0)