Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions paddle/framework/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/string/to_string.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
Expand Down Expand Up @@ -205,9 +206,13 @@ All parameter, weight, gradient are variables in Paddle.
});
// clang-format on

py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());
py::class_<platform::GPUPlace>(m, "GPUPlace")
.def(py::init<int>())
.def("__str__", string::to_string<const platform::GPUPlace &>);

py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);

py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
m, "Operator");
Expand Down
41 changes: 23 additions & 18 deletions python/paddle/v2/framework/tests/gradient_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,26 @@ def product(dim):


class GradientChecker(unittest.TestCase):
def __is_close(self, numeric_grads, scope, max_relative_error):
def __is_close(self, numeric_grads, scope, max_relative_error, msg_prefix):
for name in numeric_grads:
op_grad = numpy.array(
scope.find_var(grad_var_name(name)).get_tensor())
is_close = numpy.allclose(
numeric_grads[name], op_grad, rtol=max_relative_error, atol=100)
if not is_close:
return False
return True
b = numpy.array(scope.find_var(grad_var_name(name)).get_tensor())
a = numeric_grads[name]

abs_a = numpy.abs(a)
# if abs_a is nearly zero, then use abs error for a, not relative
# error.
abs_a[abs_a < 1e-3] = 1

diff_mat = numpy.abs(a - b) / abs_a
max_diff = numpy.max(diff_mat)

def err_msg():
offset = numpy.argmax(diff_mat > max_relative_error)
return "%s Variable %s max gradient diff %f over limit %f, the first " \
"error element is %d" % (
msg_prefix, name, max_diff, max_relative_error, offset)

self.assertLessEqual(max_diff, max_relative_error, err_msg())

def check_grad(self,
forward_op,
Expand Down Expand Up @@ -145,7 +156,8 @@ def check_grad(self,
# get numeric gradient
for check_name in inputs_to_check:
numeric_grad[check_name] = \
get_numeric_gradient(forward_op, input_vars, output_name, check_name)
get_numeric_gradient(forward_op, input_vars, output_name,
check_name)

# get operator gradient according to different device
for place in places:
Expand Down Expand Up @@ -187,15 +199,8 @@ def check_grad(self,
backward_op.infer_shape(scope)
backward_op.run(scope, ctx)

if isinstance(place, core.CPUPlace):
msg = "CPU kernel gradient is not close to numeric gradient"
else:
if isinstance(place, core.GPUPlace):
msg = "GPU kernel gradient is not close to numeric gradient"
else:
raise ValueError("unknown place " + type(place))
self.assertTrue(
self.__is_close(numeric_grad, scope, max_relative_error), msg)
self.__is_close(numeric_grad, scope, max_relative_error,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要不改成 self.AssertIsClose ?
这个函数实际上干了Assert的事儿,而不是判断是不是 close

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"Gradient Check On %s" % str(place))


if __name__ == '__main__':
Expand Down