5656"""
5757 augment_indices(indices, n)
5858
59- Strip batch indices and return inidices for batch augmented by n samples.
59+ Strip batch indices and return indices for batch augmented by n samples.
6060
6161## Example
6262```julia-repl
@@ -83,11 +83,20 @@ function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
8383end
8484
8585"""
86- NoiseAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG])
87- NoiseAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG])
86+ NoiseAugmentation(analyzer, n)
87+ NoiseAugmentation(analyzer, n, std::Real)
88+ NoiseAugmentation(analyzer, n, distribution::Sampleable)
8889
89- A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`.
90+ A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from a scalar `distribution`.
9091This input augmentation is then averaged to return an `Explanation`.
92+
93+ Defaults to the normal distribution `Normal(0, std^2)` with `std=1.0f0`.
94+ For optimal results, $REF_SMILKOV_SMOOTHGRAD recommends setting `std` between 10% and 20% of the input range of each sample,
95+ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
96+
97+ ## Keyword arguments
98+ - `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
99+ Defaults to `GLOBAL_RNG`.
91100"""
92101struct NoiseAugmentation{A<: AbstractXAIMethod ,D<: Sampleable ,R<: AbstractRNG } < :
93102 AbstractXAIMethod
@@ -96,11 +105,11 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
96105 distribution:: D
97106 rng:: R
98107end
99- function NoiseAugmentation (analyzer, n, distr :: Sampleable , rng= GLOBAL_RNG)
100- return NoiseAugmentation (analyzer, n, distr :: Sampleable , rng)
108+ function NoiseAugmentation (analyzer, n, distribution :: Sampleable , rng= GLOBAL_RNG)
109+ return NoiseAugmentation (analyzer, n, distribution :: Sampleable , rng)
101110end
102- function NoiseAugmentation (analyzer, n, σ :: Real = 0.1f0 , args ... )
103- return NoiseAugmentation (analyzer, n, Normal (0.0f0 , Float32 (σ) ^ 2 ), args ... )
111+ function NoiseAugmentation (analyzer, n, std :: T = 1.0f0 , rng = GLOBAL_RNG) where {T <: Real }
112+ return NoiseAugmentation (analyzer, n, Normal (zero (T), std ^ 2 ), rng )
104113end
105114
106115function call_analyzer (input, aug:: NoiseAugmentation , ns:: AbstractOutputSelector ; kwargs... )
0 commit comments