-
Notifications
You must be signed in to change notification settings - Fork 338
Implement Hive Unit Test Behavior #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from __future__ import absolute_import | ||
| import flytekit.plugins | ||
|
|
||
| __version__ = '0.1.5' | ||
| __version__ = '0.1.6' | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,14 +7,16 @@ | |
| from datetime import datetime as _datetime | ||
| from six import moves as _six_moves | ||
|
|
||
| from google.protobuf.json_format import ParseDict as _ParseDict | ||
| from flyteidl.plugins import qubole_pb2 as _qubole_pb2 | ||
| from flytekit.common import constants as _sdk_constants, utils as _common_utils | ||
| from flytekit.common.exceptions import user as _user_exceptions, system as _system_exception | ||
| from flytekit.common.types import helpers as _type_helpers | ||
| from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration | ||
| from flytekit.engines import common as _common_engine | ||
| from flytekit.engines.unit.mock_stats import MockStats | ||
| from flytekit.interfaces.data import data_proxy as _data_proxy | ||
| from flytekit.models import literals as _literals, array_job as _array_job | ||
| from flytekit.models import literals as _literals, array_job as _array_job, qubole as _qubole_models | ||
| from flytekit.models.core.identifier import WorkflowExecutionIdentifier | ||
|
|
||
|
|
||
|
|
@@ -32,9 +34,12 @@ def get_task(self, sdk_task): | |
| return ReturnOutputsTask(sdk_task) | ||
| elif sdk_task.type in { | ||
| _sdk_constants.SdkTaskType.DYNAMIC_TASK, | ||
| _sdk_constants.SdkTaskType.BATCH_HIVE_TASK | ||
| }: | ||
| return DynamicTask(sdk_task) | ||
| elif sdk_task.type in { | ||
| _sdk_constants.SdkTaskType.BATCH_HIVE_TASK, | ||
| }: | ||
| return HiveTask(sdk_task) | ||
| else: | ||
| raise _user_exceptions.FlyteAssertion( | ||
| "Unit tests are not currently supported for tasks of type: {}".format( | ||
|
|
@@ -76,20 +81,20 @@ def execute(self, inputs, context=None): | |
| Just execute the function and return the outputs as a user-readable dictionary. | ||
| :param flytekit.models.literals.LiteralMap inputs: | ||
| :param context: | ||
| :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] | ||
| :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] | ||
| """ | ||
| with _TemporaryConfiguration( | ||
| _os.path.join(_os.path.dirname(__file__), 'unit.config'), | ||
| internal_overrides={'image': 'unit_image'} | ||
| ): | ||
| with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory: | ||
| with _data_proxy.LocalWorkingDirectoryContext(working_directory): | ||
| return self._execute_user_code(inputs) | ||
| return self._transform_for_user_output(self._execute_user_code(inputs)) | ||
|
|
||
| def _execute_user_code(self, inputs): | ||
| """ | ||
| :param flytekit.models.literals.LiteralMap inputs: | ||
| :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] | ||
| :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] | ||
| """ | ||
| with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: | ||
| return self.sdk_task.execute( | ||
|
|
@@ -107,24 +112,32 @@ def _execute_user_code(self, inputs): | |
| inputs | ||
| ) | ||
|
|
||
| def _transform_for_user_output(self, outputs): | ||
| """ | ||
| Take whatever is returned from the task execution and convert to a reasonable output for the behavior of this | ||
| task's unit test. | ||
| :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: | ||
| :rtype: T | ||
| """ | ||
| return outputs | ||
|
|
||
| def register(self, identifier, version): | ||
| raise _user_exceptions.FlyteAssertion("You cannot register unit test tasks.") | ||
|
|
||
|
|
||
| class ReturnOutputsTask(UnitTestEngineTask): | ||
| def execute(self, inputs, context=None): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this no longer needed?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a refactor so this method is sufficient for transforming the execute function of the super class to behave as we desire. |
||
| def _transform_for_user_output(self, outputs): | ||
| """ | ||
| Just execute the function and return the outputs as a user-readable dictionary. | ||
| :param flytekit.models.literals.LiteralMap inputs: | ||
| :param context: | ||
| :rtype: dict[Text, T] | ||
| Just return the outputs as a user-readable dictionary. | ||
| :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: | ||
| :rtype: T | ||
| """ | ||
| outputs = super(ReturnOutputsTask, self).execute(inputs)[_sdk_constants.OUTPUT_FILE_NAME] | ||
| literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME] | ||
| return { | ||
| name: _type_helpers.get_sdk_type_from_literal_type( | ||
| variable.type | ||
| ).promote_from_model( | ||
| outputs.literals[name] | ||
| literal_map.literals[name] | ||
| ).to_python_std() | ||
| for name, variable in _six.iteritems(self.sdk_task.interface.outputs) | ||
| } | ||
|
|
@@ -135,7 +148,7 @@ class DynamicTask(ReturnOutputsTask): | |
| def _execute_user_code(self, inputs): | ||
| """ | ||
| :param flytekit.models.literals.LiteralMap inputs: | ||
| :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] | ||
| :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] | ||
| """ | ||
| results = super(DynamicTask, self)._execute_user_code(inputs) | ||
| if _sdk_constants.FUTURES_FILE_NAME in results: | ||
|
|
@@ -151,7 +164,7 @@ def _execute_user_code(self, inputs): | |
| # TODO: futures.outputs should have the Schema instances. | ||
| # After schema is implemented, fill out random data into the random locations | ||
| # then check output in test function | ||
| # From Haytham even though we recommend people use typed schemas, they might not always do so... | ||
| # Even though we recommend people use typed schemas, they might not always do so... | ||
| # in which case it'll be impossible to predict the actual schema, we should support a | ||
| # way for unit test authors to provide fake data regardless | ||
| sub_task_output = None | ||
|
|
@@ -201,7 +214,7 @@ def fulfil_bindings(binding_data, fulfilled_promises): | |
| fulfilled_promises | ||
|
|
||
| :param _interface.BindingData binding_data: | ||
| :param dict[Text, T] fulfilled_promises: | ||
| :param dict[Text,T] fulfilled_promises: | ||
| :rtype: | ||
| """ | ||
| if binding_data.scalar: | ||
|
|
@@ -228,3 +241,26 @@ def fulfil_bindings(binding_data, fulfilled_promises): | |
| k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in | ||
| _six.iteritems(binding_data.map.bindings) | ||
| })) | ||
|
|
||
|
|
||
| class HiveTask(DynamicTask): | ||
| def _transform_for_user_output(self, outputs): | ||
| """ | ||
| Just execute the function and return the list of Hive queries returned. | ||
| :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: | ||
| :rtype: list[Text] | ||
| """ | ||
| futures = outputs.get(_sdk_constants.FUTURES_FILE_NAME) | ||
| if futures: | ||
| task_ids_to_defs = { | ||
| t.id.name: _qubole_models.QuboleHiveJob.from_flyte_idl( | ||
| _ParseDict(t.custom, _qubole_pb2.QuboleHiveJob()) | ||
| ) | ||
| for t in futures.tasks | ||
| } | ||
| return [ | ||
| q.query | ||
| for q in task_ids_to_defs[futures.nodes[0].task_node.reference_id.name].query_collection.queries | ||
| ] | ||
| else: | ||
| return [] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
tests/flytekit/unit/use_scenarios/unit_testing/hive_tasks.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| from __future__ import absolute_import | ||
| from flytekit.sdk.tasks import hive_task | ||
| import pytest | ||
|
|
||
|
|
||
| def test_no_queries(): | ||
| @hive_task | ||
| def test_hive_task(wf_params): | ||
| pass | ||
|
|
||
| assert test_hive_task.unit_test() == [] | ||
|
|
||
|
|
||
| def test_empty_list_queries(): | ||
| @hive_task | ||
| def test_hive_task(wf_params): | ||
| return [] | ||
|
|
||
| assert test_hive_task.unit_test() == [] | ||
|
|
||
|
|
||
| def test_one_query(): | ||
| @hive_task | ||
| def test_hive_task(wf_params): | ||
| return "abc" | ||
|
|
||
| assert test_hive_task.unit_test() == ["abc"] | ||
|
|
||
|
|
||
| def test_multiple_queries(): | ||
| @hive_task | ||
| def test_hive_task(wf_params): | ||
| return ["abc", "cde"] | ||
|
|
||
| assert test_hive_task.unit_test() == ["abc", "cde"] | ||
|
|
||
|
|
||
| def test_raise_exception(): | ||
| @hive_task | ||
| def test_hive_task(wf_params): | ||
| raise FloatingPointError("Floating point error for some reason.") | ||
|
|
||
| with pytest.raises(FloatingPointError): | ||
| test_hive_task.unit_test() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably are doing this, but you can submit this after your other PR so you get both changes the version bump
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah in the future I will bump the version via an explicit PR.