diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index ac7c3d3825ab..0d9ab9ec0a02 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -51,8 +51,10 @@ class Estimator(object): The model used for training. loss : gluon.loss.Loss Loss (objective) function to calculate during training. - metrics : EvalMetric or list of EvalMetric - Metrics for evaluating models. + train_metrics : EvalMetric or list of EvalMetric + Training metrics for evaluating models on training dataset. + val_metrics : EvalMetric or list of EvalMetric + Validation metrics for evaluating models on validation dataset. initializer : Initializer Initializer to initialize the network. trainer : Trainer @@ -105,7 +107,8 @@ class Estimator(object): def __init__(self, net, loss, - metrics=None, + train_metrics=None, + val_metrics=None, initializer=None, trainer=None, context=None, @@ -113,7 +116,8 @@ def __init__(self, net, eval_net=None): self.net = net self.loss = self._check_loss(loss) - self._train_metrics = _check_metrics(metrics) + self._train_metrics = _check_metrics(train_metrics) + self._val_metrics = _check_metrics(val_metrics) self._add_default_training_metrics() self._add_validation_metrics() self.evaluation_loss = self.loss @@ -226,13 +230,21 @@ def _add_default_training_metrics(self): self._train_metrics.append(metric_loss(loss_name)) for metric in self._train_metrics: - metric.name = "training " + metric.name + # add training prefix to the metric name + # it is useful for event handlers to distinguish them from validation metrics + metric.name = 'training ' + metric.name def _add_validation_metrics(self): - self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics] + if not self._val_metrics: + self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics] for metric in self._val_metrics: - metric.name = "validation " + metric.name + # add validation prefix to the metric name + # it is useful for event handlers to distinguish them from training metrics + if 'training' in metric.name: + metric.name = metric.name.replace('training', 'validation') + else: + metric.name = 'validation ' + metric.name @property def train_metrics(self): @@ -244,7 +256,6 @@ def val_metrics(self): def evaluate_batch(self, val_batch, - val_metrics, batch_axis=0): """Evaluate model on a batch of validation data. @@ -252,25 +263,19 @@ def evaluate_batch(self, ---------- val_batch : tuple Data and label of a batch from the validation data loader. - val_metrics : EvalMetric or list of EvalMetrics - Metrics to update validation result. batch_axis : int, default 0 Batch axis to split the validation data into devices. """ data, label = self._get_data_and_label(val_batch, self.context, batch_axis) pred = [self.eval_net(x) for x in data] loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)] - # update metrics - for metric in val_metrics: - if isinstance(metric, metric_loss): - metric.update(0, loss) - else: - metric.update(label, pred) + + return data, label, pred, loss def evaluate(self, val_data, - val_metrics, - batch_axis=0): + batch_axis=0, + event_handlers=None): """Evaluate model on validation data. This function calls :py:func:`evaluate_batch` on each of the batches from the @@ -281,21 +286,42 @@ def evaluate(self, ---------- val_data : DataLoader Validation data loader with data and labels. - val_metrics : EvalMetric or list of EvalMetrics - Metrics to update validation result. batch_axis : int, default 0 Batch axis to split the validation data into devices. + event_handlers : EventHandler or list of EventHandler + List of :py:class:`EventHandlers` to apply during validation. Besides + event handlers specified here, a default MetricHandler and a LoggingHandler + will be added if not specified explicitly. """ if not isinstance(val_data, DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " "can transform your DataIter or any NDArray into Gluon DataLoader. " "Refer to gluon.data.DataLoader") - for metric in val_metrics: + for metric in self.val_metrics: metric.reset() + event_handlers = self._prepare_default_validation_handlers(event_handlers) + + _, epoch_begin, batch_begin, batch_end, \ + epoch_end, _ = self._categorize_handlers(event_handlers) + + estimator_ref = self + + for handler in epoch_begin: + handler.epoch_begin(estimator_ref) + for _, batch in enumerate(val_data): - self.evaluate_batch(batch, val_metrics, batch_axis) + for handler in batch_begin: + handler.batch_begin(estimator_ref, batch=batch) + + _, label, pred, loss = self.evaluate_batch(batch, batch_axis) + + for handler in batch_end: + handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss) + + for handler in epoch_end: + handler.epoch_end(estimator_ref) def fit_batch(self, train_batch, batch_axis=0): """Trains the model on a batch of training data. @@ -441,23 +467,17 @@ def _prepare_default_handlers(self, val_data, event_handlers): added_default_handlers.append(GradientUpdateHandler()) if not any(isinstance(handler, MetricHandler) for handler in event_handlers): - added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics)) + added_default_handlers.append(MetricHandler(metrics=self.train_metrics)) if not any(isinstance(handler, ValidationHandler) for handler in event_handlers): # no validation handler if val_data: - val_metrics = self.val_metrics # add default validation handler if validation data found added_default_handlers.append(ValidationHandler(val_data=val_data, - eval_fn=self.evaluate, - val_metrics=val_metrics)) - else: - # set validation metrics to None if no validation data and no validation handler - val_metrics = [] + eval_fn=self.evaluate)) if not any(isinstance(handler, LoggingHandler) for handler in event_handlers): - added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics, - val_metrics=val_metrics)) + added_default_handlers.append(LoggingHandler(metrics=self.train_metrics)) # if there is a mix of user defined event handlers and default event handlers # they should have the same set of metrics @@ -474,6 +494,29 @@ def _prepare_default_handlers(self, val_data, event_handlers): event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0)) return event_handlers + def _prepare_default_validation_handlers(self, event_handlers): + event_handlers = _check_event_handlers(event_handlers) + added_default_handlers = [] + + # add default logging handler and metric handler for validation + if not any(isinstance(handler, MetricHandler) for handler in event_handlers): + added_default_handlers.append(MetricHandler(metrics=self.val_metrics)) + + if not any(isinstance(handler, LoggingHandler) for handler in event_handlers): + added_default_handlers.append(LoggingHandler(metrics=self.val_metrics)) + + mixing_handlers = event_handlers and added_default_handlers + event_handlers.extend(added_default_handlers) + + # check if all handlers refer to well-defined validation metrics + if mixing_handlers: + known_metrics = set(self.val_metrics) + for handler in event_handlers: + _check_handler_metric_ref(handler, known_metrics) + + event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0)) + return event_handlers + def _categorize_handlers(self, event_handlers): """ categorize handlers into 6 event lists to avoid calling empty methods diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 64777608bef0..c7551362fa5b 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -128,28 +128,28 @@ class MetricHandler(EpochBegin, BatchEnd): Parameters ---------- - train_metrics : List of EvalMetrics - Training metrics to be updated at batch end. + metrics : List of EvalMetrics + Metrics to be updated at batch end. priority : scalar Priority level of the MetricHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is. """ - def __init__(self, train_metrics, priority=-1000): - self.train_metrics = _check_metrics(train_metrics) + def __init__(self, metrics, priority=-1000): + self.metrics = _check_metrics(metrics) # order to be called among all callbacks # metrics need to be calculated before other callbacks can access them self.priority = priority def epoch_begin(self, estimator, *args, **kwargs): - for metric in self.train_metrics: + for metric in self.metrics: metric.reset() def batch_end(self, estimator, *args, **kwargs): pred = kwargs['pred'] label = kwargs['label'] loss = kwargs['loss'] - for metric in self.train_metrics: + for metric in self.metrics: if isinstance(metric, metric_loss): # metric wrapper for loss values metric.update(0, loss) @@ -171,8 +171,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): eval_fn : function A function defines how to run evaluation and calculate loss and metrics. - val_metrics : List of EvalMetrics - Validation metrics to be updated. epoch_period : int, default 1 How often to run validation at epoch end, by default :py:class:`ValidationHandler` validate every epoch. @@ -188,7 +186,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): def __init__(self, val_data, eval_fn, - val_metrics=None, epoch_period=1, batch_period=None, priority=-1000): @@ -196,7 +193,6 @@ def __init__(self, self.eval_fn = eval_fn self.epoch_period = epoch_period self.batch_period = batch_period - self.val_metrics = _check_metrics(val_metrics) self.current_batch = 0 self.current_epoch = 0 # order to be called among all callbacks @@ -211,20 +207,12 @@ def train_begin(self, estimator, *args, **kwargs): def batch_end(self, estimator, *args, **kwargs): self.current_batch += 1 if self.batch_period and self.current_batch % self.batch_period == 0: - self.eval_fn(val_data=self.val_data, - val_metrics=self.val_metrics) - msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \ - % (self.current_epoch, self.current_batch) - for monitor in self.val_metrics: - name, value = monitor.get() - msg += '%s: %.4f, ' % (name, value) - estimator.logger.info(msg.rstrip(',')) + self.eval_fn(val_data=self.val_data) def epoch_end(self, estimator, *args, **kwargs): self.current_epoch += 1 if self.epoch_period and self.current_epoch % self.epoch_period == 0: - self.eval_fn(val_data=self.val_data, - val_metrics=self.val_metrics) + self.eval_fn(val_data=self.val_data) class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd): @@ -239,10 +227,8 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Logging interval during training. log_interval='epoch': display metrics every epoch log_interval=integer k: display metrics every interval of k batches - train_metrics : list of EvalMetrics - Training metrics to be logged, logged at batch end, epoch end, train end. - val_metrics : list of EvalMetrics - Validation metrics to be logged, logged at epoch end, train end. + metrics : list of EvalMetrics + Metrics to be logged, logged at batch end, epoch end, train end. priority : scalar, default np.Inf Priority level of the LoggingHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the @@ -250,14 +236,12 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat """ def __init__(self, log_interval='epoch', - train_metrics=None, - val_metrics=None, + metrics=None, priority=np.Inf): super(LoggingHandler, self).__init__() if not isinstance(log_interval, int) and log_interval != 'epoch': raise ValueError("log_interval must be either an integer or string 'epoch'") - self.train_metrics = _check_metrics(train_metrics) - self.val_metrics = _check_metrics(val_metrics) + self.metrics = _check_metrics(metrics) self.batch_index = 0 self.current_epoch = 0 self.processed_samples = 0 @@ -265,6 +249,7 @@ def __init__(self, log_interval='epoch', # it will also shut down logging at train end self.priority = priority self.log_interval = log_interval + self.log_interval_time = 0 def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() @@ -288,7 +273,7 @@ def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch) # log every result in train stats including train/validation loss & metrics - for metric in self.train_metrics + self.val_metrics: + for metric in self.metrics: name, value = metric.get() msg += '%s: %.4f, ' % (name, value) estimator.logger.info(msg.rstrip(', ')) @@ -307,7 +292,7 @@ def batch_end(self, estimator, *args, **kwargs): if self.batch_index % self.log_interval == 0: msg += 'time/interval: %.3fs ' % self.log_interval_time self.log_interval_time = 0 - for metric in self.train_metrics: + for metric in self.metrics: # only log current training loss & metric after each interval name, value = metric.get() msg += '%s: %.4f, ' % (name, value) @@ -316,15 +301,23 @@ def batch_end(self, estimator, *args, **kwargs): def epoch_begin(self, estimator, *args, **kwargs): if isinstance(self.log_interval, int) or self.log_interval == 'epoch': + is_training = False + # use the name hack defined in __init__() of estimator class + for metric in self.metrics: + if 'training' in metric.name: + is_training = True self.epoch_start = time.time() - estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", - self.current_epoch, estimator.trainer.learning_rate) + if is_training: + estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", + self.current_epoch, estimator.trainer.learning_rate) + else: + estimator.logger.info("Validation Begin") def epoch_end(self, estimator, *args, **kwargs): if isinstance(self.log_interval, int) or self.log_interval == 'epoch': epoch_time = time.time() - self.epoch_start msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) - for monitor in self.train_metrics + self.val_metrics: + for monitor in self.metrics: name, value = monitor.get() msg += '%s: %.4f, ' % (name, value) estimator.logger.info(msg.rstrip(', ')) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index dba3f122a9b6..924dd083bef4 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -63,7 +63,7 @@ def test_fit(): trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, trainer=trainer, context=ctx) @@ -93,7 +93,7 @@ def test_validation(): trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, trainer=trainer, context=ctx, evaluation_loss=evaluation_loss) @@ -105,8 +105,7 @@ def test_validation(): # using validation handler train_metrics = est.train_metrics val_metrics = est.val_metrics - validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate, - val_metrics=val_metrics) + validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate) with assert_raises(ValueError): est.fit(train_data=dataiter, @@ -132,7 +131,7 @@ def test_initializer(): # no initializer est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, context=ctx) est.fit(train_data=train_data, epochs=num_epochs) @@ -145,7 +144,7 @@ def test_initializer(): with warnings.catch_warnings(record=True) as w: est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, initializer=mx.init.MSRAPrelu(), trainer=trainer, context=ctx) @@ -153,7 +152,7 @@ def test_initializer(): # net partially initialized, fine tuning use case net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx) net.output = gluon.nn.Dense(10) #last layer not initialized - est = Estimator(net, loss=loss, metrics=acc, context=ctx) + est = Estimator(net, loss=loss, train_metrics=acc, context=ctx) dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10))) train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5) est.fit(train_data=train_data, @@ -175,7 +174,7 @@ def test_trainer(): with warnings.catch_warnings(record=True) as w: est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, context=ctx) assert 'No trainer specified' in str(w[-1].message) est.fit(train_data=train_data, @@ -186,7 +185,7 @@ def test_trainer(): with assert_raises(ValueError): est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, trainer=trainer, context=ctx) @@ -212,7 +211,7 @@ def test_metric(): metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()] est = Estimator(net=net, loss=loss, - metrics=metrics, + train_metrics=metrics, trainer=trainer, context=ctx) est.fit(train_data=train_data, @@ -221,7 +220,7 @@ def test_metric(): with assert_raises(ValueError): est = Estimator(net=net, loss=loss, - metrics='acc', + train_metrics='acc', trainer=trainer, context=ctx) # test default metric @@ -244,7 +243,7 @@ def test_loss(): with assert_raises(ValueError): est = Estimator(net=net, loss='mse', - metrics=acc, + train_metrics=acc, trainer=trainer, context=ctx) @@ -257,26 +256,26 @@ def test_context(): # input no context est = Estimator(net=net, loss=loss, - metrics=metrics) + train_metrics=metrics) # input list of context gpus = mx.context.num_gpus() ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()] net = _get_test_network() est = Estimator(net=net, loss=loss, - metrics=metrics, + train_metrics=metrics, context=ctx) # input invalid context with assert_raises(ValueError): est = Estimator(net=net, loss=loss, - metrics=metrics, + train_metrics=metrics, context='cpu') with assert_raises(AssertionError): est = Estimator(net=net, loss=loss, - metrics=metrics, + train_metrics=metrics, context=[mx.gpu(0), mx.gpu(100)]) @@ -341,7 +340,7 @@ def test_default_handlers(): est = Estimator(net=net, loss=loss, - metrics=train_acc, + train_metrics=train_acc, trainer=trainer, context=ctx) # no handler(all default handlers), no warning @@ -352,18 +351,18 @@ def test_default_handlers(): # use mix of default and user defined handlers train_metrics = est.train_metrics val_metrics = est.val_metrics - logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) + logging = LoggingHandler(metrics=train_metrics) est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # handler with all user defined metrics # use mix of default and user defined handlers - metric = MetricHandler(train_metrics=[train_acc]) - logging = LoggingHandler(train_metrics=[train_acc]) + metric = MetricHandler(metrics=[train_acc]) + logging = LoggingHandler(metrics=[train_acc]) est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging]) # handler with mixed metrics, some handler use metrics prepared by estimator # some handler use metrics user prepared - logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")]) + logging = LoggingHandler(metrics=[mx.metric.RMSE("val acc")]) with assert_raises(ValueError): est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) @@ -392,7 +391,7 @@ def test_eval_net(): trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, trainer=trainer, context=ctx, evaluation_loss=evaluation_loss, @@ -410,7 +409,7 @@ def test_eval_net(): trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, trainer=trainer, context=ctx, evaluation_loss=evaluation_loss, @@ -432,7 +431,7 @@ def test_eval_net(): trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) est = Estimator(net=net, loss=loss, - metrics=acc, + train_metrics=acc, trainer=trainer, context=ctx, evaluation_loss=evaluation_loss, @@ -442,3 +441,29 @@ def test_eval_net(): val_data=dataloader, epochs=num_epochs) +def test_val_handlers(): + net = _get_test_network() + train_data, _ = _get_test_data() + val_data, _ = _get_test_data() + + num_epochs = 1 + ctx = mx.cpu() + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + + train_acc = mx.metric.RMSE() + loss = gluon.loss.L2Loss() + + est = Estimator(net=net, + loss=loss, + train_metrics=train_acc, + trainer=trainer, + context=ctx) + + with warnings.catch_warnings(record=True) as w: + est.fit(train_data=train_data, epochs=num_epochs) + est.evaluate(val_data=val_data) + + logging = LoggingHandler(log_interval=1, metrics=est.val_metrics) + est.evaluate(val_data=val_data, event_handlers=[logging]) + diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 658fb88f47e5..41b790102f62 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -54,7 +54,7 @@ def test_checkpoint_handler(): net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc) checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, model_prefix=model_prefix, monitor=acc, @@ -72,7 +72,7 @@ def test_checkpoint_handler(): file_path = os.path.join(tmpdir, model_prefix) net = _get_test_network(nn.HybridSequential()) net.hybridize() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc) checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, model_prefix=model_prefix, epoch_period=None, @@ -100,7 +100,7 @@ def test_resume_checkpoint(): net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc) checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, model_prefix=model_prefix, monitor=acc, @@ -125,7 +125,7 @@ def test_early_stopping(): net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc) early_stopping = event_handler.EarlyStoppingHandler(monitor=acc, patience=0, mode='min') @@ -149,14 +149,13 @@ def test_logging(): net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc) est.logger.addHandler(logging.FileHandler(output_dir)) train_metrics = est.train_metrics val_metrics = est.val_metrics - logging_handler = event_handler.LoggingHandler(train_metrics=train_metrics, - val_metrics=val_metrics) + logging_handler = event_handler.LoggingHandler(metrics=train_metrics) est.fit(test_data, event_handlers=[logging_handler], epochs=3) assert logging_handler.batch_index == 0 assert logging_handler.current_epoch == 3 @@ -197,7 +196,7 @@ def epoch_end(self, estimator, *args, **kwargs): net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc) custom_handler = CustomStopHandler(3, 2) est.fit(test_data, event_handlers=[custom_handler], epochs=3) assert custom_handler.num_batch == 3 @@ -220,10 +219,10 @@ def test_logging_interval(): num_epochs = 1 ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() - logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) + logging = LoggingHandler(metrics=[acc], log_interval=log_interval) est = estimator.Estimator(net=net, loss=ce_loss, - metrics=acc) + train_metrics=acc) est.fit(train_data=dataloader, epochs=num_epochs, @@ -245,10 +244,10 @@ def test_logging_interval(): sys.stdout = mystdout = StringIO() acc = mx.metric.Accuracy() log_interval = 5 - logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) + logging = LoggingHandler(metrics=[acc], log_interval=log_interval) est = estimator.Estimator(net=net, loss=ce_loss, - metrics=acc) + train_metrics=acc) est.fit(train_data=dataloader, epochs=num_epochs, event_handlers=[logging])