Skip to content

Commit ada4ea1

Browse files
committed
implement det sign
1 parent b3be92f commit ada4ea1

File tree

7 files changed

+78
-0
lines changed

7 files changed

+78
-0
lines changed

python/mxnet/ndarray/ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,14 @@ def sign(self, *args, **kwargs):
12461246
"""
12471247
return op.sign(self, *args, **kwargs)
12481248

1249+
def det_sign(self, *args, **kwargs):
1250+
"""Convenience fluent method for :py:func:`det_sign`.
1251+
1252+
The arguments are the same as for :py:func:`det_sign`, with
1253+
this array as data.
1254+
"""
1255+
return op.det_sign(self, *args, **kwargs)
1256+
12491257
def flatten(self, *args, **kwargs):
12501258
"""Convenience fluent method for :py:func:`flatten`.
12511259

python/mxnet/symbol/symbol.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,6 +1974,14 @@ def sign(self, *args, **kwargs):
19741974
"""
19751975
return op.sign(self, *args, **kwargs)
19761976

1977+
def det_sign(self, *args, **kwargs):
1978+
"""Convenience fluent method for :py:func:`det_sign`.
1979+
1980+
The arguments are the same as for :py:func:`det_sign`, with
1981+
this array as data.
1982+
"""
1983+
return op.det_sign(self, *args, **kwargs)
1984+
19771985
def flatten(self, *args, **kwargs):
19781986
"""Convenience fluent method for :py:func:`flatten`.
19791987

smd_hpi/tests/test_functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import mxnet as mx
2+
import numpy as np
3+
from mxnet import autograd
4+
from mxnet.test_utils import assert_almost_equal
5+
6+
7+
def test_det_sign():
8+
exp_y = np.array([1.0, 1.0, -1.0])
9+
exp_grad = np.array([1.0, 1.0, 1.0])
10+
11+
x = mx.nd.array([0.0, 0.6, -0.3])
12+
x.attach_grad()
13+
with autograd.record():
14+
y = x.det_sign()
15+
assert_almost_equal(exp_y, y.asnumpy())
16+
y.backward()
17+
assert_almost_equal(exp_grad, x.grad.asnumpy())

src/operator/mshadow_op.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,21 @@ struct sign : public mxnet_op::tunable {
299299

300300
MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));
301301

302+
/*! \brief used for generate element of sign */
303+
struct det_sign : public mxnet_op::tunable {
304+
template<typename DType>
305+
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
306+
Map(DType a) {
307+
if (a < DType(0)) return DType(-DType(1));
308+
return DType(1);
309+
}
310+
template<typename DType>
311+
MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type
312+
Map(DType a) {
313+
return DType(1);
314+
}
315+
};
316+
302317
/*! \brief used for generate element of power */
303318
MXNET_BINARY_MATH_OP(power, math::pow(a, b));
304319

src/operator/operator_tune.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs); // NOLINT()
270270
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign); // NOLINT()
271271
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign); // NOLINT()
272272
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad); // NOLINT()
273+
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::det_sign); // NOLINT()
274+
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::det_sign); // NOLINT()
273275
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round); // NOLINT()
274276
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor); // NOLINT()
275277
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc); // NOLINT()

src/operator/tensor/elemwise_unary_op_basic.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,25 @@ The storage type of ``sign`` output depends upon the input storage type:
687687

688688
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sign, unary_bwd<mshadow_op::sign_grad>);
689689

690+
// det_sign
691+
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP(det_sign, cpu, mshadow_op::det_sign)
692+
MXNET_ADD_SPARSE_OP_ALIAS(det_sign)
693+
.describe(R"code(Returns element-wise sign of the input (but with +1 for 0 values and Straigth Through Estimator).
694+
695+
Example::
696+
697+
det_sign([-2, 0, 3]) = [-1, 1, 1]
698+
699+
The storage type of ``det_sign`` output depends upon the input storage type:
700+
701+
- det_sign(default) = default
702+
- det_sign(row_sparse) = row_sparse
703+
704+
)code" ADD_FILELINE)
705+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_det_sign"});
706+
707+
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_det_sign, unary_bwd<mshadow_op::identity_grad>);
708+
690709
// round
691710
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(round, cpu, mshadow_op::round)
692711
.describe(R"code(Returns element-wise rounded value to the nearest integer of the input.

src/operator/tensor/elemwise_unary_op_basic.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,15 @@ NNVM_REGISTER_OP(_backward_sign)
159159
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
160160
gpu, unary_bwd<mshadow_op::sign_grad> >);
161161

162+
// det_sign
163+
NNVM_REGISTER_OP(det_sign)
164+
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::det_sign>)
165+
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::det_sign>);
166+
167+
NNVM_REGISTER_OP(_backward_det_sign)
168+
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
169+
gpu, unary_bwd<mshadow_op::identity_grad> >);
170+
162171
// round
163172
NNVM_REGISTER_OP(round)
164173
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::round>)

0 commit comments

Comments
 (0)