Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
import time
import logging
import warnings
import numpy as np

from .. import metric
from .. import ndarray

from ..context import cpu
from ..model import BatchEndParam
from ..initializer import Uniform
from ..io import DataDesc
from ..io import DataDesc, DataIter, DataBatch
from ..base import _as_list


Expand Down Expand Up @@ -333,7 +334,7 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,

Parameters
----------
eval_data : DataIter
eval_data : DataIter or NDArray or numpy array
Evaluation data to run prediction on.
num_batch : int
Defaults to ``None``, indicates running all the batches in the data iterator.
Expand Down Expand Up @@ -363,6 +364,15 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
"""
assert self.binded and self.params_initialized

if isinstance(eval_data, (ndarray.NDArray, np.ndarray)):
if isinstance(eval_data, np.ndarray):
eval_data = ndarray.array(eval_data)
self.forward(DataBatch([eval_data]))
return self.get_outputs()[0]

if not isinstance(eval_data, DataIter):
raise ValueError('eval_data must be of type NDArray or DataIter')

if reset:
eval_data.reset()

Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,8 @@ def test_forward_reshape():
for_training=False, force_rebind=True)
assert mod.predict(pred_dataiter).shape == tuple([10, num_class])

@with_seed()
def test_forward_types():
#Test forward with other data batch API
Batch = namedtuple('Batch', ['data'])
data = mx.sym.Variable('data')
Expand All @@ -786,6 +788,18 @@ def test_forward_reshape():
mod.forward(Batch(data2))
assert mod.get_outputs()[0].shape == (3, 5)

#Test forward with other NDArray and np.ndarray inputs
data = mx.sym.Variable('data')
out = data * 2
mod = mx.mod.Module(symbol=out, label_names=None)
mod.bind(data_shapes=[('data', (1, 10))])
mod.init_params()
data1 = mx.nd.ones((1, 10))
assert mod.predict(data1).shape == (1, 10)
data2 = np.ones((1, 10))
assert mod.predict(data1).shape == (1, 10)



if __name__ == '__main__':
import nose
Expand Down