@@ -58,21 +58,23 @@ def metric_score(
58
58
preds = to_int_label_array (preds ) if preds is not None else None
59
59
60
60
# Optionally filter out examples (e.g., abstain predictions or unknown labels)
61
- label_dict = {"golds" : golds , "preds" : preds , "probs" : probs }
61
+ label_dict : Dict [ str , Optional [ np . ndarray ]] = {"golds" : golds , "preds" : preds , "probs" : probs }
62
62
if filter_dict :
63
63
if set (filter_dict .keys ()).difference (set (label_dict .keys ())):
64
64
raise ValueError (
65
65
"filter_dict must only include keys in ['golds', 'preds', 'probs']"
66
66
)
67
- label_dict = filter_labels (label_dict , filter_dict )
67
+ # Reassign filtered label_dict to a new variable to avoid
68
+ # mypy error regarding change variable of invariant type
69
+ label_dict_filtered : Dict [str , np .ndarray ] = filter_labels (label_dict , filter_dict )
68
70
69
71
# Confirm that required label sets are available
70
72
func , label_names = METRICS [metric ]
71
73
for label_name in label_names :
72
- if label_dict [label_name ] is None :
74
+ if label_dict_filtered [label_name ] is None :
73
75
raise ValueError (f"Metric { metric } requires access to { label_name } ." )
74
76
75
- label_sets = [label_dict [label_name ] for label_name in label_names ]
77
+ label_sets = [label_dict_filtered [label_name ] for label_name in label_names ]
76
78
return func (* label_sets , ** kwargs )
77
79
78
80
0 commit comments