Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@ namespace common {

// heuristic to dermine number of threads per GPU
inline int GetNumThreadPerGPU() {
int nthread = std::thread::hardware_concurrency();
if (nthread < 8) {
return dmlc::GetEnv("MXNET_GPU_WORKER_NTHREADS", 1);
} else {
return dmlc::GetEnv("MXNET_GPU_WORKER_NTHREADS", 2);
}
// This is resource efficient option.
return dmlc::GetEnv("MXNET_GPU_WORKER_NTHREADS", 1);
}

/*!
Expand Down
6 changes: 1 addition & 5 deletions src/io/image_augmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ struct ImageAugmentParam : public dmlc::Parameter<ImageAugmentParam> {
int rotate;
/*! \brief filled color while padding */
int fill_value;
/*! \brief whether to print augment info */
bool silent;
/*! \brief shape of the image data*/
TShape data_shape;
// declare parameters
DMLC_DECLARE_PARAMETER(ImageAugmentParam) {
DMLC_DECLARE_FIELD(rand_crop).set_default(true)
DMLC_DECLARE_FIELD(rand_crop).set_default(false)
.describe("Augmentation Param: Whether to random crop on the image");
DMLC_DECLARE_FIELD(crop_y_start).set_default(-1)
.describe("Augmentation Param: Where to nonrandom crop on y.");
Expand Down Expand Up @@ -82,8 +80,6 @@ struct ImageAugmentParam : public dmlc::Parameter<ImageAugmentParam> {
.describe("Augmentation Param: Rotate angle.");
DMLC_DECLARE_FIELD(fill_value).set_default(255)
.describe("Augmentation Param: Maximum value of illumination variation.");
DMLC_DECLARE_FIELD(silent).set_default(true)
.describe("Augmentation Param: Whether to print augmentor info.");
DMLC_DECLARE_FIELD(data_shape)
.set_expect_ndim(3).enforce_nonzero()
.describe("Dataset Param: Shape of each instance generated by the DataIter.");
Expand Down
2 changes: 1 addition & 1 deletion src/operator/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
// use int for enumeration
int act_type;
DMLC_DECLARE_PARAMETER(ActivationParam) {
DMLC_DECLARE_FIELD(act_type).set_default(kReLU)
DMLC_DECLARE_FIELD(act_type)
.add_enum("relu", kReLU)
.add_enum("sigmoid", kSigmoid)
.add_enum("tanh", kTanh)
Expand Down
2 changes: 1 addition & 1 deletion src/operator/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
.set_expect_ndim(2).enforce_nonzero()
.describe("pooling kernel size: (y, x)");

DMLC_DECLARE_FIELD(pool_type).set_default(kMaxPooling)
DMLC_DECLARE_FIELD(pool_type)
.add_enum("max", kMaxPooling)
.add_enum("avg", kAvgPooling)
.add_enum("sum", kSumPooling)
Expand Down
4 changes: 3 additions & 1 deletion src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,9 @@ void GraphExecutor::InitResources() {
CHECK_LE(cnt, 1) << "Node can only have one temp space request";
req_temp_cnt[nid] = cnt;
}
uint32_t num_color = kMaxNumColor;
// restrict allocation to maximum number of parallelism per device
uint32_t num_color = std::min(static_cast<uint32_t>(common::GetNumThreadPerGPU()),
kMaxNumColor);
std::vector<uint32_t> req_temp_color;
// use graph coloring to find node that won't run in parallel
num_color = graph::ColorNodeGroup(graph_, topo_order_, req_temp_cnt,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_symbol_compose():
'fc2_weight', 'fc2_bias']

net2 = mx.symbol.FullyConnected(name='fc3', num_hidden=10)
net2 = mx.symbol.Activation(data=net2)
net2 = mx.symbol.Activation(data=net2, act_type='relu')
net2 = mx.symbol.FullyConnected(data=net2, name='fc4', num_hidden=20)
print(net2.debug_str())

Expand Down