11from abc import ABC
2- from typing import TYPE_CHECKING , Callable , Union , Tuple , Optional
3- from typing_extensions import Literal
2+ from typing import TYPE_CHECKING , Any , Callable , List , Optional , Union
43
54import numpy as np
6- from tqdm import tqdm
7-
85from alibi .api .interfaces import Explainer
96from alibi .explainers .similarity .backends import _select_backend
10- from alibi .utils .frameworks import Framework
7+ from alibi .utils .frameworks import Framework , has_pytorch , has_tensorflow
8+ from alibi .utils .missing_optional_dependency import import_optional
9+ from tqdm import tqdm
10+ from typing_extensions import Literal
11+
12+ _TfTensor = import_optional ('tensorflow' , ['Tensor' ])
13+ _PtTensor = import_optional ('torch' , ['Tensor' ])
1114
1215if TYPE_CHECKING :
1316 import tensorflow
@@ -63,7 +66,7 @@ def __init__(self,
6366 super ().__init__ (meta = meta )
6467
6568 def fit (self ,
66- X_train : np .ndarray ,
69+ X_train : Union [ np .ndarray , List [ Any ]] ,
6770 Y_train : np .ndarray ) -> "Explainer" :
6871 """Fit the explainer. If ``self.precompute_grads == True`` then the gradients are precomputed and stored.
6972
@@ -79,21 +82,42 @@ def fit(self,
7982 self
8083 Returns self.
8184 """
82- self .X_train : np . ndarray = X_train
83- self .Y_train : np . ndarray = Y_train
84- self .X_dims : Tuple = self .X_train .shape [1 :]
85- self .Y_dims : Tuple = self .Y_train .shape [1 :]
86- self .grad_X_train : np . ndarray = np .array ([])
85+ self .X_train = X_train
86+ self .Y_train = Y_train
87+ self .X_dims = self .X_train .shape [1 :] if isinstance ( self . X_train , np . ndarray ) else None
88+ self .Y_dims = self .Y_train .shape [1 :]
89+ self .grad_X_train = np .array ([])
8790
8891 # compute and store gradients
8992 if self .precompute_grads :
9093 grads = []
94+ X : Union [np .ndarray , List [Any ]]
9195 for X , Y in tqdm (zip (self .X_train , self .Y_train ), disable = not self .verbose ):
92- grad_X_train = self ._compute_grad (X [ None ] , Y [None ])
96+ grad_X_train = self ._compute_grad (self . _format ( X ) , Y [None ])
9397 grads .append (grad_X_train [None ])
98+
9499 self .grad_X_train = np .concatenate (grads , axis = 0 )
95100 return self
96101
102+ @staticmethod
103+ def _is_tensor (x : Any ) -> bool :
104+ """Checks if an obejct is a tensor."""
105+ if has_tensorflow and isinstance (x , _TfTensor ):
106+ return True
107+ if has_pytorch and isinstance (x , _PtTensor ):
108+ return True
109+ if isinstance (x , np .ndarray ):
110+ return True
111+ return False
112+
113+ @staticmethod
114+ def _format (x : 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, Any]'
115+ ) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any]]' :
116+ """Adds batch dimension."""
117+ if BaseSimilarityExplainer ._is_tensor (x ):
118+ return x [None ]
119+ return [x ]
120+
97121 def _verify_fit (self ) -> None :
98122 """Verify that the explainer has been fitted.
99123
@@ -102,14 +126,15 @@ def _verify_fit(self) -> None:
102126 ValueError
103127 If the explainer has not been fitted.
104128 """
105-
106129 if not hasattr (self , 'X_train' ) or not hasattr (self , 'Y_train' ):
107130 raise ValueError ('Training data not set. Call `fit` and pass training data first.' )
108131
109132 def _match_shape_to_data (self ,
110- data : 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]' ,
111- target_type : Literal ['X' , 'Y' ]) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]' :
112- """Verify the shape of `data` against the shape of the training data.
133+ data : 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, Any, List[Any]]' ,
134+ target_type : Literal ['X' , 'Y' ]
135+ ) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any]]' :
136+ """
137+ Verify the shape of `data` against the shape of the training data.
113138
114139 Used to ensure input is correct shape for gradient methods implemented in the backends. `data` will be the
115140 features or label of the instance being explained. If the `data` is not a batch, reshape to be a single batch
@@ -131,6 +156,15 @@ def _match_shape_to_data(self,
131156 If the shape of `data` does not match the shape of the training data, or fit has not been called prior to
132157 calling this method.
133158 """
159+ if self ._is_tensor (data ):
160+ return self ._match_shape_to_data_tensor (data , target_type )
161+ return self ._match_shape_to_data_any (data )
162+
163+ def _match_shape_to_data_tensor (self ,
164+ data : 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]' ,
165+ target_type : Literal ['X' , 'Y' ]
166+ ) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]' :
167+ """ Verify the shape of `data` against the shape of the training data for tensor like data."""
134168 target_shape = getattr (self , f'{ target_type } _dims' )
135169 if data .shape == target_shape :
136170 data = data [None ]
@@ -139,6 +173,13 @@ def _match_shape_to_data(self,
139173 f' but training data has shape { target_shape } ' ))
140174 return data
141175
176+ @staticmethod
177+ def _match_shape_to_data_any (data : Union [Any , List [Any ]]) -> list :
178+ """ Ensures that any other data type is a list."""
179+ if isinstance (data , list ):
180+ return data
181+ return [data ]
182+
142183 def _compute_adhoc_similarity (self , grad_X : np .ndarray ) -> np .ndarray :
143184 """
144185 Computes the similarity between the gradients of the test instances and all the training instances. The method
@@ -149,18 +190,18 @@ def _compute_adhoc_similarity(self, grad_X: np.ndarray) -> np.ndarray:
149190 grad_X
150191 Gradients of the test instances.
151192 """
152- scores = np .zeros ((grad_X .shape [0 ], self .X_train .shape [0 ]))
193+ scores = np .zeros ((len (grad_X ), len (self .X_train )))
194+ X : Union [np .ndarray , List [Any ]]
153195 for i , (X , Y ) in tqdm (enumerate (zip (self .X_train , self .Y_train )), disable = not self .verbose ):
154- grad_X_train = self ._compute_grad (X [ None ] , Y [None ])
196+ grad_X_train = self ._compute_grad (self . _format ( X ) , Y [None ])
155197 scores [:, i ] = self .sim_fn (grad_X , grad_X_train [None ])[:, 0 ]
156198 return scores
157199
158200 def _compute_grad (self ,
159- X : 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]' ,
201+ X : 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any] ]' ,
160202 Y : 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]' ) \
161203 -> np .ndarray :
162204 """Computes predictor parameter gradients and returns a flattened `numpy` array."""
163-
164205 X = self .backend .to_tensor (X ) if isinstance (X , np .ndarray ) else X
165206 Y = self .backend .to_tensor (Y ) if isinstance (Y , np .ndarray ) else Y
166207 return self .backend .get_grads (self .predictor , X , Y , self .loss_fn )
0 commit comments