Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 73 additions & 10 deletions paddle/function/CrossMapNormalOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,51 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
}

/**
* \brief {o_0, o_1} = calc(i_0)
* \brief Normalization with across maps.
*
* \param inputs[0] input value.
* \param outputs[0] output value.
* \param outputs[1] denoms.
* This Function comes from the paper
* "ImageNet Classification with Deep Convolutional Neural Networks".
*
* The original formula is:
*
* Input(i, x, y)
* Output(i, x, y) = ----------------------------------------------
* -- upper
* (k + alpha * > (Input(j, x, y))^2) ^ (beta)
* -- j = lower
*
* upper is `min(C, c + N/2)`
* lower if `max(0, c - N/2)`
*
* Function implementation:
*
* inputs and outpus is NCHW format, while input.shape.ndims() is equal 4.
* And the meaning of each dimension(0-3) is respectively batch size,
* feature maps, rows and columns.
*
* Input and Output in the above formula is for each map(i) of one image, and
* Input(i, x, y), Output(i, x, y) represents an element in an image.
*
* C is the number of feature maps of one image, and N is a hyper-parameters
* is configured when Function is initialized. The sum in the denominator
* is the sum of the same position in the neighboring maps.
*
* In the implementation of Function, k is equal to 1,
* so Function has no argument for k.
*
* Function Arguments:
*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 看这个公式,无法看出 across channel (or maps)
  2. Input/Output是4D的tensor,注释里面如果忽略第一维的话, Input(x, y) 是不是应该写作 Input(c, x, y) (或者 Input[, , , ]) ?
  3. 求和符号, 直接用 sum 会不会更清晰些, 或者
   --
   \
   /
   --
  1. paper的公式里是 +k,而不是+1。
  2. alpha不用除N吧?

另外,如果比较复杂的公式,如果拆解开来写多行,会不会好些一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

* \param size_ represent N
* \param scale_ represent alpha
* \param pow_ represent beta
* \param inputs[0] represent Input
* \param outputs[0] represent Output
* \param outputs[1] represent The denominator in the formula(except beta)
*
* Note:
* Save output[1] is to simplify the backward calculation.
* TODO, if only consider the forward calculation, we can optimize to
* remove the output[1].
*/
template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase {
Expand Down Expand Up @@ -161,13 +201,36 @@ class CrossMapNormalFunc : public FunctionBase {
};

/**
* \brief {o_0} = calc(i_0, i_1, i_2, i_3)
* \brief Backward calculation for normalization with across maps.
*
* Function implementation:
*
* The implementation of this Function is derived from the
* CrossMapNormalFunc implementation.
*
* InputGrad = OutputGrad * denoms ^ (-beta)
* -- upper
* + > (OutputGrad * OutputValue * (-2 * alpha * beta) / denoms) * InputValue
* -- lower
*
* The data of inputs/outputs format is the same as the forward interface
* and is NCHW.
*
* The upper and lower is the same as forward. The logic of the sum
* is also the same as forward.
*
* Function Arguments:
*
* \param inputs[0] input value.
* \param inputs[1] output value.
* \param inputs[2] output grad.
* \param inputs[3] denoms.
* \param outputs[0] input grad.
* \param size_ represent N
* \param scale_ represent alpha
* \param pow_ represent beta
* \param inputs[0] represent InputValue, inputs[0] of CrossMapNormalFunc
* \param inputs[1] represent OutputValue, outputs[0] of CrossMapNormalFunc
* \param inputs[2] represent OutputGrad
* \param inputs[3] represent denoms, outputs[1] of CrossMapNormalFunc
* This is the intermediate result that is
* preserved in the forward calculation.
* \param outputs[0] represent InputGrad
*/
template <DeviceType Device>
class CrossMapNormalGradFunc : public FunctionBase {
Expand Down