Skip to content

Conversation

@lcy-seso
Copy link
Contributor

@lcy-seso lcy-seso commented Sep 17, 2017

fixes #3866 #4366

  • Add a softmax with cross-entropy operator, only hard label is supported.
  • Support the soft labels.
  • One problem is that this PR conflicts with this one: Add soft-label support for cross-entropy operator. #4081. No matter which is merged first, the other has to manually solve the conflicts. and the conflicts are not easy to solve ...

@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch from c969be9 to 8f8ea00 Compare September 18, 2017 02:28
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename Place, typename T>
class SoftmaxFunctor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整个模板类都定义在.h里面,还需要在.cc.cu里面特例化吗?
另外,我看im2col.h/.cc/.cu里面也是定义的Im2ColFunctor,文件名没有加_function,是否要统一一下?

Copy link
Contributor Author

@lcy-seso lcy-seso Sep 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在在 .cc 和 .cu 中分别特化是为了解决之前遇到的一个问题:在GPU下需要同时编译CPU和GPU两套。现在CMAKE是这样写的:

nv_library(softmax_function SRCS softmax_function.cc softmax_function.cu 
    DEPS operator)

softmax 函数现在是用 Egien 来实现,GPU下编译也需要一个特殊的宏。如果用宏,怎样让同一份实现,被编译两次呢 ? 类似于现在 HOSTDEVICE 这个宏完成的功能,之前遇到一些问题,没有成功。我可以再查一下。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@Xreki Xreki Sep 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经过讨论,由于这个Functor是调用Eigen,必须明确地指引Eigen编译出CPU, GPU两个版本,所以需要特化。

See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件名utils.h太泛了,能不能用些较为有意义的名字呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会在 #4081 统一调整后,相应的修改此PR。

"Store the outputs of softmax function, "
"which will be used in backward calculation.")
.AsIntermediate();
AddOutput("Out", "A 1-D tensor<float> with shape N.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于输入输出的描述性文字,要不要按https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md 里的规范来,(type, default value) usage的形式。
另外Softmax我理解是softmax计算的结果,Outcross entropy的计算结果吗,是不是叫Loss或者Cost比较好?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同意,Loss 会更好一些。

using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

检查一下所有的输入输出Variable是否为NULL,因为OperatorBase的构造函数不会对这些进行检查。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 请问是检查每一个Var的name是否是 framework::kEmptyVarName 吗?
  • 是否有必要在前向和反向的InferShape中为涉及到的所有 Var 都加上这个检查呢?我试着加上这个检测的时候,直观的感觉发现,会让代码变得非常“冗余”。
  • 如果这个检查确实是必须的,是否有可能考虑一个更好的机制来保证呢?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • InferShape中检查,可直接检查InputVarOutputVar是否为NULL,直接PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "comments")即可。在组合Op中获取不到Variable,所以检查名字是否为Empty
  • 前向中,但凡需要访问到VariableTensor,最好都检查一下。反向不太懂。。。代码是会挺冗余的,InferShape里面一大半都是PADDLE_ENFORCE
  • 如果要设置一个机制,应该允许这样一些功能
    • 一些输入/输出变量允许为NULL
    • 变量B允许为NULL,但是在变量A不是NULL时也不能为NULL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白了。谢谢。

@lcy-seso
Copy link
Contributor Author

lcy-seso commented Sep 18, 2017

Haven't finished yet. Please DO NOT review this, Thanks, everyone.

@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch 4 times, most recently from 6890163 to d044ac3 Compare September 22, 2017 02:08
@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch from d044ac3 to f1d5fb3 Compare September 22, 2017 02:10
.reshape(batch_by_one)
.broadcast(one_by_class));
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以考虑加个SoftmaxGradFunctor?不过这个后面我来加也没关系。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • SoftmaxGradFunctor 在这个PR的功能中没有遇到,想尽量减少这个PR中的修改(改了太久...),新的PR再继续修改。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,我也赞成先merge这个PR 😄,SoftmaxGradFunctor我会用到,我再加进来即可。

framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
//(TODO caoying) replace int with boolean
AddAttr<int>("soft_label",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

属性名,是不是应该叫softLabel

AddInput(
"Label",
"(Tensor, default Tensor<int>), The ground truth which is "
"a 1-D or 2-D tensor. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[N, 1]也是2-D吧,所以Label始终应该是2-D Tensor

"which will be used in backward calculation.")
.AsIntermediate();
AddOutput("Loss",
"(Tensor, default Tensor<float>), A 1-D tensor. The cross "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loss也应该是2-D Tensor

"The label should be a 1-d tensor.");

ctx.Output<framework::LoDTensor>("Softmax")->Resize(logits->dims());
ctx.Output<framework::LoDTensor>("Loss")->Resize({logits->dims()[0], 1});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里看起来需要更新下,LoDTensor -> Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

class SoftmaxFunctor {
public:
void operator()(const framework::Tensor* X, framework::Tensor* Y,
const framework::ExecutionContext& context) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和其他的Functor统一一下,context都作为第一个参数吧。另外,这里能否也用platform::DeviceContext&类型?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • commit中先调整了参数的顺序。
  • 怎么从 DeviceContext 中以通用化方式得到 EigenDevice,现在还有一些问题,后续会续继续修改。

// Calculate ths softmax outputs.
const Tensor* logits = context.Input<Tensor>("Logits");
Tensor* softmax = context.Output<Tensor>("Softmax");
softmax->mutable_data<T>(context.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 59改成T* softmax_out = softmax->mutable_data<T>(context.GetPlace());,然后去掉Line 61

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

const int* label_data = context.Input<Tensor>("Label")->data<int>();
Tensor* loss = context.Output<Tensor>("Loss");
loss->mutable_data<T>(context.GetPlace());
T* loss_data = loss->data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 66 - 67合并为T* loss_data = loss->mutable_data<T>(context.GetPlace());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

math::SoftmaxFunctor<platform::CPUPlace, T>()(logits, softmax, context);

// Calculate the cross entropy loss based on hard labels.
T* softmax_out = softmax->data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu文件

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


Tensor* loss = context.Output<Tensor>("Loss");
loss->mutable_data<T>(context.GetPlace());
T* loss_data = loss->data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu文件

@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch 3 times, most recently from 64dc863 to 937b12a Compare September 26, 2017 03:04
@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch from 937b12a to 8b8ad6b Compare September 26, 2017 03:06
@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch from 873dd2c to 360bde9 Compare September 26, 2017 08:27
@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch from 09066fd to 3d77360 Compare September 26, 2017 09:56
.reshape(batch_by_one)
.broadcast(one_by_class));
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,我也赞成先merge这个PR 😄,SoftmaxGradFunctor我会用到,我再加进来即可。

AddOutput("Loss",
"(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
"entropy loss with shape [N x 1].");
AddComment(R"DOC(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

顺序:Input, Output, Attr, Comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用PADDLE_ENFORCE_EQ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

PADDLE_ENFORCE(
logits->dims().size() == 2UL,
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 2UL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用PADDLE_ENFORCE_EQ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
"Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Lable) should be not null.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Lable -> Label

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

"Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Lable) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpMaker里面声明LogitsNotInGradient,这里又取LogitsNotInGradient的意思是?我看GradOp里面会取LossLogits的梯度变量,是否SoftmaxLabel应该声明为NotInGradient类型?

Copy link
Contributor Author

@lcy-seso lcy-seso Sep 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 我之前对NotInGradient 的理解是这个input不在计算梯度时使用。
    • 比如,Logits 不参与梯度的计算,但是Logits的GRAD需要计算,这种情况下Logits是一个NotInGradient 的input。我再确认一下。
  • SoftmaxLabel现在都直接参与了梯度的计算,所以没有声明为NotInGradient

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我输出DebugString()看了下,如下:

Op(softmax_with_cross_entropy_grad), inputs:{Label[Label], Loss[Loss], Loss@GRAD[Loss@GRAD], Softmax[Softmax], Softmax@GRAD[Softmax@GRAD]}, outputs:{Label@GRAD[Label@GRAD], Logits@GRAD[Logits@GRAD]}

确实如你所说,NotInGradient是在GradOp中不需要用到的意思。
LabelSoftmax似乎是不需要GRAD变量的,我感觉应该有什么标志,可以描述这个变量不需要创建对应的GRAD变量。不过这是另外一个事儿了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xreki DebugString()这个方法能在文档如何写新的Operator中补充一下么,觉得很有用。


const Tensor* softmax = ctx.Input<Tensor>("Softmax");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 2UL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用PADDLE_ENFORCE_EQ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

"When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 -> 1UL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch 7 times, most recently from 77ef4c3 to fa86dd6 Compare September 27, 2017 01:02
@lcy-seso lcy-seso force-pushed the softmax_with_cross_entropy_op branch from fa86dd6 to 97509b6 Compare September 27, 2017 01:03
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

almost LGTM

@Xreki Xreki merged commit 29cb856 into PaddlePaddle:develop Sep 27, 2017
@lcy-seso lcy-seso deleted the softmax_with_cross_entropy_op branch September 27, 2017 02:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SoftmaxWithLoss Operators.

4 participants