diff --git a/dmlc-core b/dmlc-core index b9d440f926e9..35c7762164cd 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit b9d440f926e905122036e0b6f942148110a655bb +Subproject commit 35c7762164cd19379d03ec403008af3a228c15f9 diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 5f8e32cc3f7b..503d4d4fa554 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -727,7 +727,7 @@ MXNET_DLL int MXKVStoreFree(KVStoreHandle handle); * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, - int num, + mx_uint num, int* keys, NDArrayHandle* vals); @@ -737,24 +737,28 @@ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, * \param num the number of key-value pairs * \param keys the list of keys * \param vals the list of values + * \param priority the priority of the action * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStorePush(KVStoreHandle handle, - int num, + mx_uint num, int* keys, - NDArrayHandle* vals); + NDArrayHandle* vals, + int priority); /*! * \brief pull a list of (key, value) pairs from the kvstore * \param handle handle to the kvstore * \param num the number of key-value pairs * \param keys the list of keys * \param vals the list of values + * \param priority the priority of the action * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, - int num, + mx_uint num, int* keys, - NDArrayHandle* vals); + NDArrayHandle* vals, + int Priority); /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 885f3bd2147e..065ddcb74aa6 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -36,6 +36,8 @@ enum class FnProperty { kCopyFromGPU, /*! \brief Copy operation from CPU to other devices */ kCopyToGPU, + /*! \brief Prioritized sync operation on CPU */ + kCPUPrioritized, /*! \brief Asynchronous function call */ kAsync }; // enum class FnProperty @@ -116,8 +118,9 @@ class Engine { * \brief Push an operator to the engine. * \param op The operator to push. * \param exec_ctx Execution context. + * \param priority Priority of the action, as hint to the engine. */ - virtual void Push(OprHandle op, Context exec_ctx) = 0; + virtual void Push(OprHandle op, Context exec_ctx, int priority = 0) = 0; /*! * \brief Push an asynchronous operation to the engine. * \param exec_fun Execution function, this function takes a parameter @@ -128,11 +131,13 @@ class Engine { * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. + * \param priority Priority of the action, as hint to the engine. */ virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal) = 0; + FnProperty prop = FnProperty::kNormal, + int priority = 0) = 0; /*! * \brief Schedule the deletion of a variable. * @@ -180,17 +185,19 @@ class Engine { * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. + * \param priority Priority of the action, as hint to the engine. * \tparam SyncFn the synchronous function to be pushed. */ template inline void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal) { + FnProperty prop = FnProperty::kNormal, + int priority = 0) { this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { exec_fn(ctx); on_complete(); - }, exec_ctx, const_vars, mutable_vars, prop); + }, exec_ctx, const_vars, mutable_vars, prop, priority); } protected: diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 0a92de56a243..ba7bcdd6e972 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -62,9 +62,11 @@ class KVStore { * * \param keys the list of keys * \param value the list of values + * \param priority Priority of the action. */ virtual void Push(const std::vector& keys, - const std::vector& values) = 0; + const std::vector& values, + int priority = 0) = 0; /*! * \brief pull a list of key-value pairs from the store * @@ -80,9 +82,11 @@ class KVStore { * * \param keys the list of keys * \param values the list of buffers for the pulled data, they should be preallocated + * \param priority Priority of the action. */ virtual void Pull(const std::vector& keys, - const std::vector& values) = 0; + const std::vector& values, + int priority = 0) = 0; #if DMLC_USE_CXX11 /** * \brief the prototype of user-defined updater diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index f9d21441f09e..587a698c9616 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -337,8 +337,9 @@ class NDArray { * due to different possible convention carried by copy function * \param from the ndarray we want to copy data from * \param to the target ndarray + * \param priority Priority of the action. */ -void CopyFromTo(const NDArray &from, NDArray *to); +void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); /*! * \brief elementwise add diff --git a/python/mxnet/io.py b/python/mxnet/io.py index c39b0f01bed6..43aebf30bce7 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -182,14 +182,32 @@ class MXDataIter(DataIter): def __init__(self, handle): super(MXDataIter, self).__init__() self.handle = handle + # debug option, used to test the speed with io effect eliminated + self._debug_skip_load = False + self._debug_at_begin = True def __del__(self): check_call(_LIB.MXDataIterFree(self.handle)) + def debug_skip_load(self): + """Set the iterator to simply return always first batch. + + Notes + ----- + This can be used to test the speed of network without taking + the loading delay into account. + """ + self._debug_skip_load = True + logging.info('Set debug_skip_load to be true, will simply return first batch') + def reset(self): + self._debug_at_begin = True check_call(_LIB.MXDataIterBeforeFirst(self.handle)) def next(self): + if self._debug_skip_load and not self._debug_at_begin: + return self.getdata(), self.getlabel() + self._debug_at_begin = False next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) if next_res.value: diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index ededb08c2f14..3d26c3cf4e50 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -6,7 +6,7 @@ import ctypes from .ndarray import NDArray from .base import _LIB -from .base import check_call, c_array, c_str, string_types +from .base import check_call, c_array, c_str, string_types, mx_uint from .base import NDArrayHandle, KVStoreHandle @@ -92,17 +92,22 @@ def init(self, key, value): >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStoreInit(self.handle, len(ckeys), ckeys, cvals)) + check_call(_LIB.MXKVStoreInit( + self.handle, mx_uint(len(ckeys)), ckeys, cvals)) - def push(self, key, value): + def push(self, key, value, priority=0): """ Push a single or a sequence of key-value pairs into the store. Parameters ---------- key : int or list of int Keys - value: NDArray or list of NDArray or list of list of NDArray + value : NDArray or list of NDArray or list of list of NDArray According values + priority : int, optional + The priority of the push operation. + The higher the priority, the faster this action is likely + to be executed before other push actions. Examples -------- @@ -140,9 +145,11 @@ def push(self, key, value): [ 4. 4. 4.]] """ ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStorePush(self.handle, len(ckeys), ckeys, cvals)) + check_call(_LIB.MXKVStorePush( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority))) - def pull(self, key, out=None): + def pull(self, key, out=None, priority=0): """ Pull a single value or a sequence of values from the store Parameters @@ -151,6 +158,10 @@ def pull(self, key, out=None): Keys out: NDArray or list of NDArray or list of list of NDArray According values + priority : int, optional + The priority of the push operation. + The higher the priority, the faster this action is likely + to be executed before other push actions. Examples -------- @@ -185,7 +196,9 @@ def pull(self, key, out=None): """ assert(out is not None) ckeys, cvals = _ctype_key_value(key, out) - check_call(_LIB.MXKVStorePull(self.handle, len(ckeys), ckeys, cvals)) + check_call(_LIB.MXKVStorePull( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority))) def set_updater(self, updater): """Set a push updater into the store. diff --git a/python/mxnet/misc.py b/python/mxnet/misc.py index 43da2e1fc350..2d3ffc6e5abd 100644 --- a/python/mxnet/misc.py +++ b/python/mxnet/misc.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name, logging-not-lazy, arguments-differ +# pylint: disable=invalid-name """learning rate scheduler""" import math @@ -9,8 +9,15 @@ class LearningRateScheduler(object): def __init__(self): self.base_lr = 0.01 - def __call__(self): - """lr calculation function""" + def __call__(self, iteration): + """ + Call to schedule current learning rate + + Parameters + ---------- + iteration: int + Current iteration count + """ raise NotImplementedError("must override this") @@ -51,8 +58,8 @@ def __call__(self, iteration): lr = self.base_lr * math.pow(self.factor, int(iteration / self.step)) if lr != self.old_lr: self.old_lr = lr - logging.info("At Iteration [%d]: Swith to new learning rate %.5f" \ - % (iteration, lr)) + logging.info("At Iteration [%d]: Swith to new learning rate %.5f", + iteration, lr) return lr diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 9f576f5eeb16..3e0fa1556266 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -258,10 +258,10 @@ def _train_multi_device(symbol, ctx, input_shape, continue # Gradient synchronization if kv: - # push gradient - kv.push(index, grad_list) + # push gradient, priority is negative index + kv.push(index, grad_list, priority=-index) # pull back the sum, to the same locations. - kv.pull(index, grad_list) + kv.pull(index, grad_list, priority=-index) opt_list = opt_state_blocks[index] # optimizea for w, g, state in zip(arg_list, grad_list, opt_list): diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index bef3aa0d83c0..686f8cca3554 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -53,7 +53,7 @@ def plot_network(symbol, title="plot", shape=None): if shape != None: draw_shape = True interals = symbol.get_internals() - _, out_shapes, __ = interals.infer_shape(**shape) + _, out_shapes, _ = interals.infer_shape(**shape) if out_shapes == None: raise ValueError("Input shape is incompete") shape_dict = dict(zip(interals.list_outputs(), out_shapes)) diff --git a/src/c_api.cc b/src/c_api.cc index b71deade4c03..0e3cd487d99c 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -945,12 +945,12 @@ int MXKVStoreFree(KVStoreHandle handle) { } int MXKVStoreInit(KVStoreHandle handle, - int num, int* keys, + mx_uint num, int* keys, NDArrayHandle* vals) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); - for (int i = 0; i < num; ++i) { + for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = *static_cast(vals[i]); } @@ -959,28 +959,32 @@ int MXKVStoreInit(KVStoreHandle handle, } int MXKVStorePush(KVStoreHandle handle, - int num, int* keys, NDArrayHandle* vals) { + mx_uint num, int* keys, + NDArrayHandle* vals, + int priority) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); - for (int i = 0; i < num; ++i) { + for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = *static_cast(vals[i]); } - static_cast(handle)->Push(v_keys, v_vals); + static_cast(handle)->Push(v_keys, v_vals, priority); API_END(); } int MXKVStorePull(KVStoreHandle handle, - int num, int* keys, NDArrayHandle* vals) { + mx_uint num, int* keys, + NDArrayHandle* vals, + int priority) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); - for (int i = 0; i < num; ++i) { + for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = static_cast(vals[i]); } - static_cast(handle)->Pull(v_keys, v_vals); + static_cast(handle)->Pull(v_keys, v_vals, priority); API_END(); } diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index dcfd5371ba79..dd112e879f3a 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -55,7 +55,7 @@ class NaiveEngine final : public Engine { NaiveOpr *opr = op->Cast(); delete opr; } - void Push(OprHandle op, Context exec_ctx) override { + void Push(OprHandle op, Context exec_ctx, int priority) override { NaiveOpr *opr = op->Cast(); this->PushAsync(opr->fn, exec_ctx, @@ -67,7 +67,8 @@ class NaiveEngine final : public Engine { Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop) override { + FnProperty prop, + int priority = 0) override { CallbackOnComplete callback = CreateCallback( NaiveEngine::OnComplete, nullptr); this->req_completed_ = false; diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 35ec65a423c6..ea122a67becf 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -29,7 +29,7 @@ ThreadedVar::ThreadedVar(VersionedVarBlock* head) : head_{head} { #endif // ENGINE_DEBUG } -void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { +inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { std::lock_guard lock{m_}; if (pending_write_ == nullptr) { // invariant: is_ready_to_read() @@ -50,7 +50,7 @@ void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { } } -void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { +inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { auto&& new_var_block = VersionedVarBlock::New(); std::lock_guard lock{m_}; // invariant. @@ -79,7 +79,7 @@ void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { } template -void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { +inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { OprBlock *trigger = nullptr; { // this is lock scope @@ -100,7 +100,7 @@ void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { } template -bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { +inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // this is lock scope VersionedVarBlock *old_pending_write, *end_of_read_chain; OprBlock* trigger_write = nullptr; @@ -167,12 +167,12 @@ bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { return false; } -void ThreadedVar::SetToDelete() { +inline void ThreadedVar::SetToDelete() { std::lock_guard lock{m_}; to_delete_ = true; } -bool ThreadedVar::ready_to_read() { +inline bool ThreadedVar::ready_to_read() { std::lock_guard lock{m_}; return this->is_ready_to_read(); } @@ -252,7 +252,7 @@ void ThreadedEngine::DeleteOperator(OprHandle op) { }, Context::CPU(), {}, deps, FnProperty::kAsync); } -void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { +void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority) { ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); OprBlock* opr_block = OprBlock::New(); opr_block->opr = threaded_opr; @@ -261,6 +261,7 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size() + 1)); opr_block->ctx = exec_ctx; + opr_block->priority = priority; ++pending_; // Add read dependencies. for (auto&& i : threaded_opr->const_vars) { @@ -278,10 +279,10 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop) { + FnProperty prop, int priority) { ThreadedOpr *opr = NewOperator(fn, const_vars, mutable_vars, prop); opr->temporary = true; - Push(opr, exec_ctx); + Push(opr, exec_ctx, priority); } void ThreadedEngine::DeleteVariable(SyncFn delete_fn, diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 5ca2a503563a..13c55c6d5747 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -48,6 +48,8 @@ struct OprBlock : public common::ObjectPoolAllocatable { ThreadedOpr* opr{nullptr}; /*! \brief The context this operator */ Context ctx; + /*! \brief priority of the function */ + int priority; // define possible debug information DEFINE_ENGINE_DEBUG_INFO(OprBlock); /*! @@ -98,7 +100,7 @@ class ThreadedVar final : public Var, * Otherwise, the opr_block will be added to waiting queue. * \param opr_block The operation to be scheduled. */ - void AppendReadDependency(OprBlock* opr_block); + inline void AppendReadDependency(OprBlock* opr_block); /*! * \brief Schedule a write operation on this variable. * If the opr_block can be runed right away, @@ -106,7 +108,7 @@ class ThreadedVar final : public Var, * Otherwise, the opr_block will be added to waiting queue. * \param opr_block The operation to be scheduled. */ - void AppendWriteDependency(OprBlock* opr_block); + inline void AppendWriteDependency(OprBlock* opr_block); /*! * \brief A read operation is completed on this variable. * This function may trigger subsequent waiting operations on this variable. @@ -116,7 +118,7 @@ class ThreadedVar final : public Var, * \tparam Dispatcher the function called to trigger an operation. */ template - void CompleteReadDependency(Dispatcher dispatcher); + inline void CompleteReadDependency(Dispatcher dispatcher); /*! * \brief A write operation is completed on this variable. * This function may trigger subsequent waiting operations on this variable. @@ -127,11 +129,11 @@ class ThreadedVar final : public Var, * \return to_delete, whether this Variable can be deleted after this functin. */ template - bool CompleteWriteDependency(Dispatcher dispatcher); + inline bool CompleteWriteDependency(Dispatcher dispatcher); /*! \brief Mark this variable to be deleted. */ - void SetToDelete(); + inline void SetToDelete(); /*! \return whether this variable is ready to read. */ - bool ready_to_read(); + inline bool ready_to_read(); /*! * \brief Cast a Var pointer to ThreadedVar pointer * \param ptr pointer from base. @@ -234,11 +236,12 @@ class ThreadedEngine : public Engine { std::vector const& mutable_vars, FnProperty prop) override; void DeleteOperator(OprHandle op) override; - void Push(OprHandle op, Context exec_ctx) override; + void Push(OprHandle op, Context exec_ctx, int priority) override; void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop) override; + FnProperty prop, + int priority) override; void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override; void WaitForVar(VarHandle var) override; void WaitForAll() override; diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index ad4612718abf..62738cbaedac 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -25,23 +25,30 @@ namespace engine { */ class ThreadedEnginePerDevice : public ThreadedEngine { public: + static auto constexpr kFIFO = dmlc::ConcurrentQueueType::kFIFO; + static auto constexpr kPriority = dmlc::ConcurrentQueueType::kPriority; + static auto constexpr kCopyQueue = kPriority; + static auto constexpr kPriorityQueue = kPriority; + static auto constexpr kWorkerQueue = kFIFO; + ThreadedEnginePerDevice() noexcept(false) { - cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 1); gpu_worker_nthreads_ = common::GetNumThreadPerGPU(); gpu_copy_nthreads_ = dmlc::GetEnv("MXNET_GPU_COPY_NTHREADS", 1); + cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 1); // create CPU task - cpu_worker_.reset(new ThreadWorkerBlock()); - auto *cpu_queue = &(cpu_worker_->task_queue); - cpu_worker_->pool.reset(new ThreadPool( - cpu_worker_nthreads_, [this, cpu_queue] { - this->CPUWorker(cpu_queue); + int cpu_priority_nthreads = dmlc::GetEnv("MXNET_CPU_PRIORITY_NTHREADS", 4); + cpu_priority_worker_.reset(new ThreadWorkerBlock()); + cpu_priority_worker_->pool.reset(new ThreadPool( + cpu_priority_nthreads, [this] { + this->CPUWorker(cpu_priority_worker_.get()); })); // GPU tasks will be created lazily } ~ThreadedEnginePerDevice() noexcept(false) { gpu_normal_workers_.Clear(); gpu_copy_workers_.Clear(); - cpu_worker_.reset(nullptr); + cpu_normal_workers_.Clear(); + cpu_priority_worker_.reset(nullptr); } protected: @@ -58,21 +65,54 @@ class ThreadedEnginePerDevice : public ThreadedEngine { this->ExecuteOprBlock(run_ctx, opr_block); } else { if (ctx.dev_mask() == cpu::kDevMask) { - cpu_worker_->task_queue.Push(opr_block); + if (opr_block->opr->prop == FnProperty::kCPUPrioritized) { + cpu_priority_worker_->task_queue.Push(opr_block, opr_block->priority); + } else { + int dev_id = ctx.dev_id; + int nthread = cpu_worker_nthreads_; + cpu_normal_workers_.Get(dev_id, [this, dev_id, nthread]() { + auto blk = new ThreadWorkerBlock(); + blk->pool.reset(new ThreadPool(nthread, [this, blk] () { + this->CPUWorker(blk); + })); + return blk; + })->task_queue.Push(opr_block, opr_block->priority); + } } else { CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); - ThreadWorkerBlock* block = this->GetGPUWorkerBlock( - ctx.dev_id, opr_block->opr->prop); - block->task_queue.Push(opr_block); + // GPU execution. + FnProperty prop = opr_block->opr->prop; + bool is_copy = (prop == FnProperty::kCopyFromGPU || + prop == FnProperty::kCopyToGPU); + int nthread = gpu_worker_nthreads_; + int dev_id = ctx.dev_id; + if (is_copy) { + gpu_copy_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { + auto blk = new ThreadWorkerBlock(); + blk->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, blk] () { + this->GPUWorker(dev_id, is_copy, blk); + })); + return blk; + })->task_queue.Push(opr_block, opr_block->priority); + } else { + gpu_normal_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { + auto blk = new ThreadWorkerBlock(); + blk->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, blk] () { + this->GPUWorker(dev_id, is_copy, blk); + })); + return blk; + })->task_queue.Push(opr_block, opr_block->priority); + } } } } private: // working unit for each of the task. + template struct ThreadWorkerBlock { // task queue on this task - dmlc::ConcurrentBlockingQueue task_queue; + dmlc::ConcurrentBlockingQueue task_queue; // thread pool that works on this task std::unique_ptr pool; // destructor @@ -87,44 +127,23 @@ class ThreadedEnginePerDevice : public ThreadedEngine { /*! \brief number of concurrent thread each gpu copy worker uses */ int gpu_copy_nthreads_; // cpu worker - std::unique_ptr cpu_worker_; + common::LazyAllocArray > cpu_normal_workers_; + // cpu priority worker + std::unique_ptr > cpu_priority_worker_; // workers doing normal works on GPU - common::LazyAllocArray gpu_normal_workers_; + common::LazyAllocArray > gpu_normal_workers_; // workers doing copy works from/to GPU - common::LazyAllocArray gpu_copy_workers_; - /*! - * \brief get GPU Task Worker - * \param dev_id the device id - * \param prop The property of the function. - */ - inline ThreadWorkerBlock *GetGPUWorkerBlock(int dev_id, - FnProperty prop) { - bool is_copy = (prop == FnProperty::kCopyFromGPU || - prop == FnProperty::kCopyToGPU); - auto *arr = &gpu_normal_workers_; - int nthread = gpu_worker_nthreads_; - if (is_copy) { - arr = &gpu_copy_workers_; - nthread = gpu_copy_nthreads_; - } - - return arr->Get(dev_id, [this, dev_id, is_copy, nthread]() { - auto block = new ThreadWorkerBlock(); - block->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, block] () { - this->GPUWorker(dev_id, is_copy, &(block->task_queue)); - })); - return block; - }); - } + common::LazyAllocArray > gpu_copy_workers_; /*! * \brief GPU worker that performs operations on a certain device. * \param dev_id The device id of the worker. * \param is_copy_worker whether the worker only do copy job - * \param task_queue the device id of the worker. + * \param block The task block of the worker. */ + template inline void GPUWorker(int dev_id, bool is_copy_worker, - dmlc::ConcurrentBlockingQueue* task_queue) { + ThreadWorkerBlock *block) { #if MXNET_USE_CUDA // allocate stream mshadow::SetDevice(dev_id); @@ -138,6 +157,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { run_ctx.stream = stream; // execute task OprBlock* opr_block; + auto* task_queue = &(block->task_queue); while (task_queue->Pop(&opr_block)) { this->ExecuteOprBlock(run_ctx, opr_block); } @@ -147,9 +167,11 @@ class ThreadedEnginePerDevice : public ThreadedEngine { } /*! * \brief CPU worker that performs operations on CPU. - * \param task_queue the device id of the worker. + * \param block The task block of the worker. */ - inline void CPUWorker(dmlc::ConcurrentBlockingQueue* task_queue) { + template + inline void CPUWorker(ThreadWorkerBlock *block) { + auto* task_queue = &(block->task_queue); RunContext run_ctx; run_ctx.stream = nullptr; // execute task diff --git a/src/io/iter_image_recordio.cc b/src/io/iter_image_recordio.cc index 406030d24a5b..957c543da9c9 100644 --- a/src/io/iter_image_recordio.cc +++ b/src/io/iter_image_recordio.cc @@ -183,6 +183,7 @@ inline void ImageRecordIOParser::Init( threadget = omp_get_num_threads(); } param_.preprocess_threads = threadget; + // setup decoders for (int i = 0; i < threadget; ++i) { augmenters_.push_back(new ImageAugmenter()); @@ -196,8 +197,12 @@ inline void ImageRecordIOParser::Init( param_.label_width = 1; } CHECK(param_.path_imgrec.length() != 0) - << "ImageRecordIOIterator: must specify image_rec"; + << "ImageRecordIOIterator: must specify image_rec"; + if (param_.silent == 0) { + LOG(INFO) << "ImageRecordIOParser: " << param_.path_imgrec + << ", use " << threadget << " threads for decoding.."; + } // TODO(mu, tianjun) add DMLC env variable to detect parition const int part_index = 0; const int num_parts = 1; diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index f8cbf53d27b3..df88739b7b15 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -22,7 +22,10 @@ class KVStoreLocal : public KVStore { KVStoreLocal() { pinned_ctx_ = (MXNET_USE_CUDA != 0) ? Context::CPUPinned(0) : Context::CPU(); - set_updater(DefaultUpdater()); + set_updater(nullptr); + // the server perameters + nthread_reduction_ = dmlc::GetEnv("MXNET_KVSTORE_REDUCTION_NTHREADS", 4); + bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); } void set_updater(Updater updater) override { @@ -39,31 +42,46 @@ class KVStoreLocal : public KVStore { } void Push(const std::vector& keys, - const std::vector& values) override { + const std::vector& values, + int priority) override { std::vector uniq_keys; std::vector > grouped_vals; GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); - CHECK(updater_) << "invalid updater"; + for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; - auto it = local_.find(key); - CHECK(it != local_.end()) << "key " << key << " has not been inited"; - updater_(key, MergePushValue(key, grouped_vals[i]), &(it->second)); + const NDArray& merged = MergePushValue(key, grouped_vals[i], priority); + if (updater_ != nullptr) { + auto it = local_.find(key); + CHECK(it != local_.end()) << "key " << key << " has not been inited"; + updater_(key, merged, &(it->second)); + } } } void Pull(const std::vector& keys, - const std::vector& values) override { + const std::vector& values, + int priority) override { std::vector uniq_keys; std::vector > grouped_vals; GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; - auto it = local_.find(key); - CHECK(it != local_.end()) << "key " << key << " has not been inited"; - for (NDArray* v : grouped_vals[i]) { - CopyFromTo(it->second, v); + if (updater_ != nullptr) { + auto it = local_.find(key); + CHECK(it != local_.end()) << "key " << key << " has not been inited"; + const NDArray& src = it->second; + for (auto* vptr : grouped_vals[i]) { + CopyFromTo(src, vptr, priority); + } + } else { + auto it = merge_buf_.find(key); + CHECK(it != merge_buf_.end()) << "key " << key << " has not been pushed"; + auto& src = it->second.merged; + for (auto* vptr : grouped_vals[i]) { + CopyFromTo(src, vptr, priority); + } } } } @@ -102,46 +120,118 @@ class KVStoreLocal : public KVStore { /*! * \brief returns the aggregated push value */ - const NDArray &MergePushValue(int key, const std::vector& val) { + const NDArray& MergePushValue(int key, const std::vector& val, int priority) { CHECK(val.size()); auto& buf = merge_buf_[key]; + // copy buffer + std::vector const_vars(val.size() - 1); + std::vector reduce(val.size()); if (buf.merged.is_none()) { - buf.merged = NDArray(val[0].shape(), pinned_ctx_); + Context ctx = Context::CPUPinned(val[0].ctx().dev_id); + if (MXNET_USE_CUDA == 0) ctx = Context::CPU(); + buf.merged = NDArray(val[0].shape(), ctx); } - CopyFromTo(val[0], &buf.merged); + + CopyFromTo(val[0], &(buf.merged), priority); + reduce[0] = buf.merged; for (size_t i = 1; i < val.size(); ++i) { - const auto& v = val[i]; + const NDArray& v = val[i]; Context ctx = v.ctx(); - if (v.ctx().dev_mask() == cpu::kDevMask) { - buf.merged += v; + const_vars[i - 1] = v.var(); + if (ctx.dev_mask() == cpu::kDevMask) { + reduce[i] = val[i]; } else { - CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); - NDArray *copy_buf = buf.AllocCopyBuf(ctx.dev_id, val[0].shape()); - CopyFromTo(val[i], copy_buf); - buf.merged += *copy_buf; + NDArray *copy_buf = buf.AllocCopyBuf(i, ctx.dev_id, val[0].shape()); + CopyFromTo(val[i], copy_buf, priority); + reduce[i] = *copy_buf; } } + + Engine::Get()->PushSync([reduce, this](RunContext rctx) { + ReduceSum(reduce); + }, Context::CPU(), const_vars, {reduce[0].var()}, + FnProperty::kCPUPrioritized, priority); return buf.merged; } private: + inline static void ReduceSum(const std::vector &dptr, + size_t offset, index_t size) { + using namespace mshadow; // NOLINT(*) + Tensor in_0(dptr[0] + offset, Shape1(size)); + switch (dptr.size()) { + case 2: { + Tensor in_1(dptr[1] + offset, Shape1(size)); + in_0 += in_1; + break; + } + case 3: { + Tensor in_1(dptr[1] + offset, Shape1(size)); + Tensor in_2(dptr[2] + offset, Shape1(size)); + in_0 += in_1 + in_2; + break; + } + case 4: { + Tensor in_1(dptr[1] + offset, Shape1(size)); + Tensor in_2(dptr[2] + offset, Shape1(size)); + Tensor in_3(dptr[3] + offset, Shape1(size)); + in_0 += in_1 + in_2 + in_3; + break; + } + default: { + for (size_t i = 1; i < dptr.size(); ++i) { + Tensor in_k(dptr[i] + offset, Shape1(size)); + in_0 += in_k; + } + } + } + } + // reduce sum into val[0] + // this is performance critical + inline void ReduceSum(const std::vector &in_data) { + const size_t step = 4 << 10; + // ge ptr out + std::vector dptr(in_data.size()); + for (size_t i = 0; i < in_data.size(); ++i) { + TBlob data = in_data[i].data(); + CHECK(data.CheckContiguous()); + dptr[i] = data.FlatTo2D().dptr_; + } + size_t total = in_data[0].shape().Size(); + long ntask = (total + 1 - step) / step; // NOLINT(*) + if (total < bigarray_bound_ || nthread_reduction_ <= 1) { + ReduceSum(dptr, 0, total); + } else { + #pragma omp parallel for schedule(static) num_threads(nthread_reduction_) + for (long j = 0; j < ntask; ++j) { // NOLINT(*) + size_t k = static_cast(j); + size_t begin = std::min(k * step, total); + size_t end = std::min((k + 1) * step, total); + ReduceSum(dptr, begin, static_cast(end - begin)); + } + } + } /// \brief temperal space for pushing and pull struct BufferEntry { + // the merged value + NDArray merged; /// \brief the cpu buffer for gpu data std::vector copy_buf; - /// \brief merged data in cpu - NDArray merged; // allocate copy buffer, if it has not been allocated - inline NDArray *AllocCopyBuf(uint32_t dev_id, const TShape& shape) { - if (dev_id >= copy_buf.size()) copy_buf.resize(dev_id + 1); - if (copy_buf[dev_id].is_none()) { - copy_buf[dev_id] = NDArray(shape, Context::CPUPinned(dev_id)); + inline NDArray *AllocCopyBuf(size_t index, uint32_t dev_id, const TShape& shape) { + if (index >= copy_buf.size()) copy_buf.resize(index + 1); + if (copy_buf[index].is_none()) { + copy_buf[index] = NDArray(shape, Context::CPUPinned(dev_id)); } - return ©_buf[dev_id]; + return ©_buf[index]; } }; + // number of threads to do reduction + int nthread_reduction_; + // number of threads to do reduction + size_t bigarray_bound_; /// \brief buffer for merging push value std::unordered_map merge_buf_; /// \brief buffer for storing local values diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 335baef17198..121058a95e20 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -153,7 +153,7 @@ inline void ScalarOp(const NDArray &lhs, } } -void CopyFromTo(const NDArray &from, NDArray *to) { +void CopyFromTo(const NDArray &from, NDArray *to, int priority) { CHECK(from.shape() == to->shape()) << "operands shape mismatch"; CHECK(from.shape().ndim() != 0) @@ -172,7 +172,8 @@ void CopyFromTo(const NDArray &from, NDArray *to) { TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); - }, from.ctx(), const_vars, {ret.var()}); + }, from.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, priority); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { @@ -183,7 +184,8 @@ void CopyFromTo(const NDArray &from, NDArray *to) { from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, ret.ctx(), const_vars, {ret.var()}, FnProperty::kCopyToGPU); + }, ret.ctx(), const_vars, {ret.var()}, + FnProperty::kCopyToGPU, priority); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.CheckAndAlloc(); @@ -192,7 +194,8 @@ void CopyFromTo(const NDArray &from, NDArray *to) { from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU); + }, from.ctx(), const_vars, {ret.var()}, + FnProperty::kCopyFromGPU, priority); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.CheckAndAlloc(); @@ -201,7 +204,8 @@ void CopyFromTo(const NDArray &from, NDArray *to) { from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU); + }, from.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, priority); } else { LOG(FATAL) << "unknown device mask"; } @@ -211,6 +215,9 @@ void CopyFromTo(const NDArray &from, NDArray *to) { } } +inline void CopyFromToSimple(const NDArray &from, NDArray *to) { + CopyFromTo(from, to, 0); +} template inline void SampleOP(const real_t &a, @@ -520,7 +527,7 @@ MXNET_REGISTER_NDARRAY_FUN(_rdiv_scalar).set_function(ScalarOp