1
1
from itertools import chain
2
- from typing import List , Tuple , Union
2
+ from typing import DefaultDict , Dict , List , NamedTuple , Tuple , Union
3
3
4
4
import numpy as np
5
5
from tqdm import tqdm
11
11
RowData = List [Tuple [int , int , int ]]
12
12
13
13
14
+ class ApplierMetadata (NamedTuple ):
15
+ """Metadata about Applier call."""
16
+
17
+ # Map from LF name to number of faults in apply call
18
+ faults : Dict [str , int ]
19
+
20
+
21
+ class _FunctionCaller :
22
+ def __init__ (self , fault_tolerant : bool ):
23
+ self .fault_tolerant = fault_tolerant
24
+ self .fault_counts : DefaultDict [str , int ] = DefaultDict (int )
25
+
26
+ def __call__ (self , f : LabelingFunction , x : DataPoint ) -> int :
27
+ if not self .fault_tolerant :
28
+ return f (x )
29
+ try :
30
+ return f (x )
31
+ except Exception :
32
+ self .fault_counts [f .name ] += 1
33
+ return - 1
34
+
35
+
14
36
class BaseLFApplier :
15
37
"""Base class for LF applier objects.
16
38
@@ -60,7 +82,7 @@ def __repr__(self) -> str:
60
82
61
83
62
84
def apply_lfs_to_data_point (
63
- x : DataPoint , index : int , lfs : List [LabelingFunction ]
85
+ x : DataPoint , index : int , lfs : List [LabelingFunction ], f_caller : _FunctionCaller
64
86
) -> RowData :
65
87
"""Label a single data point with a set of LFs.
66
88
@@ -72,6 +94,8 @@ def apply_lfs_to_data_point(
72
94
Index of the data point
73
95
lfs
74
96
Set of LFs to label ``x`` with
97
+ f_caller
98
+ A ``_FunctionCaller`` to record failed LF executions
75
99
76
100
Returns
77
101
-------
@@ -80,7 +104,7 @@ def apply_lfs_to_data_point(
80
104
"""
81
105
labels = []
82
106
for j , lf in enumerate (lfs ):
83
- y = lf ( x )
107
+ y = f_caller ( lf , x )
84
108
if y >= 0 :
85
109
labels .append ((index , j , y ))
86
110
return labels
@@ -114,8 +138,12 @@ class LFApplier(BaseLFApplier):
114
138
"""
115
139
116
140
def apply (
117
- self , data_points : Union [DataPoints , np .ndarray ], progress_bar : bool = True
118
- ) -> np .ndarray :
141
+ self ,
142
+ data_points : Union [DataPoints , np .ndarray ],
143
+ progress_bar : bool = True ,
144
+ fault_tolerant : bool = False ,
145
+ return_meta : bool = False ,
146
+ ) -> Union [np .ndarray , Tuple [np .ndarray , ApplierMetadata ]]:
119
147
"""Label list of data points or a NumPy array with LFs.
120
148
121
149
Parameters
@@ -124,13 +152,23 @@ def apply(
124
152
List of data points or NumPy array to be labeled by LFs
125
153
progress_bar
126
154
Display a progress bar?
155
+ fault_tolerant
156
+ Output ``-1`` if LF execution fails?
157
+ return_meta
158
+ Return metadata from apply call?
127
159
128
160
Returns
129
161
-------
130
162
np.ndarray
131
163
Matrix of labels emitted by LFs
164
+ ApplierMetadata
165
+ Metadata, such as fault counts, for the apply call
132
166
"""
133
167
labels = []
168
+ f_caller = _FunctionCaller (fault_tolerant )
134
169
for i , x in tqdm (enumerate (data_points ), disable = (not progress_bar )):
135
- labels .append (apply_lfs_to_data_point (x , i , self ._lfs ))
136
- return self ._numpy_from_row_data (labels )
170
+ labels .append (apply_lfs_to_data_point (x , i , self ._lfs , f_caller ))
171
+ L = self ._numpy_from_row_data (labels )
172
+ if return_meta :
173
+ return L , ApplierMetadata (f_caller .fault_counts )
174
+ return L
0 commit comments