diff --git a/src/common/object_pool.h b/src/common/object_pool.h index 5787f53ff497..ac12e2b30c6e 100644 --- a/src/common/object_pool.h +++ b/src/common/object_pool.h @@ -41,6 +41,12 @@ class ObjectPool { */ static ObjectPool* Get(); + /*! + * \brief Get a shared ptr of the singleton instance of pool. + * \return Shared pointer to the Object Pool. + */ + static std::shared_ptr _GetSharedRef(); + private: /*! * \brief Internal structure to hold pointers. @@ -141,8 +147,13 @@ void ObjectPool::Delete(T* ptr) { template ObjectPool* ObjectPool::Get() { - static ObjectPool inst; - return &inst; + return _GetSharedRef().get(); +} + +template +std::shared_ptr > ObjectPool::_GetSharedRef() { + static std::shared_ptr > inst_ptr(new ObjectPool()); + return inst_ptr; } template diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 7a8708484f9c..9a7c32525961 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -251,6 +251,11 @@ class ThreadedEngine : public Engine { ThreadedEngine() { engine_info_ = dmlc::GetEnv("MXNET_ENGINE_INFO", false); + + objpool_opr_ref_ = common::ObjectPool::_GetSharedRef(); + objpool_blk_ref_ = common::ObjectPool::_GetSharedRef(); + objpool_varblk_ref_ = common::ObjectPool::_GetSharedRef(); + objpool_var_ref_ = common::ObjectPool::_GetSharedRef(); } ~ThreadedEngine() { { @@ -329,6 +334,15 @@ class ThreadedEngine : public Engine { */ std::mutex finished_m_; std::condition_variable finished_cv_; + + /*! + * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early + * See also #309 (https://github.com/dmlc/mxnet/issues/309) + */ + std::shared_ptr > objpool_opr_ref_; + std::shared_ptr > objpool_blk_ref_; + std::shared_ptr > objpool_varblk_ref_; + std::shared_ptr > objpool_var_ref_; /*! * \brief Disallow copy construction and assignment. */