@@ -8,80 +8,6 @@ struct AugmentationSelector{I} <: AbstractOutputSelector
88end
99(s:: AugmentationSelector )(out) = s. indices
1010
11- """
12- augment_batch_dim(input, n)
13-
14- Repeat each sample in input batch n-times along batch dimension.
15- This turns arrays of size `(..., B)` into arrays of size `(..., B*n)`.
16-
17- ## Example
18- ```julia-repl
19- julia> A = [1 2; 3 4]
20- 2×2 Matrix{Int64}:
21- 1 2
22- 3 4
23-
24- julia> augment_batch_dim(A, 3)
25- 2×6 Matrix{Int64}:
26- 1 1 1 2 2 2
27- 3 3 3 4 4 4
28- ```
29- """
30- function augment_batch_dim (input:: AbstractArray{T,N} , n) where {T,N}
31- return repeat (input; inner= (ntuple (Returns (1 ), N - 1 )... , n))
32- end
33-
34- """
35- reduce_augmentation(augmented_input, n)
36-
37- Reduce augmented input batch by averaging the explanation for each augmented sample.
38- """
39- function reduce_augmentation (input:: AbstractArray{T,N} , n) where {T<: AbstractFloat ,N}
40- # Allocate output array
41- in_size = size (input)
42- in_size[end ] % n != 0 &&
43- throw (ArgumentError (" Can't reduce augmented batch size of $(in_size[end ]) by $n " ))
44- out_size = (in_size[1 : (end - 1 )]. .. , div (in_size[end ], n))
45- out = similar (input, eltype (input), out_size)
46-
47- axs = axes (input, N)
48- colons = ntuple (Returns (:), N - 1 )
49- for (i, ax) in enumerate (first (axs): n: last (axs))
50- view (out, colons... , i) .=
51- dropdims (sum (view (input, colons... , ax: (ax + n - 1 )); dims= N); dims= N) / n
52- end
53- return out
54- end
55-
56- """
57- augment_indices(indices, n)
58-
59- Strip batch indices and return indices for batch augmented by n samples.
60-
61- ## Example
62- ```julia-repl
63- julia> inds = [CartesianIndex(5,1), CartesianIndex(3,2)]
64- 2-element Vector{CartesianIndex{2}}:
65- CartesianIndex(5, 1)
66- CartesianIndex(3, 2)
67-
68- julia> augment_indices(inds, 3)
69- 6-element Vector{CartesianIndex{2}}:
70- CartesianIndex(5, 1)
71- CartesianIndex(5, 2)
72- CartesianIndex(5, 3)
73- CartesianIndex(3, 4)
74- CartesianIndex(3, 5)
75- CartesianIndex(3, 6)
76- ```
77- """
78- function augment_indices (inds:: Vector{CartesianIndex{N}} , n) where {N}
79- indices_wo_batch = [i. I[1 : (end - 1 )] for i in inds]
80- return map (enumerate (repeat (indices_wo_batch; inner= n))) do (i, idx)
81- CartesianIndex {N} (idx... , i)
82- end
83- end
84-
8511"""
8612 NoiseAugmentation(analyzer, n)
8713 NoiseAugmentation(analyzer, n, std::Real)
@@ -104,38 +30,53 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
10430 n:: Int
10531 distribution:: D
10632 rng:: R
107- end
108- function NoiseAugmentation (analyzer, n, distribution:: Sampleable , rng= GLOBAL_RNG)
109- return NoiseAugmentation (analyzer, n, distribution:: Sampleable , rng)
33+
34+ function NoiseAugmentation (
35+ analyzer:: A , n:: Int , distribution:: D , rng:: R
36+ ) where {A<: AbstractXAIMethod ,D<: Sampleable ,R<: AbstractRNG }
37+ n < 2 &&
38+ throw (ArgumentError (" Number of noise samples `n` needs to be larger than one." ))
39+ return new {A,D,R} (analyzer, n, distribution, rng)
40+ end
11041end
11142function NoiseAugmentation (analyzer, n, std:: T = 1.0f0 , rng= GLOBAL_RNG) where {T<: Real }
11243 return NoiseAugmentation (analyzer, n, Normal (zero (T), std^ 2 ), rng)
11344end
45+ function NoiseAugmentation (analyzer, n, distribution:: Sampleable , rng= GLOBAL_RNG)
46+ return NoiseAugmentation (analyzer, n, distribution, rng)
47+ end
11448
11549function call_analyzer (input, aug:: NoiseAugmentation , ns:: AbstractOutputSelector ; kwargs... )
11650 # Regular forward pass of model
11751 output = aug. analyzer. model (input)
11852 output_indices = ns (output)
119-
120- # Call regular analyzer on augmented batch
121- augmented_input = add_noise (augment_batch_dim (input, aug. n), aug. distribution, aug. rng)
122- augmented_indices = augment_indices (output_indices, aug. n)
123- augmented_expl = aug. analyzer (augmented_input, AugmentationSelector (augmented_indices))
53+ output_selector = AugmentationSelector (output_indices)
54+
55+ # First augmentation
56+ input_aug = similar (input)
57+ input_aug = sample_noise! (input_aug, input, aug)
58+ expl_aug = aug. analyzer (input_aug, output_selector)
59+ sum_val = expl_aug. val
60+
61+ # Further augmentations
62+ for _ in 2 : (aug. n)
63+ input_aug = sample_noise! (input_aug, input, aug)
64+ expl_aug = aug. analyzer (input_aug, output_selector)
65+ sum_val += expl_aug. val
66+ end
12467
12568 # Average explanation
69+ val = sum_val / aug. n
70+
12671 return Explanation (
127- reduce_augmentation (augmented_expl. val, aug. n),
128- input,
129- output,
130- output_indices,
131- augmented_expl. analyzer,
132- augmented_expl. heatmap,
133- nothing ,
72+ val, input, output, output_indices, expl_aug. analyzer, expl_aug. heatmap, nothing
13473 )
13574end
13675
137- function add_noise (A:: AbstractArray{T} , distr:: Distribution , rng:: AbstractRNG ) where {T}
138- return A + T .(rand (rng, distr, size (A)))
76+ function sample_noise! (
77+ out:: A , input:: A , aug:: NoiseAugmentation
78+ ) where {T,A<: AbstractArray{T} }
79+ out .= input .+ rand (aug. rng, aug. distribution, size (input))
13980end
14081
14182"""
@@ -149,6 +90,13 @@ difference between the input and the reference input.
14990struct InterpolationAugmentation{A<: AbstractXAIMethod } <: AbstractXAIMethod
15091 analyzer:: A
15192 n:: Int
93+
94+ function InterpolationAugmentation (analyzer:: A , n:: Int ) where {A<: AbstractXAIMethod }
95+ n < 2 && throw (
96+ ArgumentError (" Number of interpolation steps `n` needs to be larger than one." ),
97+ )
98+ return new {A} (analyzer, n)
99+ end
152100end
153101
154102function call_analyzer (
@@ -160,57 +108,25 @@ function call_analyzer(
160108 # Regular forward pass of model
161109 output = aug. analyzer. model (input)
162110 output_indices = ns (output)
163-
164- # Call regular analyzer on augmented batch
165- augmented_input = interpolate_batch (input, input_ref, aug. n)
166- augmented_indices = augment_indices (output_indices, aug. n)
167- augmented_expl = aug. analyzer (augmented_input, AugmentationSelector (augmented_indices))
111+ output_selector = AugmentationSelector (output_indices)
112+
113+ # First augmentations
114+ input_aug = input_ref
115+ expl_aug = aug. analyzer (input_aug, output_selector)
116+ sum_val = expl_aug. val
117+
118+ # Further augmentations
119+ input_delta = (input - input_ref) / (aug. n - 1 )
120+ for _ in 1 : (aug. n)
121+ input_aug += input_delta
122+ expl_aug = aug. analyzer (input_aug, output_selector)
123+ sum_val += expl_aug. val
124+ end
168125
169126 # Average gradients and compute explanation
170- expl = (input - input_ref) .* reduce_augmentation (augmented_expl . val, aug. n)
127+ val = (input - input_ref) .* sum_val / aug. n
171128
172129 return Explanation (
173- expl,
174- input,
175- output,
176- output_indices,
177- augmented_expl. analyzer,
178- augmented_expl. heatmap,
179- nothing ,
130+ val, input, output, output_indices, expl_aug. analyzer, expl_aug. heatmap, nothing
180131 )
181132end
182-
183- """
184- interpolate_batch(x, x0, nsamples)
185-
186- Augment batch along batch dimension using linear interpolation between input `x` and a reference input `x0`.
187-
188- ## Example
189- ```julia-repl
190- julia> x = Float16.(reshape(1:4, 2, 2))
191- 2×2 Matrix{Float16}:
192- 1.0 3.0
193- 2.0 4.0
194-
195- julia> x0 = zero(x)
196- 2×2 Matrix{Float16}:
197- 0.0 0.0
198- 0.0 0.0
199-
200- julia> interpolate_batch(x, x0, 5)
201- 2×10 Matrix{Float16}:
202- 0.0 0.25 0.5 0.75 1.0 0.0 0.75 1.5 2.25 3.0
203- 0.0 0.5 1.0 1.5 2.0 0.0 1.0 2.0 3.0 4.0
204- ```
205- """
206- function interpolate_batch (
207- x:: AbstractArray{T,N} , x0:: AbstractArray{T,N} , nsamples
208- ) where {T,N}
209- in_size = size (x)
210- outs = similar (x, (in_size[1 : (end - 1 )]. .. , in_size[end ] * nsamples))
211- colons = ntuple (Returns (:), N - 1 )
212- for (i, t) in enumerate (range (zero (T), oneunit (T); length= nsamples))
213- outs[colons... , i: nsamples: end ] .= x0 + t * (x - x0)
214- end
215- return outs
216- end
0 commit comments