diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc index 2284beec11cd..af4abbee468e 100644 --- a/src/common/cuda/rtc.cc +++ b/src/common/cuda/rtc.cc @@ -150,13 +150,16 @@ CUfunction get_function(const std::string ¶meters, std::string(fp16_support_string) + "\n" + type_support_string + "\n" + util_string + "\n" + + limits + "\n" + special_functions_definitions + '\n' + vectorization_support_string + "\n" + function_definitions_util + "\n" + function_definitions_binary + "\n" + function_definitions_unary + "\n" + backward_function_definitions + "\n" + - reducer + "\n"; + grad_function_definitions + "\n" + + reducer + "\n" + + logic_reducer + "\n"; std::string code_with_header = common_header + parameters + code; // If verbose mode, output kernel source, though not including the common header if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) { diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h index 64ec2515f44c..cb1bae86fc88 100644 --- a/src/common/cuda/rtc/backward_functions-inl.h +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -237,6 +237,98 @@ backward_square(const DTypeGrad grad, const DType val) { return 2 * val * grad; } +template +__device__ inline DType div_rgrad(const DType val, + const DType2 val2) { + return -val / (val2 * val2); +} + +template +__device__ inline mixed_type +backward_clip(const DTypeGrad grad, const DType val, + const float a_min, const float a_max) { + if (val > a_max || val < a_min) { + return 0; + } else { + return grad; + } +} + +template +__device__ inline mixed_type +backward_reciprocal(const DTypeGrad grad, const DType val) { + return -grad / (val * val); +} + +template +__device__ inline mixed_type +backward_erf(const DTypeGrad grad, const DType val) { + using type = mixed_type; + const type v = val; + constexpr type my_pi = pi; + return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad; +} + +template +__device__ inline mixed_type +backward_erfinv(const DTypeGrad grad, const DType val) { + using type = mixed_type; + constexpr type my_pi = pi; + const type g = grad; + const type v = val; + return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g; +} + +template +__device__ inline mixed_type +backward_gamma(const DTypeGrad grad, const DType val) { + using type = mixed_type; + const type v = val; + if (type_util::is_same::value) { + return grad * op::gamma(v) * op::special_functions::cephes::psi(v); + } else { + return grad * op::gamma(v) * op::special_functions::cephes::psi(v); + } +} + +template +__device__ inline mixed_type +backward_gammaln(const DTypeGrad grad, const DType val) { + using type = mixed_type; + const type v = val; + if (type_util::is_same::value) { + return grad * op::special_functions::cephes::psi(v); + } else { + return grad * op::special_functions::cephes::psi(v); + } +} + +template +__device__ inline mixed_type +backward_digamma(const DTypeGrad grad, const DType val) { + using type = mixed_type; + const type v = val; + if (type_util::is_same::value) { + return grad * op::special_functions::trigamma(v); + } else { + return grad * op::special_functions::trigamma(v); + } +} + +template +__device__ inline mixed_type +backward_gelu(const DTypeGrad grad, const DType val) { + return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) + + val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f)); +} + +} // namespace op + +)code"; + +const char grad_function_definitions[] = R"code( +namespace op { + template __device__ inline mixed_type rdiv_grad(const DType val, @@ -252,12 +344,6 @@ div_grad(const DType val, return op::reciprocal(temp); } -template -__device__ inline DType div_rgrad(const DType val, - const DType2 val2) { - return -val / (val2 * val2); -} - template __device__ inline DType mod_grad(const DType val, const DType2 val2) { @@ -368,80 +454,6 @@ rldexp_grad(const DType val, return val2 * op::power(static_cast(2), val) * op::log(static_cast(2)); } -template -__device__ inline mixed_type -backward_clip(const DTypeGrad grad, const DType val, - const float a_min, const float a_max) { - if (val > a_max || val < a_min) { - return 0; - } else { - return grad; - } -} - -template -__device__ inline mixed_type -backward_reciprocal(const DTypeGrad grad, const DType val) { - return -grad / (val * val); -} - -template -__device__ inline mixed_type -backward_erf(const DTypeGrad grad, const DType val) { - const mixed_type v = val; - constexpr mixed_type my_pi = pi; - return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad; -} - -template -__device__ inline mixed_type -backward_erfinv(const DTypeGrad grad, const DType val) { - constexpr mixed_type my_pi = pi; - const mixed_type g = grad; - const mixed_type v = val; - return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g; -} - -template -__device__ inline mixed_type -backward_gamma(const DTypeGrad grad, const DType val) { - const mixed_type v = val; - if (type_util::is_same::value) { - return grad * op::gamma(v) * op::special_functions::cephes::psi(v); - } else { - return grad * op::gamma(v) * op::special_functions::cephes::psi(v); - } -} - -template -__device__ inline mixed_type -backward_gammaln(const DTypeGrad grad, const DType val) { - const mixed_type v = val; - if (type_util::is_same::value) { - return grad * op::special_functions::cephes::psi(v); - } else { - return grad * op::special_functions::cephes::psi(v); - } -} - -template -__device__ inline mixed_type -backward_digamma(const DTypeGrad grad, const DType val) { - const mixed_type v = val; - if (type_util::is_same::value) { - return grad * op::special_functions::trigamma(v); - } else { - return grad * op::special_functions::trigamma(v); - } -} - -template -__device__ inline mixed_type -backward_gelu(const DTypeGrad grad, const DType val) { - return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) + - val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f)); -} - template __device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) { auto bsq = scalar * scalar; @@ -467,8 +479,74 @@ __device__ inline DType prelu_grad(const DType val, return (val > 0) ? 0 : val; } -} // namespace op +template +__device__ inline mixed_type +gamma_implicit_grad(const DType a_in, const DType2 x_in) { + using OType = mixed_type; + const OType a = a_in; + const OType x = x_in; + if (x < 0.8f) { + OType numer = 1; + OType denom = a; + OType series1 = numer / denom; + OType series2 = numer / (denom * denom); +#pragma unroll + for (int i = 1; i <= 5; i++) { + numer *= -x / static_cast(i); + denom += 1; + series1 += numer / denom; + series2 += numer / (denom * denom); + } + OType pow_x_alpha = op::power(x, a); + OType gamma_pdf = op::power(x, a - 1) * op::exp(-x); + OType gamma_cdf = pow_x_alpha * series1; + OType gamma_cdf_alpha = + (op::log(x) - OType(special_functions::cephes::psi(a))) * + gamma_cdf - + pow_x_alpha * series2; + OType result = -gamma_cdf_alpha / gamma_pdf; + return op::isnan(result) ? 0.f : result; + } + if (a > 8.0f) { + if (0.9f * a <= x && x <= 1.1f * a) { + OType numer_1 = 1 + 24 * a * (1 + 12 * a); + OType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) - + 65 * x * x / a + a * (107 + 3600 * x); + OType denom = 1244160 * (a * a) * (a * a); + return numer_1 * numer_2 / denom; + } + OType denom = op::sqrt(8 * a); + OType term2 = denom / (a - x); + OType term3 = + op::power(x - a - a * op::log(x / a), static_cast(-1.5)); + OType term23 = (x < a) ? term2 - term3 : term2 + term3; + OType term1 = op::log(x / a) * term23 - + op::sqrt(2 / a) * (a + x) / ((a - x) * (a - x)); + OType stirling = 1.f + 1.f / (12.f * a) * (1.f + 1.f / (24.f * a)); + OType numer = x * term1; + return -stirling * numer / denom; + } + OType u = op::log(x / a); + OType v = op::log(a); + OType coef_uv[3][8] = { + {0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115, + 0.10406089, 0.0014179084}, + {0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465, + 0.020070113, -0.0035938915, -0.00058392623}, + {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642, + -0.0021309326, 0.00085092367, -1.5247877e-07}, + }; + OType coef_v[8]; +#pragma unroll + for (int i = 0; i < 8; i++) { + coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]); + } + OType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3])); + OType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7])); + return op::exp(p / q); +} +} // namespace op )code"; } // namespace rtc diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h index 9018a5d435d6..4f87db6f72d4 100644 --- a/src/common/cuda/rtc/forward_functions-inl.h +++ b/src/common/cuda/rtc/forward_functions-inl.h @@ -696,6 +696,10 @@ __device__ inline DType log_sigmoid(const DType val) { template __device__ inline DType softrelu(const DType val) { + // Avoid overflow of exp for large inputs. + // The threshold 20 is chosen such that softrelu(a) = a + // for a > 20 using floating precision. + if (val > 20) return val; if (type_util::has_double_or_integral::value) { return ::log(1 + ::exp(val)); } else { @@ -936,6 +940,11 @@ __device__ inline bool_t np_logical_not(const DType val) { return !static_cast(val); } +template +__device__ inline bool_t NonZero(const DType val) { + return val != 0; +} + #undef DEFINE_UNARY_MATH_FUNC template diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h index 259d0e060a57..f5b70d832594 100644 --- a/src/common/cuda/rtc/reducer-inl.h +++ b/src/common/cuda/rtc/reducer-inl.h @@ -27,11 +27,10 @@ namespace common { namespace cuda { namespace rtc { -const char reducer[] = R"code( +const char reducer[] = R"code( namespace red { -/*! \brief sum reducer */ struct sum { /*! \brief do reduction into dst */ template @@ -95,103 +94,6 @@ struct sum { } }; -/*! \brief maximum reducer */ -struct maximum { - /*! \brief do reduction into dst */ - template - __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*) - if (!util::isnan(dst)) { - if (!(dst >= src)) dst = src; - } - } - /*! \brief do reduction into dst */ - template - __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, - volatile DType& none) { - Reduce(dst, src); - } - /*! \brief combine the results of two reducers */ - template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { - Reduce(dst_val, src_val); - } - /*! \brief combine the results of two reducers */ - template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, - volatile DType& src_val, volatile DType& src_residual) { - Reduce(dst_val, src_val); - } - /*! \brief finalize reduction result */ - template - __device__ inline static void Finalize(volatile DType& dst) {} - /*! \brief finalize reduction result */ - template - __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} - /*! - *\brief set the initial value during reduction - */ - template - __device__ inline static void SetInitValue(DType &initv) { - initv = -2*DBL_MAX; - } - /*! - *\brief set the initial value during reduction - */ - template - __device__ inline static void SetInitValue(DType &initv, DType &none) { - SetInitValue(initv); - } -}; - -/*! \brief minimum reducer */ -struct minimum { - /*! \brief do reduction into dst */ - template - __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { - if (!util::isnan(dst)) { - if (!(dst <= src)) dst = src; - } - } - /*! \brief do reduction into dst */ - template - __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, - volatile DType& none) { - Reduce(dst, src); - } - /*! \brief combine the results of two reducers */ - template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { - Reduce(dst_val, src_val); - } - /*! \brief combine the results of two reducers */ - template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, - volatile DType& src_val, volatile DType& src_residual) { - Reduce(dst_val, src_val); - } - /*! \brief finalize reduction result */ - template - __device__ inline static void Finalize(volatile DType& dst) {} - /*! \brief finalize reduction result */ - template - __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} - /*! - *\brief set the initial value during reduction - */ - template - __device__ inline static void SetInitValue(DType &initv) { - initv = 2*DBL_MAX; - } - /*! - *\brief set the initial value during reduction - */ - template - __device__ inline static void SetInitValue(DType &initv, DType &none) { - SetInitValue(initv); - } -}; - -/*! \brief product reducer */ struct product { /*! \brief do reduction into dst */ template @@ -237,7 +139,6 @@ struct product { } }; -/*! \brief sum reducer that ignores NaN values in the input */ struct nansum { /*! \brief do reduction into dst */ template @@ -293,7 +194,6 @@ struct nansum { } }; -/*! \brief product reducer that ignores NaN values in the input */ struct nanprod { /*! \brief do reduction into dst */ template @@ -493,10 +393,222 @@ struct nrmlp { scale = 0; } }; -} // namespace red +} // namespace red )code"; +const char logic_reducer[] = R"code( +namespace red { + +struct maximum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*) + if (!util::isnan(dst)) { + if (!(dst >= src)) dst = src; + } + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = limits::NegInfValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +struct minimum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + if (!util::isnan(dst)) { + if (!(dst <= src)) dst = src; + } + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = limits::PosInfValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +struct argmax { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& dst, volatile DType src) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief do stable reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& dst, volatile DType src, + volatile DType&) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst, volatile DType& src) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst, volatile DType&, + volatile DType& src, volatile DType&) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType&) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv.num = limits::NegInfValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &) { + initv.num = limits::NegInfValue(); + } +}; + +struct argmin { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& dst, volatile DType src) { + if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief do stable reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& dst, volatile DType src, + volatile DType& residual) { + if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst, volatile DType& src) { + if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst, volatile DType&, + volatile DType& src, volatile DType&) { + if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& residual) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv.num = limits::PosInfValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &residual) { + initv.num = limits::PosInfValue(); + } +}; +} // namespace red +)code"; } // namespace rtc } // namespace cuda } // namespace common diff --git a/src/common/cuda/rtc/special_functions-inl.h b/src/common/cuda/rtc/special_functions-inl.h index d64afb51e2c1..7110e4737d0f 100644 --- a/src/common/cuda/rtc/special_functions-inl.h +++ b/src/common/cuda/rtc/special_functions-inl.h @@ -50,11 +50,6 @@ namespace rtc { // Direct inquiries to 30 Frost Street, Cambridge, MA 02140 // const char special_functions_definitions[] = R"code( -constexpr double DBL_MAX = 1.7976931348623157081e+308; -constexpr float FLT_MAX = 3.4028234663852885981e+38; -#define inf ((float)1e50) -#define nan (inf - inf) - namespace op { namespace special_functions { diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h index b4266030be1f..bafa8cf3f7e5 100644 --- a/src/common/cuda/rtc/util-inl.h +++ b/src/common/cuda/rtc/util-inl.h @@ -446,6 +446,144 @@ __device__ inline T strided_grouped_warp_allreduce(T value, OP redfun, const int } // namespace util )code"; + +const char limits[] = R"code( +constexpr double DBL_MAX = 1.7976931348623157081e+308; +constexpr float FLT_MAX = 3.4028234663852885981e+38; +#define inf ((float)1e50) +#define nan (inf - inf) + +namespace limits { + +template +__device__ inline DType MinValue(void); + +template<> +__device__ inline float MinValue(void) { + return -FLT_MAX; +} +/*! \brief minimum value of double */ +template<> +__device__ inline double MinValue(void) { + return -DBL_MAX; +} +/*! \brief minimum value of uint8 */ +template<> +__device__ inline uint8 MinValue(void) { + return 0; +} +/*! \brief minimum value of int8_t */ +template<> +__device__ inline int8 MinValue(void) { + return -128; +} +/*! \brief minimum value of int32 */ +template<> +__device__ inline int32 MinValue(void) { + return -2147483648; +} +/*! \brief minimum value of int64_t */ +template<> +__device__ inline int64 MinValue(void) { + return -9223372036854775808LL; +} +/*! \brief minimum value of bool */ +template<> +__device__ inline bool MinValue(void) { + return false; +} +/*! \brief minimum value of bool_t */ +template<> +__device__ inline bool_t MinValue(void) { + return MinValue(); +} + +/*! + * \brief negative infinity of certain types + * \tparam DType data type + */ +template +__device__ inline DType NegInfValue(void) { + return MinValue(); +} +/*! \brief negative infinity value of float */ +template<> +__device__ inline float NegInfValue(void) { + return -inf; +} +/*! \brief negative infinity value of double */ +template<> +__device__ inline double NegInfValue(void) { + return -inf; +} + +/*! + * \brief maximum value of certain types + * \tparam DType data type + */ +template +__device__ inline DType MaxValue(void); +/*! \brief maximum value of float */ +template<> +__device__ inline float MaxValue(void) { + return FLT_MAX; +} +/*! \brief maximum value of double */ +template<> +__device__ inline double MaxValue(void) { + return DBL_MAX; +} +/*! \brief maximum value of uint8 */ +template<> +__device__ inline uint8 MaxValue(void) { + return 255; +} +/*! \brief maximum value of int8 */ +template<> +__device__ inline int8 MaxValue(void) { + return 127; +} +/*! \brief maximum value of int32 */ +template<> +__device__ inline int32 MaxValue(void) { + return 2147483647; +} +/*! \brief maximum value of int64 */ +template<> +__device__ inline int64 MaxValue(void) { + return 9223372036854775807LL; +} +/*! \brief maximum value of bool */ +template<> +__device__ inline bool MaxValue(void) { + return true; +} +/*! \brief maximum value of bool_t */ +template<> +__device__ inline bool_t MaxValue(void) { + return MaxValue(); +} +/*! + * \brief positive infinity of certain types + * \tparam DType data type + */ +template +__device__ inline DType PosInfValue(void) { + return MaxValue(); +} +/*! \brief positive infinity value of float */ +template<> +__device__ inline float PosInfValue(void) { + return inf; +} +/*! \brief positive infinity value of double */ +template<> +__device__ inline double PosInfValue(void) { + return inf; +} + +} // namespace limits +)code"; } // namespace rtc } // namespace cuda } // namespace common diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index c33dad4601d7..be49f0fb326c 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -940,7 +940,7 @@ template<> MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map (mshadow::half::half_t a, mshadow::half::half_t b) { - return mshadow::half::half_t(-::floorf(static_cast(a/b))); + return mshadow::half::half_t(-::floorf(static_cast(a)/static_cast(b))); } struct rmod : public mxnet_op::tunable { @@ -1573,7 +1573,7 @@ struct argmax { /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*) - if (dst.num < src.num) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { dst.num = src.num; dst.idx = src.idx; } @@ -1581,7 +1581,7 @@ struct argmax { /*! \brief do stable reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) - if (dst.num < src.num) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { dst.num = src.num; dst.idx = src.idx; } @@ -1589,7 +1589,7 @@ struct argmax { /*! \brief combine the results of two reducers */ template MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) - if (dst_val.num < src_val.num) { + if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) { dst_val.num = src_val.num; dst_val.idx = src_val.idx; } @@ -1597,7 +1597,7 @@ struct argmax { /*! \brief combine the results of two reducers */ template MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) - if (dst_val.num < src_val.num) { + if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) { dst_val.num = src_val.num; dst_val.idx = src_val.idx; } diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h index da30192231c7..0df0db28fc65 100644 --- a/src/operator/nn/group_norm-inl.h +++ b/src/operator/nn/group_norm-inl.h @@ -113,24 +113,29 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, Tensor workspace; - size_t workspace_size = 0; - MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, { - workspace_size = - broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0], - red_src_shape, sizeof(DType)); - }); + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0], + red_src_shape); workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); // Calculate mean +#if !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( s, mean_, req[0], workspace, data_); - Tensor mean_data_tensor = mean_.FlatTo1D(s); - mean_data_tensor /= scalar(channel_size); }); }); +#else + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, mean_, req[0], workspace, + data_, "red::sum{}", NDim, "identity"); + }); +#endif // !defined(__CUDACC__) + MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, { + Tensor mean_data_tensor = mean_.FlatTo1D(s); + mean_data_tensor /= scalar(channel_size); + }); TBlob data_grp = data.reshape(temp_data_shape); const TBlob& mean_grp = mean.reshape(moments_shape); @@ -150,15 +155,25 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, // Calculate std const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape); +#if !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( s, std_, req[0], workspace, centered_out); - Tensor std_data_tensor = std_.FlatTo1D(s); - std_data_tensor = F(std_data_tensor / scalar(channel_size) - + scalar(param.eps)); }); }); +#else + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, std_, req[0], + workspace, centered_out, + "red::sum{}", NDim, "square"); + }); +#endif + MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, { + Tensor std_data_tensor = std_.FlatTo1D(s); + std_data_tensor = F(std_data_tensor / scalar(channel_size) + + scalar(param.eps)); + }); // Calculate data = data / std #if !defined(__CUDACC__) @@ -263,26 +278,17 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, // Initialize the workspace + Construct the temporary TBlobs Tensor workspace; - size_t reduce_workspace_size = 0; - size_t data_size = 0; - size_t red_out_size = 0; - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - data_size = sizeof(DType) * data.Size(); - red_out_size = sizeof(DType) * mean.Size(); - // There are two types of reduction workloads: reduce over axis and reduce exclude axis - // We take the maximum of the workspace sizes required by these workloads. - // Also, we explicitly set the req_type=kAddto in case we want to use it. - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_dst_shape, - kAddTo, red_src_shape, - sizeof(DType))); - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, - red_exclude_src_shape, - sizeof(DType))); - }); + size_t dtype_size = common::mshadow_type_info(outputs[0].type_flag_).size; + size_t data_size = data.Size() * dtype_size; + size_t red_out_size = mean.Size() * dtype_size; + // There are two types of reduction workloads: reduce over axis and reduce exclude axis + // We take the maximum of the workspace sizes required by these workloads. + // Also, we explicitly set the req_type=kAddto in case we want to use it. + size_t reduce_workspace_size = + std::max(broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape), + broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, + red_exclude_src_shape)); workspace = ctx.requested[0].get_space_typed( Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s); const TBlob normalized_data = @@ -300,14 +306,6 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute(attrs, ctx, {normalized_data, std_}, {kWriteTo}, {normalized_data}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {data_, mean_}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {normalized_data, std_}, - {kWriteTo}, {normalized_data}); -#endif // !defined(__CUDACC__) // Calculate grad_beta if (req[2] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { @@ -319,13 +317,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, }); } // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) -#if !defined(__CUDACC__) ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, {kWriteTo}, {ograd_mult}); -#else - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) if (req[1] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { @@ -335,6 +328,32 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, }); }); } +#else + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {data_, mean_}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {normalized_data, std_}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + if (req[2] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape), + req[2], workspace, ograd.reshape(red_exclude_src_shape), + "red::sum{}", NDim, "identity"); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape), + req[1], workspace, ograd_mult.reshape(red_exclude_src_shape), + "red::sum{}", NDim, "identity"); + }); + } +#endif // !defined(__CUDACC__) // Calculate grad_data: // ograd_mult = ograd * gamma / std @@ -350,15 +369,6 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute(attrs, ctx, {ograd_mult, std_}, {kWriteTo}, {ograd_mult}); -#else - BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, - {inputs[0], gamma}, - {kWriteTo}, - {ograd_mult.reshape(data.shape_)}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {ograd_mult, std_}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( @@ -368,19 +378,11 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(N); }); -#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {ograd_mult, red_out}, {req[0]}, {output_}); ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, {kWriteTo}, {ograd_mult}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {ograd_mult, red_out}, - {req[0]}, {output_}); - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( @@ -390,11 +392,38 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(-N); }); -#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {normalized_data, red_out}, {kAddTo}, {output_}); #else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {inputs[0], gamma}, + {kWriteTo}, + {ograd_mult.reshape(data.shape_)}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {ograd_mult, std_}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(N); + }); + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {output_}); + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(-N); + }); BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {normalized_data, red_out}, {kAddTo}, {output_}); diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index d8c8dbc7a2f5..79d09063ee6c 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -38,6 +38,7 @@ #include "../operator_common.h" #include "../mxnet_op.h" #include "../tensor/broadcast_reduce_op.h" +#include "mxnet/tuple.h" namespace mxnet { namespace op { @@ -115,14 +116,11 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, int channel_size = red_src_shape.Size() / red_dst_shape.Size(); // Initialize the workspace Tensor workspace; - size_t workspace_size = 0; - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - workspace_size = - broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0], - in_data.shape_, sizeof(DType)); - }); + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0], + in_data.shape_); workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) { common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for float16 inputs for LayerNorm. " @@ -145,15 +143,9 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, }); }); // Calculate data = data - mean -#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {inputs[0], outputs[layernorm::kMean]}, {kWriteTo}, {outputs[0]}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {inputs[0], outputs[layernorm::kMean]}, - {kWriteTo}, {outputs[0]}); -#endif // !defined(__CUDACC__) // Calculate std const TBlob centered_out = outputs[0].reshape(red_src_shape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { @@ -170,7 +162,6 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, + scalar(param.eps)); }); }); -#if !defined(__CUDACC__) // Calculate data = data / std BinaryBroadcastCompute(attrs, ctx, {outputs[0], outputs[layernorm::kStd]}, @@ -184,6 +175,30 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, {outputs[0], beta}, {kWriteTo}, {outputs[0]}); #else + // Calculate mean + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, mean_data, req[0], workspace, in_data, + "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor mean_data_tensor = mean_data.FlatTo1D(s); + mean_data_tensor /= scalar(channel_size); + }); + // Calculate data = data - mean + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {inputs[0], outputs[layernorm::kMean]}, + {kWriteTo}, {outputs[0]}); + // Calculate std + const TBlob centered_out = outputs[0].reshape(red_src_shape); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, std_data, req[0], workspace, centered_out, + "red::sum{}", NDim, "square"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor std_data_tensor = std_data.FlatTo1D(s); + std_data_tensor = F(std_data_tensor / scalar(channel_size) + + scalar(param.eps)); + }); // Calculate data = data / std BinaryBroadcastRTCCompute {"div"}(attrs, ctx, {outputs[0], outputs[layernorm::kStd]}, @@ -196,7 +211,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, BinaryBroadcastRTCCompute {"add"}(attrs, ctx, {outputs[0], beta}, {kWriteTo}, {outputs[0]}); -#endif // !defined(__CUDACC__) +#endif } template @@ -205,6 +220,26 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs); +template +void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const TBlob& gamma, + const TBlob& mean, + const TBlob& std, + const TBlob& normalized_data, + const TBlob& ograd_mult, + const TBlob& red_out, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& red_dst_shape, + const mxnet::TShape& red_src_shape, + const mxnet::TShape& red_exclude_dst_shape, + const mxnet::TShape& red_exclude_src_shape, + const int channel_size); + /* Calculate the gradient of layer normalization. We have the following gradient for gamma, beta and x: @@ -250,26 +285,17 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, int channel_size = red_src_shape.Size() / red_dst_shape.Size(); // Initialize the workspace + Construct the temporary TBlobs Tensor workspace; - size_t reduce_workspace_size = 0; - size_t data_size = 0; - size_t red_out_size = 0; - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - data_size = sizeof(DType) * data.Size(); - red_out_size = sizeof(DType) * mean.Size(); - // There are two types of reduction workloads: reduce over axis and reduce exclude axis - // We take the maximum of the workspace sizes required by these workloads. - // Also, we explicitly set the req_type=kAddto in case we want to use it. - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_dst_shape, - kAddTo, red_src_shape, - sizeof(DType))); - reduce_workspace_size = - std::max(reduce_workspace_size, + size_t dtype_size = common::mshadow_type_info(outputs[0].type_flag_).size; + size_t data_size = data.Size() * dtype_size; + size_t red_out_size = mean.Size() * dtype_size; + // There are two types of reduction workloads: reduce over axis and reduce exclude axis + // We take the maximum of the workspace sizes required by these workloads. + // Also, we explicitly set the req_type=kAddto in case we want to use it. + size_t reduce_workspace_size = + std::max(broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape), broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, - red_exclude_src_shape, - sizeof(DType))); - }); + red_exclude_src_shape)); workspace = ctx.requested[0].get_space_typed( Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s); const TBlob normalized_data = TBlob(workspace.dptr_ + reduce_workspace_size, @@ -278,135 +304,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, ograd.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id()); const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); - // Compute normalized_data = (data - mean) / std -#if !defined(__CUDACC__) - BinaryBroadcastCompute(attrs, ctx, - {data, mean}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, std}, - {kWriteTo}, {normalized_data}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {data, mean}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {normalized_data, std}, - {kWriteTo}, {normalized_data}); -#endif // !defined(__CUDACC__) - // Calculate grad_beta - bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); - if (req[2] != kNullOp) { - MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - if (!safe_acc) { - broadcast::Reduce( - s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, - ograd.reshape(red_exclude_src_shape)); - } else { - broadcast::Reduce( - s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, - ograd.reshape(red_exclude_src_shape)); - } - }); - }); - } - // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) -#if !defined(__CUDACC__) - ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, - {kWriteTo}, {ograd_mult}); -#else - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) - if (req[1] != kNullOp) { - MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - if (!safe_acc) { - broadcast::Reduce( - s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, - ograd_mult.reshape(red_exclude_src_shape)); - } else { - broadcast::Reduce( - s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, - ograd_mult.reshape(red_exclude_src_shape)); - } - }); - }); - } - // Calculate grad_data: - // ograd_mult = ograd * gamma / std - // grad_data = ograd_mult - mean(ograd_mult, axis) - // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) - if (req[0] != kNullOp) { -#if !defined(__CUDACC__) - BinaryBroadcastCompute(attrs, ctx, - {ograd, gamma}, - {kWriteTo}, {ograd_mult}); - BinaryBroadcastCompute(attrs, ctx, - {ograd_mult, std}, - {kWriteTo}, {ograd_mult}); -#else - BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, - {ograd, gamma}, - {kWriteTo}, {ograd_mult}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {ograd_mult, std}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - if (!safe_acc) { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); - } else { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); - } - }); - Tensor red_out_tensor = red_out.FlatTo1D(s); - red_out_tensor /= scalar(channel_size); - }); -#if !defined(__CUDACC__) - BinaryBroadcastCompute(attrs, ctx, - {ograd_mult, red_out}, - {req[0]}, {outputs[0]}); - ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, - {kWriteTo}, {ograd_mult}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {ograd_mult, red_out}, - {req[0]}, {outputs[0]}); - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - if (!safe_acc) { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); - } else { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); - } - }); - Tensor red_out_tensor = red_out.FlatTo1D(s); - red_out_tensor /= scalar(- channel_size); - }); -#if !defined(__CUDACC__) - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, red_out}, - {kAddTo}, {outputs[0]}); -#else - BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, - {normalized_data, red_out}, - {kAddTo}, {outputs[0]}); -#endif // !defined(__CUDACC__) - } + + LayerNormGradComputeGeneralImpl(attrs, ctx, ograd, data, gamma, mean, std, normalized_data, + ograd_mult, red_out, req, outputs, workspace, red_dst_shape, + red_src_shape, red_exclude_dst_shape, red_exclude_src_shape, + channel_size); } } // namespace op diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 08847205f155..1a040fa6f7d0 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -268,6 +268,122 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, LayerNormComputeGeneral(attrs, ctx, inputs, req, outputs); } +template <> +void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const TBlob& gamma, + const TBlob& mean, + const TBlob& std, + const TBlob& normalized_data, + const TBlob& ograd_mult, + const TBlob& red_out, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& red_dst_shape, + const mxnet::TShape& red_src_shape, + const mxnet::TShape& red_exclude_dst_shape, + const mxnet::TShape& red_exclude_src_shape, + const int channel_size) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + // Compute normalized_data = (data - mean) / std + BinaryBroadcastCompute(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); + if (req[2] != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + if (!safe_acc) { + broadcast::Reduce( + s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape)); + } else { + broadcast::Reduce( + s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape)); + } + }); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + if (!safe_acc) { + broadcast::Reduce( + s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape)); + } else { + broadcast::Reduce( + s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape)); + } + }); + }); + } + // Calculate grad_data: + // ograd_mult = ograd * gamma / std + // grad_data = ograd_mult - mean(ograd_mult, axis) + // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) + if (req[0] != kNullOp) { + BinaryBroadcastCompute(attrs, ctx, + {ograd, gamma}, + {kWriteTo}, {ograd_mult}); + BinaryBroadcastCompute(attrs, ctx, + {ograd_mult, std}, + {kWriteTo}, {ograd_mult}); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + if (!safe_acc) { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } else { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } + }); + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(channel_size); + }); + BinaryBroadcastCompute(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {outputs[0]}); + ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + if (!safe_acc) { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } else { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } + }); + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(- channel_size); + }); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {outputs[0]}); + } +} + template<> void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu index a60df412299a..9a33e0665ff4 100644 --- a/src/operator/nn/layer_norm.cu +++ b/src/operator/nn/layer_norm.cu @@ -29,6 +29,89 @@ using namespace mshadow::cuda; namespace mxnet { namespace op { +template <> +void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const TBlob& gamma, + const TBlob& mean, + const TBlob& std, + const TBlob& normalized_data, + const TBlob& ograd_mult, + const TBlob& red_out, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& red_dst_shape, + const mxnet::TShape& red_src_shape, + const mxnet::TShape& red_exclude_dst_shape, + const mxnet::TShape& red_exclude_src_shape, + const int channel_size) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + // Compute normalized_data = (data - mean) / std + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + if (req[2] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape), "red::sum{}", NDim, "identity"); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape), "red::sum{}", NDim, + "identity"); + }); + } + // Calculate grad_data: + // ograd_mult = ograd * gamma / std + // grad_data = ograd_mult - mean(ograd_mult, axis) + // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) + if (req[0] != kNullOp) { + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {ograd, gamma}, + {kWriteTo}, {ograd_mult}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {ograd_mult, std}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(channel_size); + }); + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {outputs[0]}); + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(- channel_size); + }); + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {outputs[0]}); + } +} template __device__ __forceinline__ DType warp_shfl(DType value, int src_lane, int width = 32, unsigned int mask = 0xffffffff) { diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h index ca78b65bf1ec..78c7e4a1cd44 100644 --- a/src/operator/nn/moments-inl.h +++ b/src/operator/nn/moments-inl.h @@ -126,7 +126,12 @@ inline void MomentsForwardImpl(const OpContext& ctx, small = ReduceAxesShapeImpl(inputs[0].shape_, axes, true, false); } +#if !defined(__CUDACC__) ReduceAxesComputeImpl(ctx, {data}, {req[0]}, {mean}, small); +#else + ReduceAxesRTCComputeImpl(ctx, {data}, {req[0]}, {mean}, small, "red::sum{}", nullptr, true); +#endif + TBlob temp; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { Shape<6> data_shape, mean_shape; for (int i = 0; i < 6; ++i) { @@ -137,9 +142,15 @@ inline void MomentsForwardImpl(const OpContext& ctx, ctx.requested[0].get_space_typed(Shape1(data.shape_.Size()), s);; Kernel::Launch(s, data.shape_.Size(), temp_data.dptr_, data.dptr(), mean.dptr(), data_shape, mean_shape); - ReduceAxesComputeImpl( - ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small); + temp = TBlob(temp_data); }); +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( + ctx, {temp.reshape(data.shape_)}, {kWriteTo}, {var}, small); +#else + ReduceAxesRTCComputeImpl(ctx, {temp.reshape(data.shape_)}, + {kWriteTo}, {var}, small, "red::sum{}", nullptr, true); +#endif } template diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh deleted file mode 100644 index d4374edc9828..000000000000 --- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh +++ /dev/null @@ -1,415 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015-2020 by Contributors - * \file broadcast_reduce_customized-inl.cuh - * \brief Customized CUDA implementations for binary broadcast and reduce - * \author MXNet contributors -*/ -#ifndef MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ -#define MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ - -#include "../../tensor/broadcast_reduce-inl.cuh" - -using namespace mshadow::cuda; - -template -__launch_bounds__(nthread_reduce) -__global__ void reduce_kernel_wr(const int N, const int M, const bool addto, - const DType* __restrict big, OType *small, - const Shape big_shape0, const Shape small_shape, - const Shape big_shape, const Shape big_stride, - const int Mnext, const bool do_transpose, - Reducer* reducer) { - extern __shared__ char shTileChar[]; - AType* shTile = (AType*)(shTileChar); - const int tid = threadIdx.x + threadIdx.y*blockDim.x; - const int bx = (do_transpose) ? blockDim.y : blockDim.x; - const int by = (do_transpose) ? blockDim.x : blockDim.y; - const int tidx = (do_transpose) ? tid / by : threadIdx.x; - const int tidy = (do_transpose) ? tid % by : threadIdx.y; - // bool need_clean = !reducer; - // reducer = reducer ? reducer : new Reducer(); - for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { - // This TB handles M range [Mstart, ...., Mend - 1] - const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); - const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); - for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - Shape coord = unravel(idx, small_shape); - int idx_big0 = ravel(coord, big_shape0); - - AType val, residual; - reducer->SetInitValue(val, residual); - if (idx < N) { - for (int k = tidy + Mstart; k < Mend; k += by*unroll) { - int idx_big[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride); - } - DType tmp[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) { - tmp[u] = OP::Map(big[idx_big[u]]); - } - } - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) reducer->Reduce(val, AType(tmp[u]), residual); - } - } - } - - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - AType tmp, tmp_residual; - reducer->SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2])); - } - } else { - if (idx < N) { - reducer->Finalize(val, residual); - assign(&small[idx + m0*N], addto, OType(val)); - } - } - } - } - // if (need_clean) { - // delete reducer; - // } -} - -template -__launch_bounds__(nthread_reduce) -__global__ void reduce_kernel_wr(const int N, const int M, const bool addto, - const DType* __restrict big, const DType* __restrict lhs, - const DType* __restrict rhs, DType *small, - const Shape big_shape0, const Shape lhs_shape0, - const Shape rhs_shape0, const Shape small_shape, - const Shape big_shape, const Shape lhs_shape, - const Shape rhs_shape, const Shape big_stride, - const Shape lhs_stride, const Shape rhs_stride, - const int Mnext, const bool do_transpose, - Reducer* reducer) { - extern __shared__ char shTileChar[]; - DType* shTile = (DType*)(shTileChar); - const int tid = threadIdx.x + threadIdx.y*blockDim.x; - const int bx = (do_transpose) ? blockDim.y : blockDim.x; - const int by = (do_transpose) ? blockDim.x : blockDim.y; - const int tidx = (do_transpose) ? tid / by : threadIdx.x; - const int tidy = (do_transpose) ? tid % by : threadIdx.y; - // bool need_clean = !reducer; - // reducer = reducer ? reducer : new Reducer(); - for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { - // This TB handles M range [Mstart, ...., Mend - 1] - const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); - const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); - for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - Shape coord = unravel(idx, small_shape); - int idx_big0 = ravel(coord, big_shape0); - int idx_lhs0 = ravel(coord, lhs_shape0); - int idx_rhs0 = ravel(coord, rhs_shape0); - - DType val, residual; - reducer->SetInitValue(val, residual); - if (idx < N) { - for (int k = tidy + Mstart; k < Mend; k += by*unroll) { - int idx_big[unroll]; - int idx_lhs[unroll]; - int idx_rhs[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride); - idx_lhs[u] = idx_lhs0 + unravel_dot(k + u*by, lhs_shape, lhs_stride); - idx_rhs[u] = idx_rhs0 + unravel_dot(k + u*by, rhs_shape, rhs_stride); - } - DType tmp[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) { - tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]])); - } - } - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) reducer->Reduce(val, tmp[u], residual); - } - } - } - - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - DType tmp, tmp_residual; - reducer->SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, shTile[tidx * 2]); - } - } else { - if (idx < N) { - reducer->Finalize(val, residual); - assign(&small[idx + m0*N], addto, val); - } - } - } - } - // if (need_clean) { - // delete reducer; - // } -} - -// Simple reduction of lines when M is small -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_lines_kernel_wr(const int N, const int M, const bool addto, - const int small_in_stride, const DType* __restrict small_in, DType *small_out, - Reducer* reducer) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - - DType val, residual; - reducer->SetInitValue(val, residual); - for (int k = 0; k < M; k++) { - reducer->Reduce(val, small_in[idx + k*small_in_stride], residual); - } - - if (idx < N) { - reducer->Finalize(val, residual); - assign(&small_out[idx], addto, val); - } - - } -} - -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_kernel_M1_wr(const int N, const bool addto, - const DType* __restrict big, OType *small, const Shape bshape, - const Shape sshape, Reducer* reducer) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = unravel(idx, sshape); - int j = ravel(coord, bshape); - AType val, residual; - reducer->SetInitValue(val, residual); - reducer->Reduce(val, AType(OP::Map(big[j])), residual); - reducer->Finalize(val, residual); - assign(&small[idx], addto, OType(val)); - } -} - -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_kernel_M1_wr(const int N, const bool addto, - const DType* __restrict big, - const DType* __restrict lhs, - const DType* __restrict rhs, - DType *small, - const Shape big_shape, - const Shape lhs_shape, - const Shape rhs_shape, - const Shape small_shape, - Reducer* reducer) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = unravel(idx, small_shape); - int idx_big = ravel(coord, big_shape); - int idx_lhs = ravel(coord, lhs_shape); - int idx_rhs = ravel(coord, rhs_shape); - DType val, residual; - reducer->SetInitValue(val, residual); - reducer->Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); - reducer->Finalize(val, residual); - assign(&small[idx], addto, val); - } -} - -#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \ - if (do_unroll) { \ - const int unrollVar = unrollAmount; \ - {__VA_ARGS__} \ - } else { \ - const int unrollVar = 1; \ - {__VA_ARGS__} \ - } - -template -void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const OpReqType req, - const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config, - Reducer* reducer = nullptr) { - bool need_clean = !reducer; - reducer = reducer ? reducer : new Reducer(); - if (config.M == 1) { - reduce_kernel_M1_wr - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), - small.shape_.get(), reducer); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr); - } else { - OType* small_dptr = small.dptr(); - bool addto = (req == kAddTo); - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); - addto = false; - // Check that the workspace is contigiuous - CHECK_EQ(workspace.CheckContiguous(), true); - // Check that we have enough storage - CHECK_GE(workspace.size(0), config.workspace_size); - } - - const int by = (config.kernel_1.do_transpose) ? - config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { - reduce_kernel_wr - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( - config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), - small.shape_.get(), config.rshape.get(), config.rstride.get(), - config.Mnext, config.kernel_1.do_transpose, reducer); - }); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr); - - if (config.Mnext > 1) { - reduce_lines_kernel_wr - <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr(), reducer); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr); - } - } - if (need_clean) { - delete reducer; - } -} - -template -void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs, - const OpReqType req, const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config, Reducer* reducer = nullptr) { - bool need_clean = !reducer; - reducer = reducer ? reducer : new Reducer(); - if (config.M == 1) { - reduce_kernel_M1_wr - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), lhs.dptr(), rhs.dptr(), - small.dptr(), big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr); - } else { - DType* small_dptr = small.dptr(); - bool addto = (req == kAddTo); - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); - addto = false; - // Check that the workspace is contigiuous - CHECK_EQ(workspace.CheckContiguous(), true); - // Check that we have enough storage - CHECK_GE(workspace.size(0), config.workspace_size); - } - - const int by = (config.kernel_1.do_transpose) ? - config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { - reduce_kernel_wr - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( - config.N, config.M, addto, big.dptr(), lhs.dptr(), rhs.dptr(), - small_dptr, big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get(), config.rshape, config.lhs_shape, - config.rhs_shape, config.rstride, config.lhs_stride, config.rhs_stride, config.Mnext, - config.kernel_1.do_transpose, reducer); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr); - }); - - if (config.Mnext > 1) { - reduce_lines_kernel_wr - <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr(), reducer); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr); - } - } - if (need_clean) { - delete reducer; - } -} - -#undef KERNEL_UNROLL_SWITCH - -template -void ReduceWithReducer(Stream *s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big, Reducer* reducer = nullptr) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - bool need_clean = !reducer; - reducer = reducer ? reducer : new Reducer(); - ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); - if (safe_acc) { - MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { - typedef typename std::conditional::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { - typedef typename std::conditional::type OutType; - config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr, sizeof(AccType)); - ReduceImplWithReducer( - stream, small, req, big, workspace, config, reducer); - }); - }); - } else { - ReduceImplWithReducer(stream, small, req, big, workspace, config, reducer); - } - if (need_clean) { - delete reducer; - } -} - -#endif // MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h index 0226df45f960..2941d54fb56c 100644 --- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h +++ b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h @@ -54,12 +54,6 @@ MSHADOW_XINLINE void seq_reduce_assign_wr(const index_t idx, const size_t M, con assign(&small[idx], addto, OType(val)); } -#ifdef __CUDACC__ -#include "broadcast_reduce_customized-inl.cuh" -#include "../../tensor/broadcast_reduce-inl.cuh" - -#else - template void seq_reduce_compute_wr(const size_t N, const size_t M, const bool addto, const DType *big, OType *small, const Shape bshape, @@ -177,7 +171,6 @@ void ReduceWithReducer(Stream *s, const TBlob& small, const OpReqType req, reducer); } -#endif } // namespace broadcast } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h index 8e1c0b3db18d..976991f30f88 100644 --- a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h +++ b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h @@ -46,19 +46,17 @@ void ReduceAxesComputeImplWithReducer(const OpContext& ctx, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { - const TBlob in_data = inputs[0].reshape(src_shape); - const TBlob out_data = outputs[0].reshape(dst_shape); - BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_, sizeof(OType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - broadcast::ReduceWithReducer( - s, out_data, req[0], workspace, in_data, reducer); - // no normalization - }); + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + broadcast::ReduceWithReducer( + s, out_data, req[0], workspace, in_data, reducer); + // no normalization }); }); } diff --git a/src/operator/numpy/linalg/np_norm-inl.h b/src/operator/numpy/linalg/np_norm-inl.h index b26e68086852..60dee6ac3492 100644 --- a/src/operator/numpy/linalg/np_norm-inl.h +++ b/src/operator/numpy/linalg/np_norm-inl.h @@ -285,18 +285,10 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs, } else if (param.ord == std::numeric_limits::infinity()) { // inf norm LOG(FATAL) << "inf norm handled in front-end."; } else { +#ifndef __CUDACC__ mshadow_op::nrmlp host_reducer(param.ord); mshadow_op::nrmlp *reducer_instance = nullptr; -#ifdef __CUDACC__ - Stream *s = ctx.get_stream(); - cudaStream_t copy_stream = mshadow::Stream::GetStream(s); - cudaMalloc(reinterpret_cast(&reducer_instance), sizeof(mshadow_op::nrmlp)); - cudaMemcpyAsync(reducer_instance, &host_reducer, sizeof(mshadow_op::nrmlp), - cudaMemcpyHostToDevice, copy_stream); - cudaStreamSynchronize(copy_stream); -#else reducer_instance = &host_reducer; -#endif if (safe_acc) { ReduceAxesComputeImplWithReducer( ctx, inputs, req, outputs, small, reducer_instance); @@ -304,8 +296,10 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImplWithReducer( ctx, inputs, req, outputs, small, reducer_instance); } -#ifdef __CUDACC__ - cudaFree(reducer_instance); +#else + ReduceAxesRTCComputeImpl( + ctx, inputs, req, outputs, small, "red::nrmlp{" + std::to_string(param.ord) + "}", + nullptr, false, "abs"); #endif } } @@ -443,8 +437,13 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, } if (param.flag == 1) { // Frobenius norm - ReduceAxesComputeImplWithReducer( +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( ctx, inputs, req, outputs, reduced_shape); +#else + ReduceAxesRTCComputeImpl( + ctx, inputs, req, outputs, reduced_shape, "red::nrm2{}", nullptr, false, "identity"); +#endif return; } @@ -453,19 +452,29 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, if (param.ord != 2 && param.ord != -2) { // row norm or col norm TShape sum_shape = inputs[0].shape_; sum_shape[mat_axis[!(param.ord == 1 || param.ord == -1)]] = 1; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - TBlob temp = outputs[1].reshape(sum_shape); - std::vector sum_output({temp}); - ReduceAxesComputeImpl( - ctx, inputs, req, sum_output, sum_shape); - if (param.ord > 0) { - ReduceAxesComputeImpl( - ctx, sum_output, req, outputs, reduced_shape); - } else { - ReduceAxesComputeImpl( - ctx, sum_output, req, outputs, reduced_shape); - } - }); + TBlob temp = outputs[1].reshape(sum_shape); + std::vector sum_output({temp}); +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( + ctx, inputs, req, sum_output, sum_shape); + if (param.ord > 0) { + ReduceAxesComputeImpl( + ctx, sum_output, req, outputs, reduced_shape); + } else { + ReduceAxesComputeImpl( + ctx, sum_output, req, outputs, reduced_shape); + } +#else + ReduceAxesRTCComputeImpl(ctx, inputs, req, sum_output, sum_shape, + "red::sum{}", nullptr, false, "abs"); + if (param.ord > 0) { + ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, + "red::maximum{}", nullptr, false); + } else { + ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, + "red::minimum{}", nullptr, false); + } +#endif return; } @@ -500,6 +509,7 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, L_trans[mat_axis[1]] = 1; } + std::vector eigen; MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { Tensor UT = outputs[1].get_with_shape(Shape3(batch_dim, row_dim, row_dim), s); @@ -523,32 +533,46 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, Tensor svd_input = workspace.get_with_shape(Shape3(batch_dim, row_dim, col_dim), s); gesvd::op(svd_input, UT, L, V, ctx, attrs, &svd_workspace); - TBlob workspace0(reinterpret_cast(temp.dptr_), L_trans, temp.dev_mask(), temp.dev_id()); TransposeImpl(ctx.run_ctx, TBlob(L).reshape(L_shape), workspace0, reduce_axes); - std::vector eigen({ workspace0 }); - if (param.flag == 2) { // nuclear norm - ReduceAxesComputeImpl( + eigen.emplace_back(workspace0); + }); + +#if !defined(__CUDACC__) + if (param.flag == 2) { // nuclear norm + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true)) { + if (ord == 2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (ord == -2) { + ReduceAxesComputeImpl( ctx, eigen, req, outputs, reduced_shape); - } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true)) { - if (ord == 2) { - ReduceAxesComputeImpl( - ctx, eigen, req, outputs, reduced_shape); - } else if (ord == -2) { - ReduceAxesComputeImpl( - ctx, eigen, req, outputs, reduced_shape); - } - } else { - if (ord == 2) { - ReduceAxesComputeImpl( - ctx, eigen, req, outputs, reduced_shape); - } else if (ord == -2) { - ReduceAxesComputeImpl( - ctx, eigen, req, outputs, reduced_shape); - } } - }); + } else { + if (ord == 2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (ord == -2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } + } +#else + if (param.flag == 2) { // nuclear norm + ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, "red::sum{}", nullptr, false); + } else { + if (ord == 2) { + ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, + "red::maximum{}", nullptr, false, "abs"); + } else if (ord == -2) { + ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, + "red::minimum{}", nullptr, false, "abs"); + } + } +#endif } template @@ -784,8 +808,13 @@ void NumpyNormComputeForward(const nnvm::NodeAttrs& attrs, std::vector flat_outputs({ outputs[0].reshape(TShape(1, 1)) }); - ReduceAxesComputeImplWithReducer( +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( ctx, flat_inputs, req, flat_outputs, TShape(1, 1)); +#else + ReduceAxesRTCComputeImpl( + ctx, flat_inputs, req, flat_outputs, TShape(1, 1), "red::nrm2{}", nullptr, false, "identity"); +#endif return; } diff --git a/src/operator/numpy/np_broadcast_reduce_op.cc b/src/operator/numpy/np_broadcast_reduce_op.cc new file mode 100644 index 000000000000..4b64a1a29169 --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file np_broadcast_reduce_op.cc + * \brief Function definitions of NumPy-compatible + * broadcast and reduce operators + */ + +#include "np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { +#if MXNET_USE_CUDA + +void NumpyArgMinMaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + if (req[0] == kNullOp) return; + // parse param + const auto& param = nnvm::get(attrs.parsed); + mshadow::Stream *s = ctx.get_stream(); + TBlob out = outputs[0]; + TBlob in = inputs[0]; + // do some shape checks + if (in.shape_.ndim() != 0) { + if (param.axis.has_value()) { + // cannot do argmax in an empty dimension + int axis = param.axis.value(); + axis = CheckAxis(axis, in.shape_.ndim()); + CHECK_NE(in.shape_[axis], 0) + << "searching input tensor of shape " << inputs[0].shape_ + << " along axis = " << axis << " of zero dim-size is not allowed"; + } else { + // cannot do argmax on an empty array + CHECK_NE(in.shape_.Size(), 0U) << "attempt to search an empty sequence"; + } + } + if (in.shape_.Size() == 0U) return; // zero-size tensor + // prepare shape + dmlc::optional> axes; + if (param.axis.has_value()) { + mxnet::Tuple t({param.axis.value()}); + axes = dmlc::optional>(t); + } + TShape small; + small = NumpyReduceAxesShapeImpl(in.shape_, axes, true); + mxnet::TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(in.shape_, small, &src_shape, &dst_shape); + const TBlob in_data = in.reshape(src_shape); + // request a work space + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[0].reshape(dst_shape), req[0], workspace, in_data, + reducer, NDim, "identity", true); + }); +} + +#endif // MXNET_USE_CUDA + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op.cuh b/src/operator/numpy/np_broadcast_reduce_op.cuh deleted file mode 100644 index f97aa7831516..000000000000 --- a/src/operator/numpy/np_broadcast_reduce_op.cuh +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015-2020 by Contributors - * \file np_broadcast_reduce-inl.cuh - * \brief GPU implementations for numpy binary broadcast ops - * \author Zhaoqi Zhu -*/ -#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_ -#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_ - -using namespace mshadow::cuda; -using namespace mshadow; -using namespace broadcast; - -template -void NumpyArgMinMaxReduce(Stream *s, const TBlob& in_data, const TBlob& out_data, - const Tensor& workspace) { - cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config(out_data.shape_, in_data.shape_, nullptr, nullptr, sizeof(OType)); - - ReduceImpl> - (stream, out_data, kWriteTo, in_data, workspace, config); -} - -#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_ diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 80714fb6d770..9ce3967f4797 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -298,7 +298,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; if (req[0] == kNullOp) return; - const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + const auto& param = nnvm::get(attrs.parsed); if (param.initial.has_value()) { LOG(FATAL) << "initial is not supported yet"; } @@ -494,10 +494,6 @@ void NumpyArgMinMaxReduce(mshadow::Stream *s, const TBlob& in_data, const T in_data.shape_.get(), out_data.shape_.get(), rshape, rstride); } -#ifdef __CUDACC__ -#include "np_broadcast_reduce_op.cuh" -#endif - template void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -508,7 +504,7 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; if (req[0] == kNullOp) return; // parse param - const ReduceAxisParam& param = nnvm::get(attrs.parsed); + const auto& param = nnvm::get(attrs.parsed); mshadow::Stream *s = ctx.get_stream(); TBlob out = outputs[0]; TBlob in = inputs[0]; @@ -537,36 +533,52 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs, small = NumpyReduceAxesShapeImpl(in.shape_, axes, true); mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(in.shape_, small, &src_shape, &dst_shape); + const TBlob in_data = in.reshape(src_shape); + // request a work space + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape); MSHADOW_TYPE_SWITCH_WITH_BOOL(in.type_flag_, DType, { // define OType typedef mxnet::op::mshadow_op::IndexedNum OType; - // request a work space - size_t workspace_size = sizeof(OType) * out.shape_.Size(); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - // set up intermediate output - TBlob intermediate = out; - intermediate.dptr_ = reinterpret_cast(workspace.dptr_); - // reshape the input and intermediate output tensor - const TBlob in_data = in.reshape(src_shape); - const TBlob intermediate_out_data = intermediate.reshape(dst_shape); // switch dim BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, intermediate_out_data.shape_, req[0], in_data.shape_, sizeof(OType)); + constexpr size_t align_size = 1024; + const size_t aligned_first_workspace_size = ((workspace_size + align_size - 1) / align_size) + * align_size; + workspace_size = aligned_first_workspace_size + + sizeof(OType) * out.shape_.Size(); Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + // set up intermediate output + TBlob intermediate = out; + intermediate.dptr_ = reinterpret_cast(workspace.dptr_ + + aligned_first_workspace_size); + // reshape the input and intermediate output tensor + const TBlob intermediate_out_data = intermediate.reshape(dst_shape); NumpyArgMinMaxReduce(s, in_data, intermediate_out_data, workspace); + // parse the indices from the intermediate tensor back to the actual output tensor + using namespace mxnet_op; + Kernel::Launch( + s, out.shape_.Size(), outputs[0].dptr(), + static_cast(intermediate_out_data.dptr_)); }); - // parse the indices from the intermediate tensor back to the actual output tensor - using namespace mxnet_op; - Kernel::Launch( - s, out.shape_.Size(), outputs[0].dptr(), - static_cast(intermediate_out_data.dptr_)); }); } +#if MXNET_USE_CUDA + +struct NumpyArgMinMaxRTCCompute { + std::string reducer; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +#endif + template inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -663,36 +675,6 @@ struct NumpyMomentsParam : public dmlc::Parameter { } }; -template -void ReduceAxesComputeWithWorkspaceImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mshadow::Tensor& workspace, - const mxnet::TShape& src_shape, - const mxnet::TShape& dst_shape, - const int ddof = 0) { - using namespace mshadow; - using namespace mshadow::expr; - - Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { - const TBlob in_data = inputs[0].reshape(src_shape); - const TBlob out_data = outputs[0].reshape(dst_shape); - BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - broadcast::Reduce( - s, out_data, req[0], workspace, in_data); - if (normalize) { - auto out = out_data.FlatTo2D(s); - out /= scalar(src_shape.Size()/dst_shape.Size() - ddof); - } - }); - }); - }); -} - struct NumpyWeightedAverageParam : public dmlc::Parameter { dmlc::optional> axis; bool returned; @@ -871,13 +853,6 @@ struct avg_grad_w_1D_kernel { } }; -// Windows has issues with #ifdefs inside MSHADOW_TYPE_SWITCH -#ifndef __CUDACC__ -#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastCompute -#else -#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastRTCCompute {#OP} -#endif - template void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -914,6 +889,9 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, weights = weights.reshape(new_w_shape); small2 = TShape(new_w_shape.ndim(), 1); } + TBlob wa; + TBlob sum_of_wa; + Tensor workspace; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { // Get temp space size_t temp_data_size = data.shape_.Size() * sizeof(DType); @@ -922,38 +900,53 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, BroadcastReduceShapeCompact(data.shape_, small1, &src_shape, &dst_shape); size_t workspace_size = 0; workspace_size = broadcast::ReduceWorkspaceSize( - s, dst_shape, {kWriteTo}, src_shape, sizeof(DType)); + s, dst_shape, {kWriteTo}, src_shape); size_t temp_mem_size = temp_data_size + temp_sum_size + workspace_size; Tensor temp_mem = ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); - DType *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); - DType *temp_sum_ptr = reinterpret_cast(temp_mem.dptr_ + temp_data_size); + auto *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); + auto *temp_sum_ptr = reinterpret_cast(temp_mem.dptr_ + temp_data_size); char *workspace_ptr = temp_mem.dptr_ + temp_data_size + temp_sum_size; - Tensor workspace(workspace_ptr, Shape1(workspace_size), s); + workspace = Tensor(workspace_ptr, Shape1(workspace_size), s); // Compute weighted data - TBlob wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask); - NP_BROADCAST_REDUCE_OP_BROADCAST(mul)( - attrs, ctx, {data, weights}, {kWriteTo}, {wa}); - - // Compute sum of weighted data - TBlob sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); - ReduceAxesComputeWithWorkspaceImpl( - ctx, {wa}, {kWriteTo}, {sum_of_wa}, workspace, src_shape, dst_shape); - if (!back) { - const TBlob& avg = outputs[0]; - const TBlob& sum_of_weights = outputs[1]; - TShape w_src_shape, w_dst_shape; - BroadcastReduceShapeCompact(weights.shape_, small2, &w_src_shape, &w_dst_shape); - // Compute sum of weight - TBlob scl = sum_of_weights.reshape(small2); - ReduceAxesComputeWithWorkspaceImpl( - ctx, {weights}, {kWriteTo}, {scl}, workspace, w_src_shape, w_dst_shape); - - // Compute avg and assign output - NP_BROADCAST_REDUCE_OP_BROADCAST(div)( - attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); - } else { + wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask); + sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); + }); +#if !defined(__CUDACC__) + BinaryBroadcastCompute( + attrs, ctx, {data, weights}, {kWriteTo}, {wa}); + + // Compute sum of weighted data + ReduceAxesComputeImpl( + ctx, {wa}, {kWriteTo}, {sum_of_wa}, small1, &workspace); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {data, weights}, {kWriteTo}, {wa}); + + // Compute sum of weighted data + ReduceAxesRTCComputeImpl(ctx, {wa}, {kWriteTo}, {sum_of_wa}, small1, "red::sum{}", + &workspace, false, "identity"); +#endif + if (!back) { + const TBlob& avg = outputs[0]; + const TBlob& sum_of_weights = outputs[1]; + // Compute sum of weight + TBlob scl = sum_of_weights.reshape(small2); +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( + ctx, {weights}, {kWriteTo}, {scl}, small2, &workspace); + // Compute avg and assign output + BinaryBroadcastCompute( + attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); +#else + ReduceAxesRTCComputeImpl(ctx, {weights}, {kWriteTo}, {scl}, small2, "red::sum{}", + &workspace, false, "identity"); + // Compute avg and assign output + BinaryBroadcastRTCCompute {"div"}( + attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); +#endif + } else { + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { // Compute and assign the derivatives of a and weights const TBlob& igrad_a = outputs[0]; const TBlob& igrad_w = outputs[1]; @@ -992,12 +985,10 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, } }); }) - } - }); + }); + } } -#undef NP_BROADCAST_REDUCE_OP_BROADCAST - template void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -1010,21 +1001,29 @@ void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs, CHECK_NE(req[0], kWriteInplace) << "Average does not support write in-place"; const auto& param = nnvm::get(attrs.parsed); const TBlob& data = inputs[0]; + TShape small; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { if (!param.weighted) { - TShape small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true); + small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true); // Compute sum of weights which equals to the product of sizes of reduced axes Stream* s = ctx.get_stream(); auto ret = outputs[1].FlatTo1D(s); ret = scalar(data.shape_.Size()/small.Size()); - // Compute mean - ReduceAxesComputeImpl( - ctx, inputs, req, {outputs[0]}, small); - } else { - NumpyWeightedAverageComputeImpl( - attrs, ctx, inputs, req, outputs, param.axis); } }); + if (!param.weighted) { + // Compute mean +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( + ctx, inputs, req, {outputs[0]}, small); +#else + ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0]}, small, + "red::sum{}", nullptr, true); +#endif + } else { + NumpyWeightedAverageComputeImpl( + attrs, ctx, inputs, req, outputs, param.axis); + } } template @@ -1090,40 +1089,58 @@ void NumpyMomentsForward(const nnvm::NodeAttrs& attrs, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(data.shape_, small, &src_shape, &dst_shape); + // Get workspace and temp space for data - mean + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape); + size_t temp_data_size = data.shape_.Size() * common::mshadow_type_info(inputs[0].type_flag_).size; + size_t temp_mem_size = temp_data_size + workspace_size; + Tensor temp_mem = + ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); + char *workspace_ptr = temp_mem.dptr_ + temp_data_size; + Tensor workspace(workspace_ptr, Shape1(workspace_size), s); + // Compute mean +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( + ctx, inputs, {kWriteTo}, {mean}, small, &workspace); +#else + ReduceAxesRTCComputeImpl(ctx, inputs, {kWriteTo}, {mean}, small, "red::sum{}", + &workspace, true, "identity"); +#endif + // Compute data - mean + Shape<6> data_shape, mean_shape; + for (int i = 0; i < 6; ++i) { + data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; + mean_shape[i] = (i < small.ndim()) ? small[i] : 1; + } +#if !defined(__CUDACC__) MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { - // Get workspace and temp space for data - mean - size_t workspace_size = 0; - workspace_size = broadcast::ReduceWorkspaceSize( - s, dst_shape, req[0], src_shape, sizeof(DType)); - size_t temp_data_size = data.shape_.Size() * sizeof(DType); - size_t temp_mem_size = temp_data_size + workspace_size; - Tensor temp_mem = - ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); DType *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); - char *workspace_ptr = temp_mem.dptr_ + temp_data_size; - Tensor workspace(workspace_ptr, Shape1(workspace_size), s); - // Compute mean - ReduceAxesComputeWithWorkspaceImpl( - ctx, inputs, {kWriteTo}, {mean}, workspace, src_shape, dst_shape); - // Compute data - mean - Shape<6> data_shape, mean_shape; - for (int i = 0; i < 6; ++i) { - data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; - mean_shape[i] = (i < small.ndim()) ? small[i] : 1; - } Kernel::Launch(s, data_shape.Size(), temp_data_ptr, data.dptr(), mean.dptr(), data_shape, mean_shape); Tensor temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s); TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_); - ReduceAxesComputeWithWorkspaceImpl( - ctx, {temp_data_blob}, {req[0]}, {moment}, workspace, src_shape, dst_shape, param.ddof); - if (sqrt) { + ReduceAxesComputeImpl( + ctx, {temp_data_blob}, {req[0]}, {moment}, small, &workspace, param.ddof); + if (sqrt && req[0] != kNullOp) { Tensor moment_tensor = moment.FlatTo1D(s); moment_tensor = F(moment_tensor); } }); }); +#else + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + DType *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); + Kernel::Launch(s, data_shape.Size(), temp_data_ptr, + data.dptr(), mean.dptr(), data_shape, mean_shape); + Tensor temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s); + TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_); + ReduceAxesRTCComputeImpl(ctx, {temp_data_blob}, {req[0]}, {moment}, small, + "red::sum{}", &workspace, true, "identity", param.ddof); + if (sqrt && req[0] != kNullOp) { + UnaryRTCCompute {"sqrt"}({}, ctx, {moment}, {kWriteInplace}, {moment}); + } + }); +#endif } template @@ -1159,6 +1176,7 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs, for (int i = 0; i < igrad_shape.ndim(); ++i) { expanded_igrad_shape[i + ndim_delta] = igrad_shape[i]; } +#if !defined(__CUDACC__) if (NeedSafeAcc(inputs[0].type_flag_, outputs[0].type_flag_)) { ReduceAxesComputeImpl( ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape); @@ -1166,6 +1184,10 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl( ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape); } +#else + ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, + expanded_igrad_shape, "red::sum{}", nullptr, false); +#endif } template diff --git a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu index d3247b743bc5..405ae4b38eb7 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu @@ -29,12 +29,12 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_any) -.set_attr("FCompute", NumpyReduceAxesBoolCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"NonZero", "red::sum{}", false}); NNVM_REGISTER_OP(_npi_all) -.set_attr("FCompute", NumpyReduceAxesBoolCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"NonZero", "red::product{}", false}); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cu b/src/operator/numpy/np_broadcast_reduce_op_index.cu index 892d04679422..eb6086c9d7fe 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_index.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_index.cu @@ -28,10 +28,10 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_argmax) -.set_attr("FCompute", NumpyArgMinMaxCompute); +.set_attr("FCompute", NumpyArgMinMaxRTCCompute{"red::argmax{}"}); NNVM_REGISTER_OP(_npi_argmin) -.set_attr("FCompute", NumpyArgMinMaxCompute); +.set_attr("FCompute", NumpyArgMinMaxRTCCompute{"red::argmin{}"}); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu index 422097d20181..602057324af2 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -27,25 +27,30 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_sum) -.set_attr("FCompute", NumpyReduceAxesCompute); +.set_attr("FCompute", + ReduceAxesRTCCompute{"identity", "red::sum{}", false}); NNVM_REGISTER_OP(_backward_npi_sum) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); NNVM_REGISTER_OP(_npi_max) -.set_attr("FCompute", NumpyReduceAxesNoDTypeCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::maximum{}", false}); NNVM_REGISTER_OP(_backward_npi_max) .set_attr("FCompute", NumpyReduceAxesNoDTypeBackward); NNVM_REGISTER_OP(_npi_min) -.set_attr("FCompute", NumpyReduceAxesNoDTypeCompute); +.set_attr("FCompute", + ReduceAxesRTCCompute{"identity", + "red::minimum{}", false}); NNVM_REGISTER_OP(_backward_npi_min) .set_attr("FCompute", NumpyReduceAxesNoDTypeBackward); NNVM_REGISTER_OP(_npi_prod) -.set_attr("FCompute", NumpyReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::product{}", false}); NNVM_REGISTER_OP(_backward_npi_prod) .set_attr("FCompute", NumpyReduceAxesBackwardUseInOut); @@ -57,7 +62,8 @@ NNVM_REGISTER_OP(_backward_np_average) .set_attr("FCompute", NumpyWeightedAverageBackward); NNVM_REGISTER_OP(_npi_mean) -.set_attr("FCompute", NumpyReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::sum{}", true}); NNVM_REGISTER_OP(_backward_np_mean) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h index 80beaa3a0bf5..01c54b650616 100644 --- a/src/operator/numpy/np_constraint_check.h +++ b/src/operator/numpy/np_constraint_check.h @@ -56,9 +56,14 @@ void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, CHECK_EQ(outputs.size(), 1U); const ConstraintCheckParam& param = nnvm::get(attrs.parsed); +#if !defined(__CUDACC__) ReduceAxesComputeImpl(ctx, inputs, req, outputs, outputs[0].shape_); +#else + ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, + outputs[0].shape_, "red::product{}"); +#endif std::string msg = param.msg; bool red_output = true; GetReduceOutput(ctx.get_stream(), outputs[0], &red_output); diff --git a/src/operator/numpy/np_cross-inl.h b/src/operator/numpy/np_cross-inl.h index ab64564dc85d..7d5dd30d813b 100644 --- a/src/operator/numpy/np_cross-inl.h +++ b/src/operator/numpy/np_cross-inl.h @@ -677,8 +677,7 @@ struct ReduceImplWrap { std::vector reduce_axis = GetReduceAxis(out_move_shape, in_move_shape); if (reduce_axis.empty() || req == kNullOp) { return 0U; } ws_reduce = broadcast::ReduceWorkspaceSize(ctx.get_stream(), - out_shape, req, in_shape, - sizeof(DType)); + out_shape, req, in_shape); return ws_reduce; } @@ -690,10 +689,17 @@ struct ReduceImplWrap { const Tensor workspace_tensor) { Stream *s = ctx.get_stream(); // Reduce work_in to work_out. +#if !defined(__CUDACC__) SUM_NDIM_SWITCH(work_out.ndim(), NDim, { op::broadcast::Reduce( s, work_out, kWriteTo, workspace_tensor, work_in); }); +#else + SUM_NDIM_SWITCH(work_out.ndim(), NDim, { + op::broadcast::RTCReduce(ctx, work_out, kWriteTo, workspace_tensor, work_in, + "red::sum{}", NDim, "identity"); + }); +#endif // Copy work_out to out_data. MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, { mxnet_op::Kernel, xpu>::Launch( diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index 1fa58908a113..be19b3876a40 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -413,9 +413,9 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, { if (need_bc) { workspace_size_l = ReduceWorkspaceSize( - s, new_lshape, req[0], new_oshape, new_lshape, new_rshape, sizeof(OType)); + s, new_lshape, req[0], new_oshape, new_lshape, new_rshape); workspace_size_r = ReduceWorkspaceSize( - s, new_rshape, req[1], new_oshape, new_lshape, new_rshape, sizeof(OType)); + s, new_rshape, req[1], new_oshape, new_lshape, new_rshape); } size_t workspace_size = std::max(workspace_size_l, workspace_size_r); size_t cast_tensor_size = tensor_size * sizeof(OType); diff --git a/src/operator/numpy/np_kron-inl.h b/src/operator/numpy/np_kron-inl.h index f357ceccb6cc..bf0983f10157 100644 --- a/src/operator/numpy/np_kron-inl.h +++ b/src/operator/numpy/np_kron-inl.h @@ -188,6 +188,14 @@ void KronOpForwardImpl(const OpContext& ctx, }); } +#if !defined(__CUDACC__) +#define NP_KRON_REDUCE_AXES(safe_acc, workspace, ...) \ + ReduceAxesComputeImpl(__VA_ARGS__, &workspace) +#else +#define NP_KRON_REDUCE_AXES(safe_acc, workspace, ...) \ + ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}", &workspace) +#endif + template void KronOpBackwardImpl(const OpContext& ctx, const std::vector& req, @@ -226,12 +234,23 @@ void KronOpBackwardImpl(const OpContext& ctx, const OpReqType& scalar_req = (ashape.ndim() == 0) ? req[0] : req[1]; ASSIGN_DISPATCH(tensor_grad_, tensor_req, broadcast_scalar(scalar_, tensor_grad_.shape_) * ograd_); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(ograd.shape_.Size()), s); - ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * ograd_); - - ReduceAxesComputeImpl( - ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(ograd.shape_, scalar_grad_.shape_, &src_shape, &dst_shape); + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, + {scalar_req}, src_shape); + constexpr size_t align_size = 1024; + const size_t aligned_first_workspace_size = ((workspace_size + align_size - 1) / align_size) + * align_size; + workspace_size = aligned_first_workspace_size + ograd.shape_.Size() * sizeof(DType); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Tensor temp(reinterpret_cast(workspace.dptr_ + + aligned_first_workspace_size), + Shape1(ograd.shape_.Size()), s); + ASSIGN_DISPATCH(temp, kWriteTo, tensor_ * ograd_); + + NP_KRON_REDUCE_AXES(true, workspace, ctx, {TBlob(temp)}, {scalar_req}, + {TBlob(scalar_grad_)}, scalar_grad_.shape_); } else { MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { Shape ashape_ = oshape.get(); @@ -276,6 +295,8 @@ void KronOpBackwardImpl(const OpContext& ctx, }); } +#undef NP_KRON_REDUCE_AXES + template inline void KronOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h index 1bfc6d10dbac..47e42d01d6a6 100644 --- a/src/operator/numpy/np_tensordot_op-inl.h +++ b/src/operator/numpy/np_tensordot_op-inl.h @@ -370,6 +370,14 @@ inline mxnet::TShape GetReverseShape(const mxnet::Tuple& shape) { return shape2; } +#if !defined(__CUDACC__) +#define NP_TENSORDOT_REDUCE_AXES(safe_acc, ...) \ + ReduceAxesComputeImpl(__VA_ARGS__) +#else +#define NP_TENSORDOT_REDUCE_AXES(safe_acc, ...) \ + ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}") +#endif + /** * calculates tensordot derivative. */ @@ -424,8 +432,8 @@ void TensordotBackwardImpl(const Tuple& a_axes_summed, workspace.stream_); ASSIGN_DISPATCH(dtypespace, kWriteTo, tensor_ * out_grad_); - ReduceAxesComputeImpl( - ctx, {TBlob(dtypespace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + NP_TENSORDOT_REDUCE_AXES(true, ctx, {TBlob(dtypespace)}, {scalar_req}, + {TBlob(scalar_grad_)}, scalar_grad_.shape_); } else { // Two tensors of at least 1 dimensions. Tuple a_axes_remained; @@ -734,8 +742,8 @@ void TensordotIntAxesBackwardImpl(const int axes, ctx.requested[0].get_space_typed(Shape1(out_grad.shape_.Size()), s); ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * out_grad_); - ReduceAxesComputeImpl( - ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + NP_TENSORDOT_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {scalar_req}, + {TBlob(scalar_grad_)}, scalar_grad_.shape_); } else { // Two tensors of at least 1 dimensions. Tuple a_axes_summed; @@ -759,6 +767,8 @@ void TensordotIntAxesBackwardImpl(const int axes, }); } +#undef NP_TENSORDOT_REDUCE_AXES + /** * backward function. */ diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h index 10ec081b2a8f..43af21def5ae 100644 --- a/src/operator/numpy/np_where_op-inl.h +++ b/src/operator/numpy/np_where_op-inl.h @@ -175,6 +175,13 @@ inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs, }); } +#if !defined(__CUDACC__) +#define NP_WHERE_REDUCE_AXES(safe_acc, ...) \ + ReduceAxesComputeImpl(__VA_ARGS__) +#else +#define NP_WHERE_REDUCE_AXES(safe_acc, ...) ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}") +#endif + template inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -226,9 +233,9 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, size_t ws_size = 0; if (ograd.shape_ != dx.shape_ || ograd.shape_ != dy.shape_) { size_t ws_size1 = broadcast::ReduceWorkspaceSize( - s, expanded_lshape, req[0], expanded_oshape, sizeof(DType)); + s, expanded_lshape, req[0], expanded_oshape); size_t ws_size2 = broadcast::ReduceWorkspaceSize( - s, expanded_rshape, req[1], expanded_oshape, sizeof(DType)); + s, expanded_rshape, req[1], expanded_oshape); ws_size = std::max(ws_size1, ws_size2); } // process left output @@ -246,10 +253,10 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, s, ograd.Size(), req[0], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[0]}, {dx.reshape(expanded_lshape)}, expanded_lshape); } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[0]}, {dx.reshape(expanded_lshape)}, expanded_lshape); } } @@ -268,10 +275,10 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, s, ograd.Size(), req[1], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); if (NeedSafeAcc(dy.type_flag_, dy.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, + NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[1]}, {dy.reshape(expanded_rshape)}, expanded_rshape); } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, + NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[1]}, {dy.reshape(expanded_rshape)}, expanded_rshape); } } @@ -367,7 +374,7 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs, size_t ws_size = 0; if (ograd.shape_ != dx.shape_) { ws_size = broadcast::ReduceWorkspaceSize(s, expanded_lshape, req[0], - expanded_oshape, sizeof(DType)); + expanded_oshape); } // If lscalar, then process right output, `is_left` should be false if (ograd.shape_ == dx.shape_) { @@ -384,10 +391,10 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs, s, ograd.Size(), req[0], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[0]}, {dx.reshape(expanded_lshape)}, expanded_lshape); } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[0]}, {dx.reshape(expanded_lshape)}, expanded_lshape); } } @@ -395,6 +402,8 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs, }); } +#undef NP_WHERE_REDUCE_AXES + template inline void NumpyWhereScalar2OpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index ab8afe95f0b1..1e43134e6e96 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -277,6 +277,86 @@ inline bool TwoparamsDistOpConcatShape(const nnvm::NodeAttrs &attrs, return true; } +template +inline void CommonReparamBackwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& new_lshape, + const mxnet::TShape& new_rshape, + const mxnet::TShape& new_oshape) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + Stream *s = ctx.get_stream(); + const TBlob lgrad = outputs[0].reshape(new_lshape); + const TBlob rgrad = outputs[1].reshape(new_rshape); + const TBlob ograd = inputs[0].reshape(new_oshape); + // Mean + const TBlob lhs = inputs[2].reshape(new_lshape); + // Scale + const TBlob rhs = inputs[3].reshape(new_rshape); + const TBlob samples = inputs[4].reshape(new_oshape); + const TBlob noise = inputs[5].reshape(new_oshape); + size_t workspace_size_l = ReduceWorkspaceSize( + s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_); + size_t workspace_size_r = ReduceWorkspaceSize( + s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_); + size_t workspace_size = std::max(workspace_size_l, workspace_size_r); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) + Reduce( + s, lgrad, req[0], workspace, ograd); + Reduce( + s, rgrad, req[1], workspace, ograd, noise, rhs); +#else + RTCReduce(ctx, lgrad, req[0], workspace, ograd, "red::sum{}", ndim, "identity"); + RTCReduce(ctx, rgrad, req[1], workspace, ograd, noise, rhs, "red::sum{}", ndim, "mul", "left"); +#endif +} + +template +inline void CommonScalarReparamBackwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& new_ishape, + const mxnet::TShape& new_oshape, + const bool loc_is_tensor = false) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + Stream *s = ctx.get_stream(); + const TBlob igrad = outputs[0].reshape(new_ishape); + // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, + // samples, noise] + const TBlob ograd = inputs[0].reshape(new_oshape); + const TBlob itensor = inputs[2].reshape(new_ishape); + const TBlob samples = inputs[3].reshape(new_oshape); + const TBlob noise = inputs[4].reshape(new_oshape); + size_t workspace_size = + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) + if (loc_is_tensor) { + Reduce(s, igrad, req[0], + workspace, ograd); + } else { + Reduce( + s, igrad, req[0], workspace, ograd, noise, noise); + } +#else + if (loc_is_tensor) { + RTCReduce(ctx, igrad, req[0], workspace, ograd, "red::sum{}", ndim, "identity"); + } else { + RTCReduce(ctx, igrad, req[0], workspace, ograd, noise, noise, "red::sum{}", + ndim, "mul", "left"); + } +#endif +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_exponential_op.h b/src/operator/numpy/random/np_exponential_op.h index 374b3b428eba..0f0d4623800e 100644 --- a/src/operator/numpy/random/np_exponential_op.h +++ b/src/operator/numpy/random/np_exponential_op.h @@ -171,11 +171,16 @@ inline void ExponentialReparamBackwardImpl(const OpContext& ctx, const TBlob samples = inputs[3].reshape(new_oshape); const TBlob noise = inputs[4].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) Reduce( s, igrad, req[0], workspace, ograd, noise, noise); +#else + RTCReduce(ctx, igrad, req[0], workspace, ograd, noise, noise, + "red::sum{}", ndim, "mul", "left"); +#endif } template diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index a0f3299f4d84..55041a93b1f7 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -420,14 +420,19 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx, const TBlob alpha = inputs[1].reshape(new_ishape); TBlob samples = inputs[2].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); // Convert samples to standard gamma Kernel, xpu>::Launch( s, samples.Size(), samples.dptr(), samples.dptr(), DType(scale)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) Reduce( s, igrad, req[0], workspace, ograd, alpha, samples); +#else + RTCReduce(ctx, igrad, req[0], workspace, ograd, alpha, samples, "red::sum{}", ndim, + "mul", "gamma_implicit_grad"); +#endif Kernel, xpu>::Launch( s, igrad.Size(), igrad.dptr(), igrad.dptr(), DType(scale)); // Convert samples back, otherwise the output would be corrupted. diff --git a/src/operator/numpy/random/np_location_scale_op.h b/src/operator/numpy/random/np_location_scale_op.h index 73403f37f1f0..0179a572bf3f 100644 --- a/src/operator/numpy/random/np_location_scale_op.h +++ b/src/operator/numpy/random/np_location_scale_op.h @@ -275,72 +275,6 @@ void NumpyLocationScaleForward(const nnvm::NodeAttrs &attrs, } } -template -inline void LocationScaleReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_lshape, - const mxnet::TShape& new_rshape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob lgrad = outputs[0].reshape(new_lshape); - const TBlob rgrad = outputs[1].reshape(new_rshape); - const TBlob ograd = inputs[0].reshape(new_oshape); - // Mean - const TBlob lhs = inputs[2].reshape(new_lshape); - // Scale - const TBlob rhs = inputs[3].reshape(new_rshape); - const TBlob samples = inputs[4].reshape(new_oshape); - const TBlob noise = inputs[5].reshape(new_oshape); - size_t workspace_size_l = ReduceWorkspaceSize( - s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); - size_t workspace_size_r = ReduceWorkspaceSize( - s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); - size_t workspace_size = std::max(workspace_size_l, workspace_size_r); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce( - s, lgrad, req[0], workspace, ograd); - Reduce( - s, rgrad, req[1], workspace, ograd, noise, rhs); -} - -template -inline void ScalarLocationScaleReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape, - const bool loc_is_tensor) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - if (loc_is_tensor) { - Reduce(s, igrad, req[0], - workspace, ograd); - } else { - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); - } -} - // Allow logistic and gumbel sampling to be differentiable, // using reparameterization trick described in: // Auto-encoding variational bayes. @@ -359,7 +293,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs, if (outputs.size() == 0U) { return; } - const NumpyLocationScaleParam ¶m = nnvm::get(attrs.parsed); + const auto ¶m = nnvm::get(attrs.parsed); // [tensor tensor] case if (inputs.size() == 6U) { mxnet::TShape new_lshape, new_rshape, new_oshape; @@ -367,7 +301,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs, &new_lshape, &new_rshape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - LocationScaleReparamBackwardImpl( + CommonReparamBackwardImpl( ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape); }); }); @@ -380,7 +314,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs, bool loc_is_tensor = !param.loc.has_value(); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarLocationScaleReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor); }); }); diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index e43d98de0168..06b5bfaabf05 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -161,7 +161,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { using namespace mshadow; using namespace mxnet_op; - const NumpyNormalParam ¶m = nnvm::get(attrs.parsed); + const auto ¶m = nnvm::get(attrs.parsed); Stream *s = ctx.get_stream(); // Generate base random number. Random *prnd = ctx.requested[0].get_random(s); @@ -240,72 +240,6 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, } } -template -inline void NormalReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_lshape, - const mxnet::TShape& new_rshape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob lgrad = outputs[0].reshape(new_lshape); - const TBlob rgrad = outputs[1].reshape(new_rshape); - const TBlob ograd = inputs[0].reshape(new_oshape); - // Mean - const TBlob lhs = inputs[2].reshape(new_lshape); - // Variance - const TBlob rhs = inputs[3].reshape(new_rshape); - const TBlob samples = inputs[4].reshape(new_oshape); - const TBlob noise = inputs[5].reshape(new_oshape); - size_t workspace_size_l = ReduceWorkspaceSize( - s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); - size_t workspace_size_r = ReduceWorkspaceSize( - s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); - size_t workspace_size = std::max(workspace_size_l, workspace_size_r); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce(s, - lgrad, req[0], workspace, ograd); - Reduce( - s, rgrad, req[1], workspace, ograd, noise, rhs); -} - -template -inline void ScalarNormalReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape, - const bool loc_is_tensor) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - if (loc_is_tensor) { - Reduce(s, igrad, req[0], - workspace, ograd); - } else { - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); - } -} - // Allow normal sampling to be differentiable, // using reparameterization trick described in: // Auto-encoding variational bayes. @@ -324,7 +258,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs, if (outputs.size() == 0U) { return; } - const NumpyNormalParam ¶m = nnvm::get(attrs.parsed); + const auto ¶m = nnvm::get(attrs.parsed); // [tensor tensor] case if (inputs.size() == 6U) { mxnet::TShape new_lshape, new_rshape, new_oshape; @@ -332,7 +266,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs, &new_lshape, &new_rshape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - NormalReparamBackwardImpl( + CommonReparamBackwardImpl( ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape); }); }); @@ -345,7 +279,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs, bool loc_is_tensor = !param.loc.has_value(); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarNormalReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor); }); }); diff --git a/src/operator/numpy/random/np_pareto_op.h b/src/operator/numpy/random/np_pareto_op.h index 5e5d26aae4d2..16731c126324 100644 --- a/src/operator/numpy/random/np_pareto_op.h +++ b/src/operator/numpy/random/np_pareto_op.h @@ -155,32 +155,6 @@ void NumpyParetoForward(const nnvm::NodeAttrs &attrs, } } -template -inline void ScalarParetoReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); - } - template void ParetoReparamBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -202,7 +176,7 @@ if (inputs.size() == 5U) { &new_ishape, &new_ishape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarParetoReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, reqs, outputs, new_ishape, new_oshape); }); }); diff --git a/src/operator/numpy/random/np_rayleigh_op.h b/src/operator/numpy/random/np_rayleigh_op.h index 0f940e511a32..75c4784a515e 100644 --- a/src/operator/numpy/random/np_rayleigh_op.h +++ b/src/operator/numpy/random/np_rayleigh_op.h @@ -153,32 +153,6 @@ void NumpyRayleighForward(const nnvm::NodeAttrs &attrs, } } -template -inline void ScalarRayleighReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); -} - template void RayleighReparamBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -200,7 +174,7 @@ void RayleighReparamBackward(const nnvm::NodeAttrs& attrs, &new_ishape, &new_ishape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarRayleighReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, req, outputs, new_ishape, new_oshape); }); }); diff --git a/src/operator/numpy/random/np_weibull_op.h b/src/operator/numpy/random/np_weibull_op.h index 970dc859b97b..a7d6d5d2c405 100644 --- a/src/operator/numpy/random/np_weibull_op.h +++ b/src/operator/numpy/random/np_weibull_op.h @@ -155,32 +155,6 @@ void NumpyWeibullForward(const nnvm::NodeAttrs &attrs, } } -template -inline void ScalarWeibullReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); - } - template void WeibullReparamBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -202,7 +176,7 @@ if (inputs.size() == 5U) { &new_ishape, &new_ishape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarWeibullReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, reqs, outputs, new_ishape, new_oshape); }); }); diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h index 2c5c1ebe1fd3..0d89570f05a6 100644 --- a/src/operator/quantization/quantization_utils.h +++ b/src/operator/quantization/quantization_utils.h @@ -184,7 +184,7 @@ inline size_t ConfigReduce(mshadow::Stream* s, CHECK_EQ(src_shape->ndim(), NDim); CHECK_EQ(dst_shape->ndim(), NDim); - return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape, sizeof(DType)); + return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape); } enum QuantizeOutType { kAuto = 0, kInt8, kUint8 }; diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index d8814cc6cb20..cfbdb7f8e0ab 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -205,10 +205,17 @@ class QuantizeV2Operator { dev_id); Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, Shape1(temp_reduce_size), s); +#if !defined(__CUDACC__) broadcast::Reduce( s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); broadcast::Reduce( s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); +#else + broadcast::RTCReduce(ctx, in_min_t.reshape(dst_shape), kWriteTo, workspace, + inputs[0].reshape(src_shape), "red::minimum{}", 2, "identity"); + broadcast::RTCReduce(ctx, in_max_t.reshape(dst_shape), kWriteTo, workspace, + inputs[0].reshape(src_shape), "red::maximum{}", 2, "identity"); +#endif if (out_type == mshadow::kUint8) { Kernel::Launch( s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h index 2bdc3a712961..56686708dba4 100644 --- a/src/operator/quantization/requantize-inl.h +++ b/src/operator/quantization/requantize-inl.h @@ -148,6 +148,7 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs, temp_space.dptr_ + 8) + 1, Shape1(1), xpu::kDevMask, dev_id); Tensor workspace( temp_space.dptr_+2*actual_float_size+2*actual_quantized_size, Shape1(temp_reduce_size), s); +#if !defined(__CUDACC__) broadcast::Reduce( s, actual_min_quantized.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); @@ -158,6 +159,18 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs, broadcast::Reduce( s, actual_max_quantized.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); +#else + broadcast::RTCReduce(ctx, actual_min_quantized.reshape(dst_shape), + kWriteTo, workspace, inputs[0].reshape(src_shape), + "red::minimum{}", 2, "identity"); + Kernel::Launch(s, 1, + actual_min_float.dptr_, actual_min_quantized.dptr(), + inputs[1].dptr(), inputs[2].dptr()); + + broadcast::RTCReduce(ctx, actual_max_quantized.reshape(dst_shape), + kWriteTo, workspace, inputs[0].reshape(src_shape), + "red::maximum{}", 2, "identity"); +#endif Kernel::Launch(s, 1, actual_max_float.dptr_, actual_max_quantized.dptr(), inputs[1].dptr(), inputs[2].dptr()); diff --git a/src/operator/random/pdf_op.h b/src/operator/random/pdf_op.h index f6dc77718704..f53d3a6dd6e0 100644 --- a/src/operator/random/pdf_op.h +++ b/src/operator/random/pdf_op.h @@ -592,10 +592,11 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs, const PdfParam& param = nnvm::get(attrs.parsed); const size_t N(outputs[1].Size()); const TShape src_shape(Shape2(N, outputs[0].Size() / N)), dst_shape(Shape2(N, 1)); + const size_t red_work_size(broadcast::ReduceWorkspaceSize( + s, dst_shape, kAddTo, src_shape)); +#if !defined(__CUDACC__) // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf. MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - const size_t red_work_size(broadcast::ReduceWorkspaceSize( - s, dst_shape, kAddTo, src_shape, sizeof(DType))); const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size); Tensor tmp_space = ctx.requested[0].get_space_typed(Shape1(tmp_size), s); @@ -620,6 +621,35 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs, s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape)); } }); +#else + // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf. + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size); + Tensor tmp_space = + ctx.requested[0].get_space_typed(Shape1(tmp_size), s); + std::vector grads = {outputs[0]}; + grads.push_back(TBlob(tmp_space.dptr_, outputs[0].shape_, + outputs[1].dev_mask(), outputs[1].type_flag_, outputs[1].dev_id())); + if (pnum == 2) { + grads.push_back(TBlob(tmp_space.dptr_ + outputs[0].Size() * sizeof(DType), outputs[0].shape_, + outputs[2].dev_mask(), outputs[2].type_flag_, outputs[2].dev_id())); + } + if (param.is_log) { + PdfGradCaller, pnum, vparm>::op(inputs, req, grads, s); + } else { + PdfGradCaller, pnum, vparm>::op(inputs, req, grads, s); + } + Tensor red_work( + tmp_space.dptr_ + pnum * outputs[0].Size() * sizeof(DType), Shape1(red_work_size), s); + broadcast::RTCReduce(ctx, outputs[1].reshape(dst_shape), req[1], red_work, + grads[1].reshape(src_shape), "red::sum{}", 2, "identity"); + if (pnum == 2) { + broadcast::RTCReduce(ctx, outputs[2].reshape(dst_shape), req[2], red_work, + grads[2].reshape(src_shape), "red::sum{}", 2, "identity"); + } + }); + +#endif } } // namespace op diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh deleted file mode 100644 index 53b0ad80c4e7..000000000000 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ /dev/null @@ -1,414 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015-2020 by Contributors - * \file broadcast_reduce-inl.cuh - * \brief CUDA implementations for binary broadcast and reduce - * \author Antti-Pekka Hynninen, Przemyslaw Tredak -*/ -#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ -#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ - -using namespace mshadow::cuda; - -template> -__launch_bounds__(nthread_reduce) -__global__ void reduce_kernel(const int N, const int M, const bool addto, - const DType* __restrict big, OType *small, - const Shape big_shape0, const Shape small_shape, - const Shape big_shape, const Shape big_stride, - const int Mnext, const bool do_transpose) { - extern __shared__ char shTileChar[]; - AType* shTile = (AType*)(shTileChar); - const int tid = threadIdx.x + threadIdx.y*blockDim.x; - const int bx = (do_transpose) ? blockDim.y : blockDim.x; - const int by = (do_transpose) ? blockDim.x : blockDim.y; - const int tidx = (do_transpose) ? tid / by : threadIdx.x; - const int tidy = (do_transpose) ? tid % by : threadIdx.y; - for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { - // This TB handles M range [Mstart, ...., Mend - 1] - const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); - const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); - for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - Shape coord = mxnet_op::unravel(idx, small_shape); - int idx_big0 = mxnet_op::ravel(coord, big_shape0); - - AType val, residual; - Reducer::SetInitValue(val, residual); - if (idx < N) { - for (int k = tidy + Mstart; k < Mend; k += by*unroll) { - int idx_big[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride); - } - AType tmp[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) { - tmp[u] = OP::Map(big[idx_big[u]]); - // argmin/max, set IndexedNum.idx - if (IndexOP::do_op) - IndexOP::Op(&tmp[u], k+u*by); - } - } - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual); - } - } - } - - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - AType tmp, tmp_residual; - Reducer::SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2])); - } - } else { - if (idx < N) { - Reducer::Finalize(val, residual); - assign(&small[idx + m0*N], addto, OType(val)); - } - } - } - } -} - -template -__launch_bounds__(nthread_reduce) -__global__ void reduce_kernel(const int N, const int M, const bool addto, - const DType* __restrict big, const DType* __restrict lhs, - const DType* __restrict rhs, DType *small, - const Shape big_shape0, const Shape lhs_shape0, - const Shape rhs_shape0, const Shape small_shape, - const Shape big_shape, const Shape lhs_shape, - const Shape rhs_shape, const Shape big_stride, - const Shape lhs_stride, const Shape rhs_stride, - const int Mnext, const bool do_transpose) { - extern __shared__ char shTileChar[]; - DType* shTile = (DType*)(shTileChar); - const int tid = threadIdx.x + threadIdx.y*blockDim.x; - const int bx = (do_transpose) ? blockDim.y : blockDim.x; - const int by = (do_transpose) ? blockDim.x : blockDim.y; - const int tidx = (do_transpose) ? tid / by : threadIdx.x; - const int tidy = (do_transpose) ? tid % by : threadIdx.y; - for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { - // This TB handles M range [Mstart, ...., Mend - 1] - const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); - const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); - for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - Shape coord = mxnet_op::unravel(idx, small_shape); - int idx_big0 = mxnet_op::ravel(coord, big_shape0); - int idx_lhs0 = mxnet_op::ravel(coord, lhs_shape0); - int idx_rhs0 = mxnet_op::ravel(coord, rhs_shape0); - - DType val, residual; - Reducer::SetInitValue(val, residual); - if (idx < N) { - for (int k = tidy + Mstart; k < Mend; k += by*unroll) { - int idx_big[unroll]; - int idx_lhs[unroll]; - int idx_rhs[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride); - idx_lhs[u] = idx_lhs0 + mxnet_op::unravel_dot(k + u*by, lhs_shape, lhs_stride); - idx_rhs[u] = idx_rhs0 + mxnet_op::unravel_dot(k + u*by, rhs_shape, rhs_stride); - } - DType tmp[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) { - tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]])); - } - } - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual); - } - } - } - - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - DType tmp, tmp_residual; - Reducer::SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, shTile[tidx * 2]); - } - } else { - if (idx < N) { - Reducer::Finalize(val, residual); - assign(&small[idx + m0*N], addto, val); - } - } - } - } -} - -// Simple reduction of lines when M is small -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_lines_kernel(const int N, const int M, const bool addto, - const int small_in_stride, const DType* __restrict small_in, DType *small_out) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - - DType val, residual; - Reducer::SetInitValue(val, residual); - for (int k = 0; k < M; k++) { - Reducer::Reduce(val, small_in[idx + k*small_in_stride], residual); - } - - if (idx < N) { - Reducer::Finalize(val, residual); - assign(&small_out[idx], addto, val); - } - - } -} - -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_kernel_M1(const int N, const bool addto, - const DType* __restrict big, OType *small, const Shape bshape, - const Shape sshape) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = mxnet_op::unravel(idx, sshape); - int j = mxnet_op::ravel(coord, bshape); - AType val, residual, temp = OP::Map(big[j]); - Reducer::SetInitValue(val, residual); - Reducer::Reduce(val, temp, residual); - Reducer::Finalize(val, residual); - assign(&small[idx], addto, OType(val)); - } -} - -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_kernel_M1(const int N, const bool addto, - const DType* __restrict big, - const DType* __restrict lhs, - const DType* __restrict rhs, - DType *small, - const Shape big_shape, - const Shape lhs_shape, - const Shape rhs_shape, - const Shape small_shape) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = mxnet_op::unravel(idx, small_shape); - int idx_big = mxnet_op::ravel(coord, big_shape); - int idx_lhs = mxnet_op::ravel(coord, lhs_shape); - int idx_rhs = mxnet_op::ravel(coord, rhs_shape); - DType val, residual; - Reducer::SetInitValue(val, residual); - Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); - Reducer::Finalize(val, residual); - assign(&small[idx], addto, val); - } -} - -#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \ - if (do_unroll) { \ - const int unrollVar = unrollAmount; \ - {__VA_ARGS__} \ - } else { \ - const int unrollVar = 1; \ - {__VA_ARGS__} \ - } - -template> -void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, - const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config) { - if (config.M == 1) { - reduce_kernel_M1 - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), reinterpret_cast(small.dptr_), - big.shape_.get(), small.shape_.get()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1); - } else { - OType* small_dptr = reinterpret_cast(small.dptr_); - bool addto = (req == kAddTo); - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); - addto = false; - // Check that the workspace is contigiuous - CHECK_EQ(workspace.CheckContiguous(), true); - // Check that we have enough storage - CHECK_GE(workspace.size(0), config.workspace_size); - } - - const int by = (config.kernel_1.do_transpose) ? - config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { - reduce_kernel - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( - config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), - small.shape_.get(), config.rshape.get(), config.rstride.get(), - config.Mnext, config.kernel_1.do_transpose); - }); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel); - - if (config.Mnext > 1) { - reduce_lines_kernel - <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, - reinterpret_cast(small.dptr_)); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel); - } - } -} - -template -void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs, - const OpReqType req, const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config) { - if (config.M == 1) { - reduce_kernel_M1 - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), lhs.dptr(), rhs.dptr(), - small.dptr(), big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1); - } else { - DType* small_dptr = small.dptr(); - bool addto = (req == kAddTo); - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); - addto = false; - // Check that the workspace is contigiuous - CHECK_EQ(workspace.CheckContiguous(), true); - // Check that we have enough storage - CHECK_GE(workspace.size(0), config.workspace_size); - } - - const int by = (config.kernel_1.do_transpose) ? - config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { - reduce_kernel - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( - config.N, config.M, addto, big.dptr(), lhs.dptr(), rhs.dptr(), - small_dptr, big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get(), config.rshape.get(), - config.lhs_shape.get(), config.rhs_shape.get(), config.rstride.get(), - config.lhs_stride.get(), config.rhs_stride.get(), config.Mnext, - config.kernel_1.do_transpose); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel); - }); - - if (config.Mnext > 1) { - reduce_lines_kernel - <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel); - } - } -} - -#undef KERNEL_UNROLL_SWITCH - -template -void Reduce(Stream *s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); - if (safe_acc) { - MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { - typedef typename std::conditional::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { - typedef typename std::conditional::type OutType; - config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr, - sizeof(AccType)); - ReduceImpl( - stream, small, req, big, workspace, config); - }); - }); - } else { - ReduceImpl(stream, small, req, big, workspace, config); - } -} - -template -void ReduceBool(Stream *s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); - ReduceImpl(stream, small, req, big, workspace, config); -} - -template -void ReduceWithExtraMem(Stream* s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big) {}; - -template -void Reduce(Stream *s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big, - const TBlob& lhs, const TBlob& rhs) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, sizeof(DType)); - ReduceImpl(stream, small, lhs, rhs, req, big, workspace, config); -} - -#endif //MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index a7d5d6746003..1907c02897c9 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -445,13 +445,13 @@ void ReduceWithExtraMem(Stream* s, const TBlob& small, const OpReqType req, } inline size_t ReduceWorkspaceSize(Stream *s, const mxnet::TShape& small, const OpReqType req, - const mxnet::TShape& big, const int type_size) { + const mxnet::TShape& big) { return 0; } inline size_t ReduceWorkspaceSize(Stream *s, const mxnet::TShape& small, const OpReqType req, const mxnet::TShape& big, const mxnet::TShape& lhs, - const mxnet::TShape& rhs, const int type_size) { + const mxnet::TShape& rhs) { return 0; } @@ -539,11 +539,13 @@ struct ReduceImplConfig { inline ReduceImplConfig(const ::mxnet::TShape& small, const ::mxnet::TShape& big, const ::mxnet::TShape* lhs, - const ::mxnet::TShape* rhs, - const size_t type_size) : + const ::mxnet::TShape* rhs) : rshape(small.ndim(), 1), rstride(small.ndim(), 1), lhs_shape(small.ndim(), 1), lhs_stride(small.ndim(), 1), rhs_shape(small.ndim(), 1), rhs_stride(small.ndim(), 1) { + // The largest reduction type currently is (index_t, double) struct + // aligned to 16B + constexpr size_t max_type_size = 2 * sizeof(double); constexpr int maxLoopPerTB = 64; int ndim = small.ndim(); @@ -646,7 +648,7 @@ struct ReduceImplConfig { by++; } kernel_1.shMemSize = (kernel_1.blockDim.x > 1) ? - kernel_1.blockDim.x*by*type_size * 2 : 0; + kernel_1.blockDim.x*by*max_type_size * 2 : 0; // Maximum number of times we want TB to loop in M // Max size of M-block each TB can handle int maxMblock = kernel_1.blockDim.x*maxLoopPerTB; @@ -657,7 +659,7 @@ struct ReduceImplConfig { ceil_idiv(N, kernel_1.blockDim.x)); kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext); kernel_1.shMemSize = (kernel_1.blockDim.y > 1) ? - kernel_1.blockDim.x*kernel_1.blockDim.y*type_size * 2 : 0; + kernel_1.blockDim.x*kernel_1.blockDim.y*max_type_size * 2 : 0; // Maximum number of times we want TB to loop in M // Max size of M-block each TB can handle int maxMblock = kernel_1.blockDim.y*maxLoopPerTB; @@ -666,7 +668,7 @@ struct ReduceImplConfig { if (Mnext > 1) { // small_dptr[] is N*Mnext*type_size bytes - workspace_size += N*Mnext*sizeof(double); + workspace_size += N * Mnext * max_type_size; // Set gridDim.y to Mnext kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext); } @@ -681,24 +683,20 @@ struct ReduceImplConfig { }; inline size_t ReduceWorkspaceSize(Stream *s, const ::mxnet::TShape& small, const OpReqType req, - const ::mxnet::TShape& big, const int type_size) { + const ::mxnet::TShape& big) { if (req == kNullOp) return 0; - ReduceImplConfig config(small, big, nullptr, nullptr, type_size); + ReduceImplConfig config(small, big, nullptr, nullptr); return config.workspace_size; } inline size_t ReduceWorkspaceSize(Stream *s, const ::mxnet::TShape& small, const OpReqType req, const ::mxnet::TShape& big, const ::mxnet::TShape& lhs, - const ::mxnet::TShape& rhs, const int type_size) { + const ::mxnet::TShape& rhs) { if (req == kNullOp) return 0; - ReduceImplConfig config(small, big, &lhs, &rhs, type_size); + ReduceImplConfig config(small, big, &lhs, &rhs); return config.workspace_size; } -#ifdef __CUDACC__ -#include "broadcast_reduce-inl.cuh" -#endif - #endif // MXNET_USE_CUDA template @@ -784,7 +782,8 @@ void RTCReduce(const OpContext& ctx, const TBlob& big, const std::string& reducer, int ndim, - const std::string& OP); + const std::string& OP, + const bool use_index = false); void RTCReduce(const OpContext& ctx, const TBlob& small, diff --git a/src/operator/tensor/broadcast_reduce_minmax_value.cu b/src/operator/tensor/broadcast_reduce_minmax_value.cu index baf79feb5c60..c8cb757cd9a3 100644 --- a/src/operator/tensor/broadcast_reduce_minmax_value.cu +++ b/src/operator/tensor/broadcast_reduce_minmax_value.cu @@ -28,13 +28,15 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(max) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::maximum{}", false}); NNVM_REGISTER_OP(_backward_max) .set_attr("FCompute", ReduceAxesBackwardUseInOut); NNVM_REGISTER_OP(min) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::minimum{}", false}); NNVM_REGISTER_OP(_backward_min) .set_attr("FCompute", ReduceAxesBackwardUseInOut); diff --git a/src/operator/tensor/broadcast_reduce_op.cc b/src/operator/tensor/broadcast_reduce_op.cc new file mode 100644 index 000000000000..483787ec7b0a --- /dev/null +++ b/src/operator/tensor/broadcast_reduce_op.cc @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "broadcast_reduce_op.h" +#include +#include "../numpy/np_broadcast_reduce_op.h" +#include "elemwise_binary_scalar_op.h" +#include "mxnet/tuple.h" + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUDA + +void ReduceAxesRTCComputeImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& small, + const std::string& reducer, + const mshadow::Tensor* workspace, + const bool normalize, + const std::string& OP, + const int ddof) { + using namespace mshadow; + + mxnet::TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); + Stream* s = ctx.get_stream(); + Tensor w; + if (workspace == nullptr) { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, dst_shape, req[0], src_shape); + w = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + workspace = &w; + } + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, out_data, req[0], *workspace, in_data, reducer, NDim, OP); + }); + if (normalize) { + NumpyBinaryScalarParam p{}; + p.scalar = static_cast(src_shape.Size()/dst_shape.Size() - ddof); + NodeAttrs a; + a.parsed = p; + BinaryScalarRTCCompute {"div"}(a, ctx, {out_data}, {kWriteInplace}, {out_data}); + } +} + +namespace { +template +void PrepareReduce(const Param& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* shape, int* ddof); + +template <> +void PrepareReduce(const ReduceAxesParam& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* small, int* ddof) { + if (param.keepdims) { + *small = outputs[0].shape_; + } else { + *small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude); + } + + *ddof = 0; +} + +template <> +void PrepareReduce(const NumpyReduceAxesNoDTypeParam& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* small, int* ddof) { + if (param.initial.has_value()) { + LOG(FATAL) << "initial is not supported yet"; + } + if (param.keepdims) { + *small = outputs[0].shape_; + } else { + *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + + *ddof = 0; +} + +template <> +void PrepareReduce(const NumpyReduceAxesParam& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* small, int* ddof) { + if (param.initial.has_value()) { + LOG(FATAL) << "initial is not supported yet"; + } + if (param.keepdims) { + *small = outputs[0].shape_; + } else { + *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + + *ddof = 0; +} + +template <> +void PrepareReduce(const NumpyReduceAxesBoolParam& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* small, int* ddof) { + if (param.keepdims) { + *small = outputs[0].shape_; + } else { + *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + + *ddof = 0; +} + +} // namespace + +template +void ReduceAxesRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) return; + mxnet::TShape small; + int ddof; + const auto& param = nnvm::get(attrs.parsed); + CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place"; + PrepareReduce(param, inputs, outputs, &small, &ddof); + if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor + if (inputs[0].shape_.Size() == 0) { + if (normalize && mxnet::common::is_float(outputs[0].type_flag_)) { + LOG(WARNING) << "WARNING: Mean of empty slice."; + NumpyBinaryScalarParam p{}; + p.scalar = std::numeric_limits::quiet_NaN(); + NodeAttrs a; + a.parsed = p; + BinaryScalarRTCCompute {"right"} (a, ctx, outputs, {kWriteTo}, outputs); + } else { + if (normalize) { + LOG(WARNING) << "WARNING: nan is outside the range of"<< + "representable values of type 'int'"; + } + if (init == 0 && req[0] == kAddTo) return; + NumpyBinaryScalarParam p{}; + p.scalar = init; + NodeAttrs a; + a.parsed = p; + BinaryScalarRTCCompute {"right"} (a, ctx, outputs, {req[0]}, outputs); + } + return; + } + + ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, reducer, nullptr, normalize, OP, ddof); +} + +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; + +#endif + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 872a89859a60..7baf25552f7f 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -658,7 +658,9 @@ void ReduceAxesComputeImpl(const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs, - const mxnet::TShape& small) { + const mxnet::TShape& small, + const mshadow::Tensor* workspace = nullptr, + const int ddof = 0) { using namespace mshadow; using namespace mshadow::expr; @@ -670,15 +672,18 @@ void ReduceAxesComputeImpl(const OpContext& ctx, const TBlob in_data = inputs[0].reshape(src_shape); const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_, sizeof(OType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Tensor w; + if (workspace == nullptr) { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_); + w = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + workspace = &w; + } broadcast::Reduce( - s, out_data, req[0], workspace, in_data); + s, out_data, req[0], *workspace, in_data); if (normalize) { auto out = out_data.FlatTo2D(s); - out /= scalar(src_shape.Size()/dst_shape.Size()); + out /= scalar(src_shape.Size()/dst_shape.Size() - ddof); } }); }); @@ -704,7 +709,7 @@ void ReduceAxesComputeBoolImpl(const OpContext& ctx, const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_, sizeof(OType)); + s, out_data.shape_, req[0], in_data.shape_); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); broadcast::ReduceBool( @@ -736,6 +741,35 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); } +#if MXNET_USE_CUDA + +template +struct ReduceAxesRTCCompute { + std::string OP; + std::string reducer; + bool normalize; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +void ReduceAxesRTCComputeImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& small, + const std::string& reducer, + const mshadow::Tensor* workspace = nullptr, + + const bool normalize = false, + const std::string& OP = "identity", + const int ddof = 0); + +#endif + template struct ReduceCsrKernel; @@ -1516,7 +1550,8 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, } else { small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false); } - bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); +#if !defined(__CUDACC__) + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) { common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LpNorm with float16 inputs. " "See https://mxnet.apache.org/api/faq/env_var " @@ -1539,6 +1574,15 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, ctx, inputs, req, outputs, small); } } +#else + const std::string &red = param.ord == 1 + ? "red::sum{}" + : "red::nrm2{}"; + const std::string &op = param.ord == 1 + ? "abs" + : "identity"; + ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, red, nullptr, false, op); +#endif } template diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu index 35b3c0272db8..f7c28341fed5 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cu +++ b/src/operator/tensor/broadcast_reduce_op_value.cu @@ -37,7 +37,8 @@ NNVM_REGISTER_OP(broadcast_like) .set_attr("FCompute", BroadcastCompute); NNVM_REGISTER_OP(_broadcast_backward) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::sum{}", false}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/broadcast_reduce_prod_value.cu b/src/operator/tensor/broadcast_reduce_prod_value.cu index 5731de308064..7e7a95b50677 100644 --- a/src/operator/tensor/broadcast_reduce_prod_value.cu +++ b/src/operator/tensor/broadcast_reduce_prod_value.cu @@ -28,13 +28,15 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(prod) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::product{}", false}); NNVM_REGISTER_OP(_backward_prod) .set_attr("FCompute", ReduceAxesBackwardUseInOut); NNVM_REGISTER_OP(nanprod) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::nanprod{}", false}); NNVM_REGISTER_OP(_backward_nanprod) .set_attr("FCompute", ReduceAxesBackwardUseInOut); diff --git a/src/operator/tensor/broadcast_reduce_sum_value.cu b/src/operator/tensor/broadcast_reduce_sum_value.cu index 2385d36f35b0..40a8ed8d17bf 100644 --- a/src/operator/tensor/broadcast_reduce_sum_value.cu +++ b/src/operator/tensor/broadcast_reduce_sum_value.cu @@ -28,19 +28,22 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(sum) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::sum{}", false}); NNVM_REGISTER_OP(_backward_sum) .set_attr("FCompute", ReduceAxesBackwardUseNone); NNVM_REGISTER_OP(mean) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::sum{}", true}); NNVM_REGISTER_OP(_backward_mean) .set_attr("FCompute", ReduceAxesBackwardUseNone); NNVM_REGISTER_OP(nansum) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::nansum{}", false}); NNVM_REGISTER_OP(_backward_nansum) .set_attr("FCompute", ReduceAxesBackwardUseInOut); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.cc b/src/operator/tensor/elemwise_binary_broadcast_op.cc index 2f9832a173f6..9a682dc649c3 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op.cc @@ -376,10 +376,10 @@ void BinaryBroadcastRTCBackwardUseNone::operator()(const nnvm::NodeAttrs& attrs, if (out.shape_.Size() != 0) { broadcast::RTCReduce(ctx, lhs, req[0], workspace, out, - "red::sum", NDim, LOP); + "red::sum{}", NDim, LOP); broadcast::RTCReduce(ctx, rhs, req[1], workspace, out, - "red::sum", NDim, ROP); + "red::sum{}", NDim, ROP); } else { using namespace common::cuda::rtc::util; if (lhs.shape_.Size() != 0) { @@ -425,21 +425,21 @@ void BinaryBroadcastRTCBackwardUseIn::operator()(const nnvm::NodeAttrs& attrs, const TBlob rhs = inputs[2].reshape(new_rshape); size_t workspace_size_l = broadcast::ReduceWorkspaceSize( s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, - rhs.shape_, common::mshadow_type_info(outputs[0].type_flag_).size); + rhs.shape_); size_t workspace_size_r = broadcast::ReduceWorkspaceSize( s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, - rhs.shape_, common::mshadow_type_info(outputs[1].type_flag_).size); + rhs.shape_); size_t workspace_size = std::max(workspace_size_l, workspace_size_r); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); if (req[0] != kNullOp) { broadcast::RTCReduce(ctx, lgrad, req[0], workspace, - ograd, lhs, rhs, "red::sum", NDim, + ograd, lhs, rhs, "red::sum{}", NDim, "mul", LOP); } if (req[1] != kNullOp) { broadcast::RTCReduce(ctx, rgrad, req[1], workspace, - ograd, lhs, rhs, "red::sum", NDim, + ograd, lhs, rhs, "red::sum{}", NDim, "mul", ROP); } }); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index a5bfdd77050c..b1700c7a3882 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -629,9 +629,9 @@ inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx, const TBlob lhs = inputs[1].reshape(new_lshape); const TBlob rhs = inputs[2].reshape(new_rshape); size_t workspace_size_l = ReduceWorkspaceSize( - s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); + s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_); size_t workspace_size_r = ReduceWorkspaceSize( - s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); + s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_); size_t workspace_size = std::max(workspace_size_l, workspace_size_r); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 49a3ed27e7a2..24d6ca8af6d2 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2046,8 +2046,13 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs, inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; +#if !defined(__CUDACC__) ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); +#else + ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first, + "red::sum{}", nullptr, false); +#endif } struct TileParam : public dmlc::Parameter { @@ -2238,8 +2243,13 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs, inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; +#if !defined(__CUDACC__) ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); +#else + ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first, + "red::sum{}", nullptr, false); +#endif } struct ReverseParam : public dmlc::Parameter { diff --git a/src/operator/tensor/reduce_rtc.cc b/src/operator/tensor/reduce_rtc.cc index 9e2d6d3f2a53..ac39f6f09dc7 100644 --- a/src/operator/tensor/reduce_rtc.cc +++ b/src/operator/tensor/reduce_rtc.cc @@ -49,12 +49,39 @@ struct reduce_kernel_params { const char reduce_function_code[] = R"code( #define FUNC OP(IType0::from(big[idx_big[u]])) +using AType = typename AccType::type; )code"; const char reduce_function_use_input_code[] = R"code( #define FUNC OP1(IType0::from(big[idx_big[u]]), \ OP2(IType1::from(lhs[idx_lhs[u]]), \ IType2::from(rhs[idx_rhs[u]]))) +using AType = typename AccType::type; +)code"; + +const char reduce_function_index_code[] = R"code( +#define FUNC AType(OP(IType0::from(big[idx_big[u]])), index) + +template +struct AccIndex { + index_t idx; + T num; + + __device__ inline AccIndex() {} + __device__ inline AccIndex(const T& val, const index_t idx) : num(val), idx(idx) {} + + __device__ inline operator index_t() const volatile { + return idx; + } + + __device__ inline AccIndex& operator=(const AccIndex& other) { + idx = other.idx; + num = other.num; + return *this; + } +}; + +using AType = AccIndex::type>; )code"; const char reduce_kernel_code[] = R"code( @@ -71,21 +98,107 @@ struct reduce_kernel_params { index_t rhs_shape[util::MAX_DIM]; }; -__launch_bounds__(kRTCMaxThreadsPerBlock) -__global__ void reduce_kernel(const int N, const int M, const bool addto, - const InputType0* __restrict big, - const InputType1* __restrict lhs, - const InputType2* __restrict rhs, - OutputType0 *small, - const reduce_kernel_params params, - const int Mnext) { +inline __device__ AType reduce(const index_t idx, const int tidx, + const int tidy, const int N, + const index_t Mstart, const index_t Mend, + const InputType0* __restrict big, + const InputType1* __restrict lhs, + const InputType2* __restrict rhs, + const reduce_kernel_params& params) { extern __shared__ char shTileChar[]; using IType0 = AccType; using IType1 = AccType; using IType2 = AccType; using OType = AccType; - using AType = typename IType0::type; AType* shTile = (AType*)(shTileChar); + const int bx = (do_transpose) ? blockDim.y : blockDim.x; + const int by = (do_transpose) ? blockDim.x : blockDim.y; + index_t coord[ndim]; + util::unravel(idx, params.small_shape, coord); + index_t idx_big0, idx_lhs0, idx_rhs0; + idx_big0 = util::ravel(coord, params.big_shape); + if (use_input) { + idx_lhs0 = util::ravel(coord, params.lhs_shape0); + idx_rhs0 = util::ravel(coord, params.rhs_shape0); + } + + AType val, residual; + REDUCER.SetInitValue(val, residual); + if (idx < N) { + for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) { + index_t idx_big[UNROLL]; + index_t idx_lhs[UNROLL]; + index_t idx_rhs[UNROLL]; + #pragma unroll + for (int u=0;u < UNROLL;u++) { + idx_big[u] = idx_big0 + util::unravel_dot(k + u*by, params.rshape, + params.rstride); + if (use_input) { + idx_lhs[u] = idx_lhs0 + util::unravel_dot(k + u*by, params.lhs_shape, + params.lhs_stride); + idx_rhs[u] = idx_rhs0 + util::unravel_dot(k + u*by, params.rhs_shape, + params.rhs_stride); + } + } + AType tmp[UNROLL]; + #pragma unroll + for (int u=0;u < UNROLL;u++) { + if (k + u*by < Mend) { + const index_t index = k + u*by; + tmp[u] = FUNC; + } + } + #pragma unroll + for (int u=0;u < UNROLL;u++) { + if (k + u*by < Mend) REDUCER.Reduce(val, tmp[u], residual); + } + } + } + + // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 + if (by > 1) { + // Fix bx to avoid bank conflicts. Assumes warpSize number of banks + const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; + const int it0 = tidx + tidy*fbx; + shTile[it0 * 2] = val; + shTile[it0 * 2 + 1] = residual; + __syncthreads(); + for (int t=1;t < by;t <<= 1) { + AType tmp, tmp_residual; + REDUCER.SetInitValue(tmp, tmp_residual); + if (tidy + t < by) { + tmp = shTile[(it0 + t*fbx) * 2]; + tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; + } + __syncthreads(); + REDUCER.Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); + __syncthreads(); + } + if (idx < N && tidy == 0) { + REDUCER.Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); + return shTile[tidx * 2]; + } else { + return AType(); + } + } else { + if (idx < N) { + REDUCER.Finalize(val, residual); + return val; + } else { + return AType(); + } + } +} + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void reduce_kernel_single(const int N, const int M, + const InputType0* __restrict big, + const InputType1* __restrict lhs, + const InputType2* __restrict rhs, + OutputType0 *small, + const reduce_kernel_params params, + const int Mnext) { + using OType = AccType; const int tid = threadIdx.x + threadIdx.y*blockDim.x; const int bx = (do_transpose) ? blockDim.y : blockDim.x; const int by = (do_transpose) ? blockDim.x : blockDim.y; @@ -96,117 +209,74 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext); const index_t Mend = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext); for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - index_t coord[ndim]; - util::unravel(idx, params.small_shape, coord); - index_t idx_big0, idx_lhs0, idx_rhs0; - idx_big0 = util::ravel(coord, params.big_shape); - if (use_input) { - idx_lhs0 = util::ravel(coord, params.lhs_shape0); - idx_rhs0 = util::ravel(coord, params.rhs_shape0); - } - - AType val, residual; - REDUCER::SetInitValue(val, residual); - if (idx < N) { - for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) { - index_t idx_big[UNROLL]; - index_t idx_lhs[UNROLL]; - index_t idx_rhs[UNROLL]; - #pragma unroll - for (int u=0;u < UNROLL;u++) { - idx_big[u] = idx_big0 + util::unravel_dot(k + u*by, params.rshape, - params.rstride); - if (use_input) { - idx_lhs[u] = idx_lhs0 + util::unravel_dot(k + u*by, params.lhs_shape, - params.lhs_stride); - idx_rhs[u] = idx_rhs0 + util::unravel_dot(k + u*by, params.rhs_shape, - params.rhs_stride); - } - } - typename OType::type tmp[UNROLL]; - #pragma unroll - for (int u=0;u < UNROLL;u++) { - if (k + u*by < Mend) { - tmp[u] = FUNC; - } - } - #pragma unroll - for (int u=0;u < UNROLL;u++) { - if (k + u*by < Mend) REDUCER::Reduce(val, tmp[u], residual); - } + const index_t idx = idx0 + tidx; + AType val = reduce(idx, tidx, tidy, N, Mstart, Mend, big, lhs, rhs, params); + if (idx < N && (by == 1 || tidy == 0)) { + if (req == OpReqType::kAddTo) { + small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), + static_cast(val))); + } else { + small[idx + m0 * N] = OType::to(val); } } + } + } +} - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - AType tmp, tmp_residual; - REDUCER::SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - REDUCER::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - REDUCER::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - if (addto) { - small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), - shTile[tidx * 2])); - } else { - small[idx + m0 * N] = OType::to(shTile[tidx * 2]); - } - } - } else { - if (idx < N) { - REDUCER::Finalize(val, residual); - if (addto) { - small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), - val)); - } else { - small[idx + m0 * N] = OType::to(val); - } - } +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void reduce_kernel_multi(const int N, const int M, + const InputType0* __restrict big, + const InputType1* __restrict lhs, + const InputType2* __restrict rhs, + AType *small, + const reduce_kernel_params params, + const int Mnext) { + const int tid = threadIdx.x + threadIdx.y*blockDim.x; + const int bx = (do_transpose) ? blockDim.y : blockDim.x; + const int by = (do_transpose) ? blockDim.x : blockDim.y; + const int tidx = (do_transpose) ? tid / by : threadIdx.x; + const int tidy = (do_transpose) ? tid % by : threadIdx.y; + for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { + // This TB handles M range [Mstart, ...., Mend - 1] + const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext); + const index_t Mend = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext); + for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { + const index_t idx = idx0 + tidx; + AType val = reduce(idx, tidx, tidy, N, Mstart, Mend, big, lhs, rhs, params); + if (idx < N && (by == 1 || tidy == 0)) { + small[idx + m0 * N] = val; } } } } + )code"; const char reduce_lines_kernel_code[] = R"code( __launch_bounds__(kRTCMaxThreadsPerBlock) __global__ void reduce_lines_kernel(const index_t N, const index_t M, const index_t small_in_stride, - const OutputType0* __restrict small_in, + const AType* __restrict small_in, OutputType0 *small_out) { using OType = AccType; for (index_t idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - typename OType::type val, residual; - REDUCER::SetInitValue(val, residual); + AType val, residual; + REDUCER.SetInitValue(val, residual); for (int k = 0; k < M; k++) { - REDUCER::Reduce(val, - OType::from(reinterpret_cast(small_in)[idx + k*small_in_stride]), + REDUCER.Reduce(val, + small_in[idx + k*small_in_stride], residual); } if (idx < N) { - REDUCER::Finalize(val, residual); + REDUCER.Finalize(val, residual); if (req == OpReqType::kAddTo) { - small_out[idx] = OType::to(op::add(OType::from(small_out[idx]), val)); + small_out[idx] = OType::to(op::add(OType::from(small_out[idx]), + static_cast(val))); } else { small_out[idx] = OType::to(val); } } - } } )code"; @@ -215,14 +285,13 @@ void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, const TBlob& big, const Tensor& workspace, const ReduceImplConfig& config, const int ndim, const std::string &common_code, int dev_id, - const TBlob *lhs = nullptr, const TBlob *rhs = nullptr) { + const TBlob *lhs = nullptr, const TBlob *rhs = nullptr, + const bool use_index = false) { using namespace common::cuda::rtc; void* small_dptr = small.dptr_; - bool first_kernel_addto = addto; if (config.Mnext > 1) { // small_dptr[] is N*Mnext*sizeof(DType) bytes small_dptr = workspace.dptr_; - first_kernel_addto = false; // Check that the workspace is contigiuous CHECK_EQ(workspace.CheckContiguous(), true); // Check that we have enough storage @@ -281,7 +350,6 @@ void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, std::vector args; args.emplace_back(&config.N); args.emplace_back(&config.M); - args.emplace_back(&first_kernel_addto); args.emplace_back(&big.dptr_); if (lhs != nullptr) { args.emplace_back(&(lhs->dptr_)); @@ -295,10 +363,11 @@ void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, args.emplace_back(&config.Mnext); const auto &function_code = (lhs == nullptr) - ? reduce_function_code + ? (use_index ? reduce_function_index_code : reduce_function_code) : reduce_function_use_input_code; + const auto& kernel_name = (config.Mnext > 1) ? "reduce_kernel_multi" : "reduce_kernel_single"; auto reduce_kernel_func = get_function(code + function_code, - "reduce_kernel", + kernel_name, reduce_kernel_code, dev_id); launch(reduce_kernel_func, config.kernel_1.gridDim, @@ -313,7 +382,7 @@ void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, args.emplace_back(&small_dptr); args.emplace_back(&small.dptr_); - auto reduce_lines_kernel_func = get_function(code, + auto reduce_lines_kernel_func = get_function(code + function_code, "reduce_lines_kernel", reduce_lines_kernel_code, dev_id); @@ -348,9 +417,11 @@ __global__ void reduce_kernel_M1(const int N, using IType1 = AccType; using IType2 = AccType; using OType = AccType; - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { + for (index_t index = threadIdx.x + blockIdx.x*blockDim.x; + index < N; + index += blockDim.x*gridDim.x) { index_t coord[ndim]; - util::unravel(idx, params.small_shape, coord); + util::unravel(index, params.small_shape, coord); index_t idx_big[1]; idx_big[0] = util::ravel(coord, params.big_shape); index_t idx_lhs[1], idx_rhs[1]; @@ -358,16 +429,17 @@ __global__ void reduce_kernel_M1(const int N, idx_lhs[0] = util::ravel(coord, params.lhs_shape); idx_rhs[0] = util::ravel(coord, params.rhs_shape); } - typename OType::type val, residual; - REDUCER::SetInitValue(val, residual); + AType val, residual; + REDUCER.SetInitValue(val, residual); const int u = 0; - REDUCER::Reduce(val, FUNC, residual); - REDUCER::Finalize(val, residual); + REDUCER.Reduce(val, FUNC, residual); + REDUCER.Finalize(val, residual); if (req == OpReqType::kAddTo) { - const auto temp = op::add(val, OType::from(small[idx])); - small[idx] = OType::to(temp); + const auto temp = op::add(static_cast(val), + OType::from(small[index])); + small[index] = OType::to(temp); } else { - small[idx] = OType::to(val); + small[index] = OType::to(static_cast(val)); } } } @@ -376,7 +448,8 @@ __global__ void reduce_kernel_M1(const int N, void RTCReduceM1Impl(Stream *s, const TBlob &small, const TBlob &big, const TBlob *lhs, const TBlob *rhs, const ReduceImplConfig &config, const int ndim, - const std::string &common_code, int dev_id) { + const std::string &common_code, int dev_id, + const bool use_index = false) { using namespace common::cuda::rtc; std::string code = common_code + @@ -427,7 +500,7 @@ void RTCReduceM1Impl(Stream *s, const TBlob &small, const TBlob &big, args.emplace_back(¶m); const auto &function_code = (lhs == nullptr) - ? reduce_function_code + ? (use_index ? reduce_function_index_code : reduce_function_code) : reduce_function_use_input_code; auto reduce_kernel_M1_func = get_function(code + function_code, "reduce_kernel_M1", @@ -447,14 +520,12 @@ void RTCReduce(const OpContext& ctx, const TBlob& big, const std::string& reducer, int ndim, - const std::string& OP) { + const std::string& OP, + const bool use_index) { using namespace mxnet::common::cuda::rtc; if (req == kNullOp) return; Stream *s = ctx.get_stream(); - size_t big_type_size = common::mshadow_type_info(big.type_flag_).acc_size; - size_t small_type_size = common::mshadow_type_info(small.type_flag_).acc_size; - size_t type_size = std::max(big_type_size, small_type_size); - ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, type_size); + ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr); std::string common_code = std::string("const OpReqType req = ") + util::to_string(req) + ";\n" @@ -469,10 +540,12 @@ void RTCReduce(const OpContext& ctx, ";\n"; if (config.M == 1) { RTCReduceM1Impl(s, small, big, nullptr, nullptr, config, - ndim, common_code, ctx.run_ctx.ctx.dev_id); + ndim, common_code, ctx.run_ctx.ctx.dev_id, + use_index); } else { RTCReduceImpl(s, small, req == kAddTo, big, workspace, config, - ndim, common_code, ctx.run_ctx.ctx.dev_id); + ndim, common_code, ctx.run_ctx.ctx.dev_id, + nullptr, nullptr, use_index); } } @@ -490,10 +563,7 @@ void RTCReduce(const OpContext& ctx, using namespace mxnet::common::cuda::rtc; if (req == kNullOp) return; Stream *s = ctx.get_stream(); - size_t big_type_size = common::mshadow_type_info(big.type_flag_).acc_size; - size_t small_type_size = common::mshadow_type_info(small.type_flag_).acc_size; - size_t type_size = std::max(big_type_size, small_type_size); - ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, type_size); + ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_); std::string common_code = std::string("const OpReqType req = ") + util::to_string(req) + ";\n" diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index ba8e3278330b..1fc7b8e520c8 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3006,8 +3006,10 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): if isinstance(dtype, tuple): assert len(dtype) == 2 ldtype, rdtype = dtype - np_test_x1 = _np.random.uniform(low, high, lshape).astype(ldtype) - np_test_x2 = _np.random.uniform(low, high, rshape).astype(rdtype) + npldtype = ldtype if dtype != _np.float16 else _np.float32 + nprdtype = rdtype if dtype != _np.float16 else _np.float32 + np_test_x1 = _np.random.uniform(low, high, lshape).astype(ldtype).astype(npldtype) + np_test_x2 = _np.random.uniform(low, high, rshape).astype(rdtype).astype(nprdtype) mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ldtype) mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rdtype) for hybridize in [True, False]: @@ -4372,7 +4374,7 @@ def test_np_argmin_argmax(): ((3, 5, 7), 2, False), ((3, 5, 7, 9, 11), -3, False), ] - dtypes = ['float16', 'float32', 'float64'] + dtypes = ['float16', 'float32', 'float64', 'bool', 'int32'] ops = ['argmin', 'argmax'] class TestArgExtreme(HybridBlock): @@ -4387,7 +4389,7 @@ def hybrid_forward(self, F, x): for op_name in ops: for shape, axis, throw_exception in workloads: for dtype in dtypes: - a = np.random.uniform(size=shape, dtype=dtype) + a = np.random.uniform(low=0, high=100, size=shape).astype(dtype) if throw_exception: # Cannot use assert_exception because sometimes the main thread # proceeds to `assert False` before the exception is thrown