From 4d82318520e08f3c56e10d14116ed45bf542716a Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sun, 8 Mar 2020 07:49:43 +0000 Subject: [PATCH] fix clip scalar --- python/mxnet/numpy/multiarray.py | 5 +++++ tests/python/unittest/test_numpy_op.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index ea0d51c65d0e..7fba4a6af088 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -6200,6 +6200,11 @@ def clip(a, a_min, a_max, out=None): >>> np.clip(a, 3, 6, out=a) array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32) """ + from numbers import Number + if isinstance(a, Number): + # In case input is a scalar, the computation would fall back to native numpy. + # The value returned would be a python scalar. + return _np.clip(a, a_min, a_max, out=None) return _mx_nd_np.clip(a, a_min, a_max, out=out) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index d1b91c1efd80..48bdeb55fa5f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3487,6 +3487,16 @@ def __init__(self, a_min=None, a_max=None): def hybrid_forward(self, F, x): return x.clip(self._a_min, self._a_max) + + # Test scalar case + for _, a_min, a_max, throw_exception in workloads: + a = _np.random.uniform() # A scalar + if throw_exception: + # No need to test the exception case here. + continue + mx_ret = np.clip(a, a_min, a_max) + np_ret = _np.clip(a, a_min, a_max) + assert_almost_equal(mx_ret, np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) for shape, a_min, a_max, throw_exception in workloads: for dtype in dtypes: