Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.1.5'
__version__ = '0.1.6'

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably are doing this, but you can submit this after your other PR so you get both changes the version bump

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in the future I will bump the version via an explicit PR.

66 changes: 51 additions & 15 deletions flytekit/engines/unit/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from datetime import datetime as _datetime
from six import moves as _six_moves

from google.protobuf.json_format import ParseDict as _ParseDict
from flyteidl.plugins import qubole_pb2 as _qubole_pb2
from flytekit.common import constants as _sdk_constants, utils as _common_utils
from flytekit.common.exceptions import user as _user_exceptions, system as _system_exception
from flytekit.common.types import helpers as _type_helpers
from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration
from flytekit.engines import common as _common_engine
from flytekit.engines.unit.mock_stats import MockStats
from flytekit.interfaces.data import data_proxy as _data_proxy
from flytekit.models import literals as _literals, array_job as _array_job
from flytekit.models import literals as _literals, array_job as _array_job, qubole as _qubole_models
from flytekit.models.core.identifier import WorkflowExecutionIdentifier


Expand All @@ -32,9 +34,12 @@ def get_task(self, sdk_task):
return ReturnOutputsTask(sdk_task)
elif sdk_task.type in {
_sdk_constants.SdkTaskType.DYNAMIC_TASK,
_sdk_constants.SdkTaskType.BATCH_HIVE_TASK
}:
return DynamicTask(sdk_task)
elif sdk_task.type in {
_sdk_constants.SdkTaskType.BATCH_HIVE_TASK,
}:
return HiveTask(sdk_task)
else:
raise _user_exceptions.FlyteAssertion(
"Unit tests are not currently supported for tasks of type: {}".format(
Expand Down Expand Up @@ -76,20 +81,20 @@ def execute(self, inputs, context=None):
Just execute the function and return the outputs as a user-readable dictionary.
:param flytekit.models.literals.LiteralMap inputs:
:param context:
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
"""
with _TemporaryConfiguration(
_os.path.join(_os.path.dirname(__file__), 'unit.config'),
internal_overrides={'image': 'unit_image'}
):
with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory:
with _data_proxy.LocalWorkingDirectoryContext(working_directory):
return self._execute_user_code(inputs)
return self._transform_for_user_output(self._execute_user_code(inputs))

def _execute_user_code(self, inputs):
"""
:param flytekit.models.literals.LiteralMap inputs:
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
"""
with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
return self.sdk_task.execute(
Expand All @@ -107,24 +112,32 @@ def _execute_user_code(self, inputs):
inputs
)

def _transform_for_user_output(self, outputs):
"""
Take whatever is returned from the task execution and convert to a reasonable output for the behavior of this
task's unit test.
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
:rtype: T
"""
return outputs

def register(self, identifier, version):
raise _user_exceptions.FlyteAssertion("You cannot register unit test tasks.")


class ReturnOutputsTask(UnitTestEngineTask):
def execute(self, inputs, context=None):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this no longer needed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a refactor so this method is sufficient for transforming the execute function of the super class to behave as we desire.

def _transform_for_user_output(self, outputs):
"""
Just execute the function and return the outputs as a user-readable dictionary.
:param flytekit.models.literals.LiteralMap inputs:
:param context:
:rtype: dict[Text, T]
Just return the outputs as a user-readable dictionary.
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
:rtype: T
"""
outputs = super(ReturnOutputsTask, self).execute(inputs)[_sdk_constants.OUTPUT_FILE_NAME]
literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME]
return {
name: _type_helpers.get_sdk_type_from_literal_type(
variable.type
).promote_from_model(
outputs.literals[name]
literal_map.literals[name]
).to_python_std()
for name, variable in _six.iteritems(self.sdk_task.interface.outputs)
}
Expand All @@ -135,7 +148,7 @@ class DynamicTask(ReturnOutputsTask):
def _execute_user_code(self, inputs):
"""
:param flytekit.models.literals.LiteralMap inputs:
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
"""
results = super(DynamicTask, self)._execute_user_code(inputs)
if _sdk_constants.FUTURES_FILE_NAME in results:
Expand All @@ -151,7 +164,7 @@ def _execute_user_code(self, inputs):
# TODO: futures.outputs should have the Schema instances.
# After schema is implemented, fill out random data into the random locations
# then check output in test function
# From Haytham even though we recommend people use typed schemas, they might not always do so...
# Even though we recommend people use typed schemas, they might not always do so...
# in which case it'll be impossible to predict the actual schema, we should support a
# way for unit test authors to provide fake data regardless
sub_task_output = None
Expand Down Expand Up @@ -201,7 +214,7 @@ def fulfil_bindings(binding_data, fulfilled_promises):
fulfilled_promises

:param _interface.BindingData binding_data:
:param dict[Text, T] fulfilled_promises:
:param dict[Text,T] fulfilled_promises:
:rtype:
"""
if binding_data.scalar:
Expand All @@ -228,3 +241,26 @@ def fulfil_bindings(binding_data, fulfilled_promises):
k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in
_six.iteritems(binding_data.map.bindings)
}))


class HiveTask(DynamicTask):
def _transform_for_user_output(self, outputs):
"""
Just execute the function and return the list of Hive queries returned.
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
:rtype: list[Text]
"""
futures = outputs.get(_sdk_constants.FUTURES_FILE_NAME)
if futures:
task_ids_to_defs = {
t.id.name: _qubole_models.QuboleHiveJob.from_flyte_idl(
_ParseDict(t.custom, _qubole_pb2.QuboleHiveJob())
)
for t in futures.tasks
}
return [
q.query
for q in task_ids_to_defs[futures.nodes[0].task_node.reference_id.name].query_collection.queries
]
else:
return []
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ max-complexity=16
[tool:pytest]
norecursedirs = common workflows spark
log_cli = true
log_cli_level = 100
log_cli_level = 20

[pep8]
max-line-length = 120
Expand Down
44 changes: 44 additions & 0 deletions tests/flytekit/unit/use_scenarios/unit_testing/hive_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import absolute_import
from flytekit.sdk.tasks import hive_task
import pytest


def test_no_queries():
@hive_task
def test_hive_task(wf_params):
pass

assert test_hive_task.unit_test() == []


def test_empty_list_queries():
@hive_task
def test_hive_task(wf_params):
return []

assert test_hive_task.unit_test() == []


def test_one_query():
@hive_task
def test_hive_task(wf_params):
return "abc"

assert test_hive_task.unit_test() == ["abc"]


def test_multiple_queries():
@hive_task
def test_hive_task(wf_params):
return ["abc", "cde"]

assert test_hive_task.unit_test() == ["abc", "cde"]


def test_raise_exception():
@hive_task
def test_hive_task(wf_params):
raise FloatingPointError("Floating point error for some reason.")

with pytest.raises(FloatingPointError):
test_hive_task.unit_test()