diff --git a/src/c_api.cc b/src/c_api.cc index 9b31b8e47641..6427a6357c90 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -180,12 +180,6 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, API_END(); } -int MXEngineWaitAll() { - API_BEGIN(); - Engine::Get()->WaitForAll(); - API_END(); -} - // NOTE: return value is added in API_END int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 606103f9e7ab..0a3da50e69be 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -42,13 +42,17 @@ class ThreadedEnginePerDevice : public ThreadedEngine { protected: void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { + const Context& ctx = opr_block->ctx; if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { - CHECK_EQ(opr_block->ctx.dev_mask, cpu::kDevMask); + if (ctx.dev_mask == gpu::kDevMask) { + #if MXNET_USE_CUDA + mshadow::SetDevice(ctx.dev_id); + #endif + } RunContext run_ctx; run_ctx.stream = nullptr; this->ExecuteOprBlock(run_ctx, opr_block); } else { - const Context& ctx = opr_block->ctx; if (ctx.dev_mask == cpu::kDevMask) { cpu_worker_.task_queue.Push(opr_block); } else {