diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index a34bf31a48..5a6bbc4fdc 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -676,12 +676,11 @@ def get_node_execution(self, node_execution_identifier): ) ) - def get_node_execution_data(self, node_execution_identifier): + def get_node_execution_data(self, node_execution_identifier) -> _execution.NodeExecutionGetDataResponse: """ Returns signed URLs to LiteralMap blobs for a node execution's inputs and outputs (when available). :param flytekit.models.core.identifier.NodeExecutionIdentifier node_execution_identifier: - :rtype: flytekit.models.execution.NodeExecutionGetDataResponse """ return _execution.NodeExecutionGetDataResponse.from_flyte_idl( super(SynchronousFlyteClient, self).get_node_execution_data( diff --git a/flytekit/models/core/catalog.py b/flytekit/models/core/catalog.py new file mode 100644 index 0000000000..97cc2c34ff --- /dev/null +++ b/flytekit/models/core/catalog.py @@ -0,0 +1,75 @@ +from flyteidl.core import catalog_pb2 + +from flytekit.models import common as _common_models +from flytekit.models.core import identifier as _identifier + + +class CatalogArtifactTag(_common_models.FlyteIdlEntity): + def __init__(self, artifact_id: str, name: str): + self._artifact_id = artifact_id + self._name = name + + @property + def artifact_id(self) -> str: + return self._artifact_id + + @property + def name(self) -> str: + return self._name + + def to_flyte_idl(self) -> catalog_pb2.CatalogArtifactTag: + return catalog_pb2.CatalogArtifactTag(artifact_id=self.artifact_id, name=self.name) + + @classmethod + def from_flyte_idl(cls, p: catalog_pb2.CatalogArtifactTag) -> "CatalogArtifactTag": + return cls( + artifact_id=p.artifact_id, + name=p.name, + ) + + +class CatalogMetadata(_common_models.FlyteIdlEntity): + def __init__( + self, + dataset_id: _identifier.Identifier, + artifact_tag: CatalogArtifactTag, + source_task_execution: _identifier.TaskExecutionIdentifier, + ): + self._dataset_id = dataset_id + self._artifact_tag = artifact_tag + self._source_task_execution = source_task_execution + + @property + def dataset_id(self) -> _identifier.Identifier: + return self._dataset_id + + @property + def artifact_tag(self) -> CatalogArtifactTag: + return self._artifact_tag + + @property + def source_task_execution(self) -> _identifier.TaskExecutionIdentifier: + return self._source_task_execution + + @property + def source_execution(self) -> _identifier.TaskExecutionIdentifier: + """ + This is a one of but for now there's only one thing in the one of + """ + return self._source_task_execution + + def to_flyte_idl(self) -> catalog_pb2.CatalogMetadata: + return catalog_pb2.CatalogMetadata( + dataset_id=self.dataset_id.to_flyte_idl(), + artifact_tag=self.artifact_tag.to_flyte_idl(), + source_task_execution=self.source_task_execution.to_flyte_idl(), + ) + + @classmethod + def from_flyte_idl(cls, pb: catalog_pb2.CatalogMetadata) -> "CatalogMetadata": + return cls( + dataset_id=_identifier.Identifier.from_flyte_idl(pb.dataset_id), + artifact_tag=CatalogArtifactTag.from_flyte_idl(pb.artifact_tag), + # Add HasField check if more things are ever added to the one of + source_task_execution=_identifier.TaskExecutionIdentifier.from_flyte_idl(pb.source_task_execution), + ) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 042b3ad909..a29bad80f1 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -1,3 +1,5 @@ +import typing + import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 @@ -7,6 +9,7 @@ from flytekit.models import literals as _literals_models from flytekit.models.core import execution as _core_execution from flytekit.models.core import identifier as _identifier +from flytekit.models.node_execution import DynamicWorkflowNodeMetadata class ExecutionMetadata(_common_models.FlyteIdlEntity): @@ -238,7 +241,6 @@ class Execution(_common_models.FlyteIdlEntity): def __init__(self, id, spec, closure): """ :param flytekit.models.core.identifier.WorkflowExecutionIdentifier id: - :param Text id: :param ExecutionSpec spec: :param ExecutionClosure closure: """ @@ -403,8 +405,8 @@ def __init__(self, inputs, outputs, full_inputs, full_outputs): """ :param _common_models.UrlBlob inputs: :param _common_models.UrlBlob outputs: - :param _literals_pb2.LiteralMap full_inputs: - :param _literals_pb2.LiteralMap full_outputs: + :param _literals_models.LiteralMap full_inputs: + :param _literals_models.LiteralMap full_outputs: """ self._inputs = inputs self._outputs = outputs @@ -428,14 +430,14 @@ def outputs(self): @property def full_inputs(self): """ - :rtype: _literals_pb2.LiteralMap + :rtype: _literals_models.LiteralMap """ return self._full_inputs @property def full_outputs(self): """ - :rtype: _literals_pb2.LiteralMap + :rtype: _literals_models.LiteralMap """ return self._full_outputs @@ -493,6 +495,14 @@ def to_flyte_idl(self): class NodeExecutionGetDataResponse(_CommonDataResponse): + def __init__(self, *args, dynamic_workflow: typing.Optional[DynamicWorkflowNodeMetadata] = None, **kwargs): + super().__init__(*args, **kwargs) + self._dynamic_workflow = dynamic_workflow + + @property + def dynamic_workflow(self) -> typing.Optional[DynamicWorkflowNodeMetadata]: + return self._dynamic_workflow + @classmethod def from_flyte_idl(cls, pb2_object): """ @@ -504,6 +514,9 @@ def from_flyte_idl(cls, pb2_object): outputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.outputs), full_inputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_inputs), full_outputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_outputs), + dynamic_workflow=DynamicWorkflowNodeMetadata.from_flyte_idl(pb2_object.dynamic_workflow) + if pb2_object.HasField("dynamic_workflow") + else None, ) def to_flyte_idl(self): @@ -515,4 +528,5 @@ def to_flyte_idl(self): outputs=self.outputs.to_flyte_idl(), full_inputs=self.full_inputs.to_flyte_idl(), full_outputs=self.full_outputs.to_flyte_idl(), + dynamic_workflow=self.dynamic_workflow.to_flyte_idl() if self.dynamic_workflow else None, ) diff --git a/flytekit/models/node_execution.py b/flytekit/models/node_execution.py index 5a0cda6a6e..762dfd196f 100644 --- a/flytekit/models/node_execution.py +++ b/flytekit/models/node_execution.py @@ -1,13 +1,101 @@ +import typing + import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import pytz as _pytz from flytekit.models import common as _common_models +from flytekit.models.core import catalog as catalog_models +from flytekit.models.core import compiler as core_compiler_models from flytekit.models.core import execution as _core_execution from flytekit.models.core import identifier as _identifier +class WorkflowNodeMetadata(_common_models.FlyteIdlEntity): + def __init__(self, execution_id: _identifier.WorkflowExecutionIdentifier): + self._execution_id = execution_id + + @property + def execution_id(self) -> _identifier.WorkflowExecutionIdentifier: + return self._execution_id + + def to_flyte_idl(self) -> _node_execution_pb2.WorkflowNodeMetadata: + return _node_execution_pb2.WorkflowNodeMetadata( + executionId=self.execution_id.to_flyte_idl(), + ) + + @classmethod + def from_flyte_idl(cls, p: _node_execution_pb2.WorkflowNodeMetadata) -> "WorkflowNodeMetadata": + return cls( + execution_id=_identifier.WorkflowExecutionIdentifier.from_flyte_idl(p.executionId), + ) + + +class DynamicWorkflowNodeMetadata(_common_models.FlyteIdlEntity): + def __init__(self, id: _identifier.Identifier, compiled_workflow: core_compiler_models.CompiledWorkflowClosure): + self._id = id + self._compiled_workflow = compiled_workflow + + @property + def id(self) -> _identifier.Identifier: + return self._id + + @property + def compiled_workflow(self) -> core_compiler_models.CompiledWorkflowClosure: + return self._compiled_workflow + + def to_flyte_idl(self) -> _node_execution_pb2.DynamicWorkflowNodeMetadata: + return _node_execution_pb2.DynamicWorkflowNodeMetadata( + id=self.id.to_flyte_idl(), + compiled_workflow=self.compiled_workflow.to_flyte_idl(), + ) + + @classmethod + def from_flyte_idl(cls, p: _node_execution_pb2.DynamicWorkflowNodeMetadata) -> "DynamicWorkflowNodeMetadata": + yy = cls( + id=_identifier.Identifier.from_flyte_idl(p.id), + compiled_workflow=core_compiler_models.CompiledWorkflowClosure.from_flyte_idl(p.compiled_workflow), + ) + return yy + + +class TaskNodeMetadata(_common_models.FlyteIdlEntity): + def __init__(self, cache_status: int, catalog_key: catalog_models.CatalogMetadata): + self._cache_status = cache_status + self._catalog_key = catalog_key + + @property + def cache_status(self) -> int: + return self._cache_status + + @property + def catalog_key(self) -> catalog_models.CatalogMetadata: + return self._catalog_key + + def to_flyte_idl(self) -> _node_execution_pb2.TaskNodeMetadata: + return _node_execution_pb2.TaskNodeMetadata( + cache_status=self.cache_status, + catalog_key=self.catalog_key.to_flyte_idl(), + ) + + @classmethod + def from_flyte_idl(cls, p: _node_execution_pb2.TaskNodeMetadata) -> "TaskNodeMetadata": + return cls( + cache_status=p.cache_status, + catalog_key=catalog_models.CatalogMetadata.from_flyte_idl(p.catalog_key), + ) + + class NodeExecutionClosure(_common_models.FlyteIdlEntity): - def __init__(self, phase, started_at, duration, output_uri=None, error=None): + def __init__( + self, + phase, + started_at, + duration, + output_uri=None, + error=None, + workflow_node_metadata: typing.Optional[WorkflowNodeMetadata] = None, + task_node_metadata: typing.Optional[TaskNodeMetadata] = None, + ): """ :param int phase: :param datetime.datetime started_at: @@ -20,6 +108,9 @@ def __init__(self, phase, started_at, duration, output_uri=None, error=None): self._duration = duration self._output_uri = output_uri self._error = error + self._workflow_node_metadata = workflow_node_metadata + self._task_node_metadata = task_node_metadata + # TODO: Add output_data field as well. @property def phase(self): @@ -56,6 +147,18 @@ def error(self): """ return self._error + @property + def workflow_node_metadata(self) -> typing.Optional[WorkflowNodeMetadata]: + return self._workflow_node_metadata + + @property + def task_node_metadata(self) -> typing.Optional[TaskNodeMetadata]: + return self._task_node_metadata + + @property + def target_metadata(self) -> typing.Union[WorkflowNodeMetadata, TaskNodeMetadata]: + return self.workflow_node_metadata or self.task_node_metadata + def to_flyte_idl(self): """ :rtype: flyteidl.admin.node_execution_pb2.NodeExecutionClosure @@ -64,6 +167,10 @@ def to_flyte_idl(self): phase=self.phase, output_uri=self.output_uri, error=self.error.to_flyte_idl() if self.error is not None else None, + workflow_node_metadata=self.workflow_node_metadata.to_flyte_idl() + if self.workflow_node_metadata is not None + else None, + task_node_metadata=self.task_node_metadata.to_flyte_idl() if self.task_node_metadata is not None else None, ) obj.started_at.FromDatetime(self.started_at.astimezone(_pytz.UTC).replace(tzinfo=None)) obj.duration.FromTimedelta(self.duration) @@ -81,6 +188,12 @@ def from_flyte_idl(cls, p): error=_core_execution.ExecutionError.from_flyte_idl(p.error) if p.HasField("error") else None, started_at=p.started_at.ToDatetime().replace(tzinfo=_pytz.UTC), duration=p.duration.ToTimedelta(), + workflow_node_metadata=WorkflowNodeMetadata.from_flyte_idl(p.workflow_node_metadata) + if p.HasField("workflow_node_metadata") + else None, + task_node_metadata=TaskNodeMetadata.from_flyte_idl(p.task_node_metadata) + if p.HasField("task_node_metadata") + else None, ) diff --git a/flytekit/remote/__init__.py b/flytekit/remote/__init__.py index 9dc4a5f0ed..51def80359 100644 --- a/flytekit/remote/__init__.py +++ b/flytekit/remote/__init__.py @@ -79,10 +79,9 @@ """ from flytekit.remote.component_nodes import FlyteTaskNode, FlyteWorkflowNode +from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.launch_plan import FlyteLaunchPlan -from flytekit.remote.nodes import FlyteNode, FlyteNodeExecution +from flytekit.remote.nodes import FlyteNode from flytekit.remote.remote import FlyteRemote -from flytekit.remote.tasks.executions import FlyteTaskExecution -from flytekit.remote.tasks.task import FlyteTask +from flytekit.remote.task import FlyteTask from flytekit.remote.workflow import FlyteWorkflow -from flytekit.remote.workflow_execution import FlyteWorkflowExecution diff --git a/flytekit/remote/component_nodes.py b/flytekit/remote/component_nodes.py index 06b885abfd..367cab8997 100644 --- a/flytekit/remote/component_nodes.py +++ b/flytekit/remote/component_nodes.py @@ -1,23 +1,22 @@ import logging as _logging from typing import Dict -import flytekit from flytekit.common.exceptions import system as _system_exceptions from flytekit.models import launch_plan as _launch_plan_model from flytekit.models import task as _task_model +from flytekit.models.core import identifier as id_models from flytekit.models.core import workflow as _workflow_model -from flytekit.remote import identifier as _identifier class FlyteTaskNode(_workflow_model.TaskNode): """A class encapsulating a task that a Flyte node needs to execute.""" - def __init__(self, flyte_task: "flytekit.remote.tasks.task.FlyteTask"): + def __init__(self, flyte_task: "flytekit.remote.task.FlyteTask"): self._flyte_task = flyte_task super(FlyteTaskNode, self).__init__(None) @property - def reference_id(self) -> _identifier.Identifier: + def reference_id(self) -> id_models.Identifier: """A globally unique identifier for the task.""" return self._flyte_task.id @@ -29,7 +28,7 @@ def flyte_task(self) -> "flytekit.remote.tasks.task.FlyteTask": def promote_from_model( cls, base_model: _workflow_model.TaskNode, - tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate], + tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], ) -> "FlyteTaskNode": """ Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the @@ -38,12 +37,12 @@ def promote_from_model( :param base_model: :param tasks: """ - from flytekit.remote.tasks import task as _task + from flytekit.remote.task import FlyteTask if base_model.reference_id in tasks: task = tasks[base_model.reference_id] _logging.info(f"Found existing task template for {task.id}, will not retrieve from Admin") - flyte_task = _task.FlyteTask.promote_from_model(task) + flyte_task = FlyteTask.promote_from_model(task) return cls(flyte_task) raise _system_exceptions.FlyteSystemException(f"Task template {base_model.reference_id} not found.") @@ -76,7 +75,7 @@ def __repr__(self) -> str: return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" @property - def launchplan_ref(self) -> _identifier.Identifier: + def launchplan_ref(self) -> id_models.Identifier: """A globally unique identifier for the launch plan, which should map to Admin.""" return self._flyte_launch_plan.id if self._flyte_launch_plan else None @@ -96,9 +95,9 @@ def flyte_workflow(self) -> "flytekit.remote.workflow.FlyteWorkflow": def promote_from_model( cls, base_model: _workflow_model.WorkflowNode, - sub_workflows: Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate], - node_launch_plans: Dict[_identifier.Identifier, _launch_plan_model.LaunchPlanSpec], - tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate], + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], ) -> "FlyteWorkflowNode": from flytekit.remote import launch_plan as _launch_plan from flytekit.remote import workflow as _workflow diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py new file mode 100644 index 0000000000..05b94b2302 --- /dev/null +++ b/flytekit/remote/executions.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union + +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.exceptions import user as user_exceptions +from flytekit.models import execution as execution_models +from flytekit.models import node_execution as node_execution_models +from flytekit.models.admin import task_execution as admin_task_execution_models +from flytekit.models.core import execution as core_execution_models +from flytekit.remote.workflow import FlyteWorkflow + + +class FlyteTaskExecution(admin_task_execution_models.TaskExecution): + """A class encapsulating a task execution being run on a Flyte remote backend.""" + + def __init__(self, *args, **kwargs): + super(FlyteTaskExecution, self).__init__(*args, **kwargs) + self._inputs = None + self._outputs = None + + @property + def is_complete(self) -> bool: + """Whether or not the execution is complete.""" + return self.closure.phase in { + core_execution_models.TaskExecutionPhase.ABORTED, + core_execution_models.TaskExecutionPhase.FAILED, + core_execution_models.TaskExecutionPhase.SUCCEEDED, + } + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs of the task execution in the standard Python format that is produced by + the type engine. + """ + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs of the task execution, if available, in the standard Python format that is produced by + the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting the outputs." + ) + if self.error: + raise user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + return self._outputs + + @property + def error(self) -> Optional[core_execution_models.ExecutionError]: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise user_exceptions.FlyteAssertion( + "Please what until the task execution has completed before requesting error information." + ) + return self.closure.error + + @classmethod + def promote_from_model(cls, base_model: admin_task_execution_models.TaskExecution) -> "FlyteTaskExecution": + return cls( + closure=base_model.closure, + id=base_model.id, + input_uri=base_model.input_uri, + is_parent=base_model.is_parent, + ) + + +class FlyteWorkflowExecution(execution_models.Execution): + """A class encapsulating a workflow execution being run on a Flyte remote backend.""" + + def __init__(self, *args, **kwargs): + super(FlyteWorkflowExecution, self).__init__(*args, **kwargs) + self._node_executions = None + self._inputs = None + self._outputs = None + self._flyte_workflow: Optional[FlyteWorkflow] = None + + @property + def node_executions(self) -> Dict[str, "FlyteNodeExecution"]: + """Get a dictionary of node executions that are a part of this workflow execution.""" + return self._node_executions or {} + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs to the execution in the standard python format as dictated by the type engine. + """ + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs to the execution in the standard python format as dictated by the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting the outputs." + ) + if self.error: + raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + return self._outputs + + @property + def error(self) -> core_execution_models.ExecutionError: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until a workflow has completed before checking for an error." + ) + return self.closure.error + + @property + def is_complete(self) -> bool: + """ + Whether or not the execution is complete. + """ + return self.closure.phase in { + core_execution_models.WorkflowExecutionPhase.ABORTED, + core_execution_models.WorkflowExecutionPhase.FAILED, + core_execution_models.WorkflowExecutionPhase.SUCCEEDED, + core_execution_models.WorkflowExecutionPhase.TIMED_OUT, + } + + @classmethod + def promote_from_model(cls, base_model: execution_models.Execution) -> "FlyteWorkflowExecution": + return cls( + closure=base_model.closure, + id=base_model.id, + spec=base_model.spec, + ) + + +class FlyteNodeExecution(node_execution_models.NodeExecution): + """A class encapsulating a node execution being run on a Flyte remote backend.""" + + def __init__(self, *args, **kwargs): + super(FlyteNodeExecution, self).__init__(*args, **kwargs) + self._task_executions = None + self._workflow_executions = [] + self._underlying_node_executions = None + self._inputs = None + self._outputs = None + self._interface = None + + @property + def task_executions(self) -> List[FlyteTaskExecution]: + return self._task_executions or [] + + @property + def workflow_executions(self) -> List[FlyteWorkflowExecution]: + return self._workflow_executions + + @property + def subworkflow_node_executions(self) -> Dict[str, FlyteNodeExecution]: + """ + This returns underlying node executions in instances where the current node execution is + a parent node. This happens when it's either a static or dynamic subworkflow. + """ + return ( + {} + if self._underlying_node_executions is None + else {n.id.node_id: n for n in self._underlying_node_executions} + ) + + @property + def executions(self) -> List[Union[FlyteTaskExecution, FlyteWorkflowExecution]]: + return self.task_executions or self._underlying_node_executions or [] + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs to the execution in the standard python format as dictated by the type engine. + """ + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs to the execution in the standard python format as dictated by the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting the outputs." + ) + if self.error: + raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + return self._outputs + + @property + def error(self) -> core_execution_models.ExecutionError: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting error information." + ) + return self.closure.error + + @property + def is_complete(self) -> bool: + """Whether or not the execution is complete.""" + return self.closure.phase in { + core_execution_models.NodeExecutionPhase.ABORTED, + core_execution_models.NodeExecutionPhase.FAILED, + core_execution_models.NodeExecutionPhase.SKIPPED, + core_execution_models.NodeExecutionPhase.SUCCEEDED, + core_execution_models.NodeExecutionPhase.TIMED_OUT, + } + + @classmethod + def promote_from_model(cls, base_model: node_execution_models.NodeExecution) -> "FlyteNodeExecution": + return cls( + closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata + ) + + @property + def interface(self) -> "flytekit.remote.interface.TypedInterface": + """ + Return the interface of the task or subworkflow associated with this node execution. + """ + return self._interface diff --git a/flytekit/remote/identifier.py b/flytekit/remote/identifier.py deleted file mode 100644 index 611c9af639..0000000000 --- a/flytekit/remote/identifier.py +++ /dev/null @@ -1,137 +0,0 @@ -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models.core import identifier as _core_identifier - - -class Identifier(_core_identifier.Identifier): - - _STRING_TO_TYPE_MAP = { - "lp": _core_identifier.ResourceType.LAUNCH_PLAN, - "wf": _core_identifier.ResourceType.WORKFLOW, - "tsk": _core_identifier.ResourceType.TASK, - } - _TYPE_TO_STRING_MAP = {v: k for k, v in _STRING_TO_TYPE_MAP.items()} - - @classmethod - def promote_from_model(cls, base_model: _core_identifier.Identifier) -> "Identifier": - return cls(base_model.resource_type, base_model.project, base_model.domain, base_model.name, base_model.version) - - @classmethod - def from_urn(cls, urn: str) -> "Identifier": - """ - Parses a string urn in the correct format into an identifier - """ - segments = urn.split(":") - if len(segments) != 5: - raise _user_exceptions.FlyteValueException( - urn, - "The provided string was not in a parseable format. The string for an identifier must be in the " - "format entity_type:project:domain:name:version.", - ) - - resource_type, project, domain, name, version = segments - - if resource_type not in cls._STRING_TO_TYPE_MAP: - raise _user_exceptions.FlyteValueException( - resource_type, - "The provided string could not be parsed. The first element of an identifier must be one of: " - f"{list(cls._STRING_TO_TYPE_MAP.keys())}. ", - ) - - return cls(cls._STRING_TO_TYPE_MAP[resource_type], project, domain, name, version) - - def __str__(self): - return ( - f"{type(self)._TYPE_TO_STRING_MAP.get(self.resource_type, '')}:" - f"{self.project}:" - f"{self.domain}:" - f"{self.name}:" - f"{self.version}" - ) - - -class WorkflowExecutionIdentifier(_core_identifier.WorkflowExecutionIdentifier): - @classmethod - def promote_from_model( - cls, base_model: _core_identifier.WorkflowExecutionIdentifier - ) -> "WorkflowExecutionIdentifier": - return cls(base_model.project, base_model.domain, base_model.name) - - @classmethod - def from_urn(cls, string: str) -> "WorkflowExecutionIdentifier": - """ - Parses a string in the correct format into an identifier - """ - segments = string.split(":") - if len(segments) != 4: - raise _user_exceptions.FlyteValueException( - string, - "The provided string was not in a parseable format. The string for an identifier must be in the format" - " ex:project:domain:name.", - ) - - resource_type, project, domain, name = segments - - if resource_type != "ex": - raise _user_exceptions.FlyteValueException( - resource_type, - "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", - ) - - return cls(project, domain, name) - - def __str__(self): - return f"ex:{self.project}:{self.domain}:{self.name}" - - -class TaskExecutionIdentifier(_core_identifier.TaskExecutionIdentifier): - @classmethod - def promote_from_model(cls, base_model: _core_identifier.TaskExecutionIdentifier) -> "TaskExecutionIdentifier": - return cls( - task_id=base_model.task_id, - node_execution_id=base_model.node_execution_id, - retry_attempt=base_model.retry_attempt, - ) - - @classmethod - def from_urn(cls, string: str) -> "TaskExecutionIdentifier": - """ - Parses a string in the correct format into an identifier - """ - segments = string.split(":") - if len(segments) != 10: - raise _user_exceptions.FlyteValueException( - string, - "The provided string was not in a parseable format. The string for an identifier must be in the format" - " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.", - ) - - resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments - - if resource_type != "te": - raise _user_exceptions.FlyteValueException( - resource_type, - "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", - ) - - return cls( - task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), - node_execution_id=_core_identifier.NodeExecutionIdentifier( - node_id=node_id, - execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), - ), - retry_attempt=int(retry), - ) - - def __str__(self): - return ( - "te:" - f"{self.node_execution_id.execution_id.project}:" - f"{self.node_execution_id.execution_id.domain}:" - f"{self.node_execution_id.execution_id.name}:" - f"{self.node_execution_id.node_id}:" - f"{self.task_id.project}:" - f"{self.task_id.domain}:" - f"{self.task_id.name}:" - f"{self.task_id.version}:" - f"{self.retry_attempt}" - ) diff --git a/flytekit/remote/interface.py b/flytekit/remote/interface.py index 6aeb0e236d..df61c8e336 100644 --- a/flytekit/remote/interface.py +++ b/flytekit/remote/interface.py @@ -1,8 +1,4 @@ -from typing import Any, Dict, List, Tuple - from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models -from flytekit.remote import nodes as _nodes class TypedInterface(_interface_models.TypedInterface): @@ -13,12 +9,3 @@ def promote_from_model(cls, model): :rtype: TypedInterface """ return cls(model.inputs, model.outputs) - - def create_bindings_for_inputs( - self, map_of_bindings: Dict[str, Any] - ) -> Tuple[List[_literal_models.Binding], List[_nodes.FlyteNode]]: - """ - :param: map_of_bindings: this can be scalar primitives, it can be node output references, lists, etc. - :raises: flytekit.common.exceptions.user.FlyteAssertion - """ - return [], [] diff --git a/flytekit/remote/launch_plan.py b/flytekit/remote/launch_plan.py index 200244f394..016e3a3489 100644 --- a/flytekit/remote/launch_plan.py +++ b/flytekit/remote/launch_plan.py @@ -7,8 +7,7 @@ from flytekit.engines.flyte import engine as _flyte_engine from flytekit.models import interface as _interface_models from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models.core import identifier as _identifier_model -from flytekit.remote import identifier as _identifier +from flytekit.models.core import identifier as id_models from flytekit.remote import interface as _interface @@ -27,11 +26,11 @@ def __init__(self, id, *args, **kwargs): @classmethod def promote_from_model( - cls, id: _identifier.Identifier, model: _launch_plan_models.LaunchPlanSpec + cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec ) -> "FlyteLaunchPlan": lp = cls( id=id, - workflow_id=_identifier.Identifier.promote_from_model(model.workflow_id), + workflow_id=model.workflow_id, default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), fixed_inputs=model.fixed_inputs, entity_metadata=model.entity_metadata, @@ -50,7 +49,7 @@ def promote_from_model( return lp @property - def id(self) -> _identifier.Identifier: + def id(self) -> id_models.Identifier: return self._id @property @@ -65,7 +64,7 @@ def is_scheduled(self) -> bool: return False @property - def workflow_id(self) -> _identifier.Identifier: + def workflow_id(self) -> id_models.Identifier: return self._workflow_id @property @@ -78,8 +77,8 @@ def interface(self) -> _interface.TypedInterface: return self._interface @property - def resource_type(self) -> _identifier_model.ResourceType: - return _identifier_model.ResourceType.LAUNCH_PLAN + def resource_type(self) -> id_models.ResourceType: + return id_models.ResourceType.LAUNCH_PLAN @property def entity_type_text(self) -> str: diff --git a/flytekit/remote/nodes.py b/flytekit/remote/nodes.py index 68d84f00ea..f8ae1b2d6a 100644 --- a/flytekit/remote/nodes.py +++ b/flytekit/remote/nodes.py @@ -1,24 +1,18 @@ +from __future__ import annotations + import logging as _logging -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union -import flytekit -from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions from flytekit.common import constants as _constants from flytekit.common.exceptions import system as _system_exceptions from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import artifact as _artifact_mixin from flytekit.common.mixins import hash as _hash_mixin -from flytekit.common.utils import _dnsify from flytekit.core.promise import NodeOutput -from flytekit.engines.flyte import engine as _flyte_engine from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import node_execution as _node_execution_models from flytekit.models import task as _task_model -from flytekit.models.core import execution as _execution_models +from flytekit.models.core import identifier as id_models from flytekit.models.core import workflow as _workflow_model from flytekit.remote import component_nodes as _component_nodes -from flytekit.remote import identifier as _identifier -from flytekit.remote.tasks.executions import FlyteTaskExecution class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): @@ -34,7 +28,6 @@ def __init__( flyte_workflow: Optional["FlyteWorkflow"] = None, flyte_launch_plan: Optional["FlyteLaunchPlan"] = None, flyte_branch=None, - parameter_mapping=True, ): non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch])) if len(non_none_entities) != 1: @@ -50,15 +43,20 @@ def __init__( elif flyte_launch_plan is not None: workflow_node = _component_nodes.FlyteWorkflowNode(flyte_launch_plan=flyte_launch_plan) + task_node = None + if flyte_task: + task_node = _component_nodes.FlyteTaskNode(flyte_task) + branch_node = None + super(FlyteNode, self).__init__( - id=_dnsify(id) if id else None, + id=id, metadata=metadata, inputs=bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], - task_node=_component_nodes.FlyteTaskNode(flyte_task) if flyte_task else None, + task_node=task_node, workflow_node=workflow_node, - branch_node=flyte_branch, + branch_node=branch_node, ) self._upstream = upstream_nodes @@ -70,11 +68,12 @@ def flyte_entity(self) -> Union["FlyteTask", "FlyteWorkflow", "FlyteLaunchPlan"] def promote_from_model( cls, model: _workflow_model.Node, - sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate]], - node_launch_plans: Optional[Dict[_identifier.Identifier, _launch_plan_model.LaunchPlanSpec]], - tasks: Optional[Dict[_identifier.Identifier, _task_model.TaskTemplate]], - ) -> "FlyteNode": - id = model.id + sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], + node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], + tasks: Optional[Dict[id_models.Identifier, _task_model.TaskTemplate]], + ) -> FlyteNode: + node_model_id = model.id + # TODO: Consider removing if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: _logging.warning(f"Should not call promote from model on a start node or end node {model}") return None @@ -97,6 +96,7 @@ def promote_from_model( # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. + # TODO: Consider removing for model_input in model.inputs: if ( model_input.binding.promise is not None @@ -106,7 +106,7 @@ def promote_from_model( if flyte_task_node is not None: return cls( - id=id, + id=node_model_id, upstream_nodes=[], # set downstream, model doesn't contain this information bindings=model.inputs, metadata=model.metadata, @@ -115,7 +115,7 @@ def promote_from_model( elif flyte_workflow_node is not None: if flyte_workflow_node.flyte_workflow is not None: return cls( - id=id, + id=node_model_id, upstream_nodes=[], # set downstream, model doesn't contain this information bindings=model.inputs, metadata=model.metadata, @@ -123,7 +123,7 @@ def promote_from_model( ) elif flyte_workflow_node.flyte_launch_plan is not None: return cls( - id=id, + id=node_model_id, upstream_nodes=[], # set downstream, model doesn't contain this information bindings=model.inputs, metadata=model.metadata, @@ -135,7 +135,7 @@ def promote_from_model( raise _system_exceptions.FlyteSystemException("Bad FlyteNode model, both task and workflow nodes are empty") @property - def upstream_nodes(self) -> List["FlyteNode"]: + def upstream_nodes(self) -> List[FlyteNode]: return self._upstream @property @@ -146,136 +146,5 @@ def upstream_node_ids(self) -> List[str]: def outputs(self) -> Dict[str, NodeOutput]: return self._outputs - def assign_id_and_return(self, id: str): - if self.id: - raise _user_exceptions.FlyteAssertion( - f"Error assigning ID: {id} because {self} is already assigned. Has this node been ssigned to another " - "workflow already?" - ) - self._id = _dnsify(id) if id else None - self._metadata.name = id - return self - - def with_overrides(self, *args, **kwargs): - # TODO: Implement overrides - raise NotImplementedError("Overrides are not supported in Flyte yet.") - def __repr__(self) -> str: return f"Node(ID: {self.id})" - - -class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact): - """A class encapsulating a node execution being run on a Flyte remote backend.""" - - def __init__(self, *args, **kwargs): - super(FlyteNodeExecution, self).__init__(*args, **kwargs) - self._task_executions = None - self._subworkflow_node_executions = None - self._inputs = None - self._outputs = None - self._interface = None - - @property - def task_executions(self) -> List["flytekit.remote.tasks.executions.FlyteTaskExecution"]: - return self._task_executions or [] - - @property - def subworkflow_node_executions(self) -> Dict[str, "flytekit.remote.nodes.FlyteNodeExecution"]: - return ( - {} - if self._subworkflow_node_executions is None - else {n.id.node_id: n for n in self._subworkflow_node_executions} - ) - - @property - def executions(self) -> List[_artifact_mixin.ExecutionArtifact]: - return self.task_executions or self._subworkflow_node_executions or [] - - @property - def inputs(self) -> Dict[str, Any]: - """ - Returns the inputs to the execution in the standard python format as dictated by the type engine. - """ - return self._inputs - - @property - def outputs(self) -> Dict[str, Any]: - """ - Returns the outputs to the execution in the standard python format as dictated by the type engine. - - :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please wait until the node execution has completed before requesting the outputs." - ) - if self.error: - raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") - return self._outputs - - @property - def error(self) -> _execution_models.ExecutionError: - """ - If execution is in progress, raise an exception. Otherwise, return None if no error was present upon - reaching completion. - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please wait until the node execution has completed before requesting error information." - ) - return self.closure.error - - @property - def is_complete(self) -> bool: - """Whether or not the execution is complete.""" - return self.closure.phase in { - _execution_models.NodeExecutionPhase.ABORTED, - _execution_models.NodeExecutionPhase.FAILED, - _execution_models.NodeExecutionPhase.SKIPPED, - _execution_models.NodeExecutionPhase.SUCCEEDED, - _execution_models.NodeExecutionPhase.TIMED_OUT, - } - - @classmethod - def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) -> "FlyteNodeExecution": - return cls( - closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata - ) - - @property - def interface(self) -> "flytekit.remote.interface.TypedInterface": - """ - Return the interface of the task or subworkflow associated with this node execution. - """ - return self._interface - - def sync(self): - """ - Syncs the state of the underlying execution artifact with the state observed by the platform. - """ - if self.metadata.is_parent_node: - if not self.is_complete or self._subworkflow_node_executions is None: - self._subworkflow_node_executions = [ - FlyteNodeExecution.promote_from_model(n) - for n in iterate_node_executions( - _flyte_engine.get_client(), - workflow_execution_identifier=self.id.execution_id, - unique_parent_id=self.id.node_id, - ) - ] - else: - if not self.is_complete or self._task_executions is None: - self._task_executions = [ - FlyteTaskExecution.promote_from_model(t) - for t in iterate_task_executions(_flyte_engine.get_client(), self.id) - ] - - self._sync_closure() - for execution in self.executions: - execution.sync() - - def _sync_closure(self): - """ - Syncs the closure of the underlying execution artifact with the state observed by the platform. - """ - self._closure = _flyte_engine.get_client().get_node_execution(self.id).closure diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index e859db730d..04e18d6794 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1,4 +1,8 @@ -"""Module defining main Flyte backend entrypoint.""" +""" +This module provides the ``FlyteRemote`` object, which is the end-user's main starting point for interacting +with a Flyte backend in an interactive and programmatic way. This of this experience as kind of like the web UI +but in Python object form. +""" from __future__ import annotations import logging @@ -50,7 +54,8 @@ from flytekit.models import launch_plan as launch_plan_models from flytekit.models import literals as literal_models from flytekit.models.admin.common import Sort -from flytekit.models.core.identifier import ResourceType +from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier +from flytekit.models.core.workflow import NodeMetadata from flytekit.models.execution import ( ExecutionMetadata, ExecutionSpec, @@ -58,14 +63,11 @@ NotificationList, WorkflowExecutionGetDataResponse, ) -from flytekit.remote.identifier import Identifier, WorkflowExecutionIdentifier -from flytekit.remote.interface import TypedInterface +from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.launch_plan import FlyteLaunchPlan -from flytekit.remote.nodes import FlyteNodeExecution -from flytekit.remote.tasks.executions import FlyteTaskExecution -from flytekit.remote.tasks.task import FlyteTask +from flytekit.remote.nodes import FlyteNode +from flytekit.remote.task import FlyteTask from flytekit.remote.workflow import FlyteWorkflow -from flytekit.remote.workflow_execution import FlyteWorkflowExecution ExecutionDataResponse = typing.Union[WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse] @@ -378,8 +380,6 @@ def fetch_workflow( :param domain: fetch entity from this domain. If None, uses the default_domain attribute. :param name: fetch entity with matching name. :param version: fetch entity with matching version. If None, gets the latest version of the entity. - :returns: :class:`~flytekit.remote.workflow.FlyteWorkflow` - :raises: FlyteAssertion if name is None """ if name is None: @@ -395,26 +395,15 @@ def fetch_workflow( admin_workflow = self.client.get_workflow(workflow_id) compiled_wf = admin_workflow.closure.compiled_workflow - base_model = compiled_wf.primary.template - sub_workflows = {sw.template.id: sw.template for sw in compiled_wf.sub_workflows} - tasks = {t.template.id: t.template for t in compiled_wf.tasks} - node_launch_plans = {} # TODO: Inspect branch nodes for launch plans - for node in FlyteWorkflow.get_non_system_nodes(base_model.nodes): + for node in FlyteWorkflow.get_non_system_nodes(compiled_wf.primary.template.nodes): if node.workflow_node is not None and node.workflow_node.launchplan_ref is not None: node_launch_plans[node.workflow_node.launchplan_ref] = self.client.get_launch_plan( node.workflow_node.launchplan_ref ).spec - flyte_workflow = FlyteWorkflow.promote_from_model( - base_model=compiled_wf.primary.template, - sub_workflows=sub_workflows, - node_launch_plans=node_launch_plans, - tasks=tasks, - ) - flyte_workflow._id = workflow_id - return flyte_workflow + return FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) def fetch_launch_plan( self, project: str = None, domain: str = None, name: str = None, version: str = None @@ -1015,7 +1004,7 @@ def sync( self, execution: FlyteWorkflowExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None, - sync_nodes: bool = True, + sync_nodes: bool = False, ) -> FlyteWorkflowExecution: """ This function was previously a singledispatchmethod. We've removed that but this function remains @@ -1036,63 +1025,198 @@ def sync_workflow_execution( self, execution: FlyteWorkflowExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None, - sync_nodes: bool = True, + sync_nodes: bool = False, ) -> FlyteWorkflowExecution: - - """Sync a FlyteWorkflowExecution object with its corresponding remote state.""" + """ + Sync a FlyteWorkflowExecution object with its corresponding remote state. + """ if entity_definition is not None: raise ValueError("Entity definition arguments aren't supported when syncing workflow executions") + + # Update closure, and then data, because we don't want the execution to finish between when we get the data, + # and then for the closure to have is_complete to be true. + execution._closure = self.client.get_execution(execution.id).closure execution_data = self.client.get_execution_data(execution.id) lp_id = execution.spec.launch_plan + if sync_nodes: + underlying_node_executions = [ + FlyteNodeExecution.promote_from_model(n) for n in iterate_node_executions(self.client, execution.id) + ] + if execution.spec.launch_plan.resource_type == ResourceType.TASK: + # This condition is only true for single-task executions flyte_entity = self.fetch_task(lp_id.project, lp_id.domain, lp_id.name, lp_id.version) + if sync_nodes: + # Need to construct the mapping. There should've been returned exactly three nodes, a start, + # an end, and a task node. + task_node_exec = [ + x + for x in filter( + lambda x: x.id.node_id != constants.START_NODE_ID and x.id.node_id != constants.END_NODE_ID, + underlying_node_executions, + ) + ] + # We need to manually make a map of the nodes since there is none for single task executions + # Assume the first one is the only one. + node_mapping = ( + { + task_node_exec[0].id.node_id: FlyteNode( + id=flyte_entity.id, + upstream_nodes=[], + bindings=[], + metadata=NodeMetadata(name=""), + flyte_task=flyte_entity, + ) + } + if len(task_node_exec) >= 1 + else {} # This is for the case where node executions haven't appeared yet + ) else: + # This is the default case, an execution of a normal workflow through a launch plan wf_id = self.fetch_launch_plan(lp_id.project, lp_id.domain, lp_id.name, lp_id.version).workflow_id flyte_entity = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) + execution._flyte_workflow = flyte_entity + node_mapping = flyte_entity._node_map - # sync closure, node executions, and inputs/outputs - execution._closure = self.client.get_execution(execution.id).closure + # update node executions (if requested), and inputs/outputs if sync_nodes: - execution._node_executions = { - node.id.node_id: self.sync_node_execution(FlyteNodeExecution.promote_from_model(node), flyte_entity) - for node in iterate_node_executions(self.client, execution.id) - } + node_execs = {} + for n in underlying_node_executions: + node_execs[n.id.node_id] = self.sync_node_execution(n, node_mapping) + execution._node_executions = node_execs return self._assign_inputs_and_outputs(execution, execution_data, flyte_entity.interface) def sync_node_execution( - self, execution: FlyteNodeExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None + self, execution: FlyteNodeExecution, node_mapping: typing.Dict[str, FlyteNode] ) -> FlyteNodeExecution: - """Sync a FlyteNodeExecution object with its corresponding remote state.""" - if ( - execution.id.node_id in {constants.START_NODE_ID, constants.END_NODE_ID} - or execution.id.node_id.endswith(constants.START_NODE_ID) - or execution.id.node_id.endswith(constants.END_NODE_ID) - ): + """ + Get data backing a node execution. These FlyteNodeExecution objects should've come from Admin with the model + fields already populated correctly. For purposes of the remote experience, we'd like to supplement the object + with some additional fields: + - inputs/outputs + - task/workflow executions, and/or underlying node executions in the case of parent nodes + - TypedInterface (remote wrapper type) + + A node can have several different types of executions behind it. That is, the node could've run (perhaps + multiple times because of retries): + - A task + - A static subworkflow + - A dynamic subworkflow (which in turn may have run additional tasks, subwfs, and/or launch plans) + - A launch plan + + The data model is complicated, so ascertaining which of these happened is a bit tricky. That logic is + encapsulated in this function. + """ + # For single task execution - the metadata spec node id is missing. In these cases, revert to regular node id + node_id = execution.metadata.spec_node_id + if not node_id: + node_id = execution.id.node_id + remote_logger.debug(f"No metadata spec_node_id found, using {node_id}") + + # First see if it's a dummy node, if it is, we just skip it. + if constants.START_NODE_ID in node_id or constants.END_NODE_ID in node_id: return execution - # sync closure, child nodes, interface, and inputs/outputs - execution._closure = self.client.get_node_execution(execution.id).closure + # Look for the Node object in the mapping supplied + if node_id in node_mapping: + execution._node = node_mapping[node_id] + else: + raise Exception(f"Missing node from mapping: {node_id}") + + # Get the node execution data + node_execution_get_data_response = self.client.get_node_execution_data(execution.id) + + # Calling a launch plan directly case + # If a node ran a launch plan directly (i.e. not through a dynamic task or anything) then + # the closure should have a workflow_node_metadata populated with the launched execution id. + # The parent node flag should not be populated here + # This is the simplest case + if not execution.metadata.is_parent_node and execution.closure.workflow_node_metadata: + launched_exec_id = execution.closure.workflow_node_metadata.execution_id + # This is a recursive call, basically going through the same process that brought us here in the first + # place, but on the launched execution. + launched_exec = self.fetch_workflow_execution( + project=launched_exec_id.project, domain=launched_exec_id.domain, name=launched_exec_id.name + ) + self.sync_workflow_execution(launched_exec) + if launched_exec.is_complete: + # The synced underlying execution should've had these populated. + execution._inputs = launched_exec.inputs + execution._outputs = launched_exec.outputs + execution._workflow_executions.append(launched_exec) + execution._interface = launched_exec._flyte_workflow.interface + return execution + + # If a node ran a static subworkflow or a dynamic subworkflow then the parent flag will be set. if execution.metadata.is_parent_node: - execution._subworkflow_node_executions = [ - self.sync_node_execution(FlyteNodeExecution.promote_from_model(node), entity_definition) - for node in iterate_node_executions( - self.client, - workflow_execution_identifier=execution.id.execution_id, - unique_parent_id=execution.id.node_id, - ) - ] + # We'll need to query child node executions regardless since this is a parent node + child_node_executions = iterate_node_executions( + self.client, + workflow_execution_identifier=execution.id.execution_id, + unique_parent_id=execution.id.node_id, + ) + child_node_executions = [x for x in child_node_executions] + + # If this was a dynamic task, then there should be a CompiledWorkflowClosure inside the + # NodeExecutionGetDataResponse + if node_execution_get_data_response.dynamic_workflow is not None: + compiled_wf = node_execution_get_data_response.dynamic_workflow.compiled_workflow + node_launch_plans = {} + # TODO: Inspect branch nodes for launch plans + for node in FlyteWorkflow.get_non_system_nodes(compiled_wf.primary.template.nodes): + if ( + node.workflow_node is not None + and node.workflow_node.launchplan_ref is not None + and node.workflow_node.launchplan_ref not in node_launch_plans + ): + node_launch_plans[node.workflow_node.launchplan_ref] = self.client.get_launch_plan( + node.workflow_node.launchplan_ref + ).spec + + dynamic_flyte_wf = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) + execution._underlying_node_executions = [ + self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), dynamic_flyte_wf._node_map) + for cne in child_node_executions + ] + # This is copied from below - dynamic tasks have both task executions (executions of the parent + # task) as well as underlying node executions (of the generated subworkflow). Feel free to refactor + # if you can think of a better way. + execution._task_executions = [ + self.sync_task_execution(FlyteTaskExecution.promote_from_model(t)) + for t in iterate_task_executions(self.client, execution.id) + ] + execution._interface = dynamic_flyte_wf.interface + else: + # If it does not, then it should be a static subworkflow + if not isinstance(execution._node.flyte_entity, FlyteWorkflow): + remote_logger.error( + f"NE {execution} entity should be a workflow, {type(execution._node)}, {execution._node}" + ) + raise Exception(f"Node entity has type {type(execution._node)}") + sub_flyte_workflow = execution._node.flyte_entity + sub_node_mapping = {n.id: n for n in sub_flyte_workflow.flyte_nodes} + execution._underlying_node_executions = [ + self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping) + for cne in child_node_executions + ] + execution._interface = sub_flyte_workflow.interface + + # This is the plain ol' task execution case else: execution._task_executions = [ self.sync_task_execution(FlyteTaskExecution.promote_from_model(t)) for t in iterate_task_executions(self.client, execution.id) ] - execution._interface = self._get_node_execution_interface(execution, entity_definition) - return self._assign_inputs_and_outputs( + execution._interface = execution._node.flyte_entity.interface + + self._assign_inputs_and_outputs( execution, - self.client.get_node_execution_data(execution.id), + node_execution_get_data_response, execution.interface, ) + return execution + def sync_task_execution( self, execution: FlyteTaskExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None ) -> FlyteTaskExecution: @@ -1123,7 +1247,12 @@ def terminate(self, execution: FlyteWorkflowExecution, cause: str): # Helper Methods # ################## - def _assign_inputs_and_outputs(self, execution, execution_data, interface): + def _assign_inputs_and_outputs( + self, + execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution], + execution_data, + interface, + ): """Helper for assigning synced inputs and outputs to an execution object.""" with self.remote_context() as ctx: execution._inputs = TypeEngine.literal_map_to_kwargs( @@ -1164,49 +1293,3 @@ def _get_output_literal_map(self, execution_data: ExecutionDataResponse) -> lite common_utils.load_proto_from_file(literals_pb2.LiteralMap, tmp_name) ) return literal_models.LiteralMap({}) - - def _get_node_execution_interface( - self, node_execution: FlyteNodeExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] - ) -> TypedInterface: - """Return the interface of the task or subworkflow associated with this node execution.""" - if isinstance(entity_definition, FlyteTask): - # A single task execution consists of a Flyte workflow with single node whose interface matches that of - # the underlying task - return entity_definition.interface - - for node in entity_definition.flyte_nodes: - if node.id == node_execution.id.node_id: - if node.task_node is not None: - return node.task_node.flyte_task.interface - elif node.workflow_node is not None and node.workflow_node.sub_workflow_ref is not None: - # Fetch the workflow and use its interface - sub_workflow_ref = node.workflow_node.sub_workflow_ref - workflow = self.fetch_workflow( - sub_workflow_ref.project, - sub_workflow_ref.domain, - sub_workflow_ref.name, - sub_workflow_ref.version, - ) - return workflow.interface - elif node.workflow_node is not None and node.workflow_node.launchplan_ref is not None: - # Fetch the launch plan this node launched, and from there fetch the referenced workflow and use its - # interface. - lp_ref = node.workflow_node.launchplan_ref - launch_plan = self.fetch_launch_plan(lp_ref.project, lp_ref.domain, lp_ref.name, lp_ref.version) - workflow = self.fetch_workflow( - launch_plan.workflow_id.project, - launch_plan.workflow_id.domain, - launch_plan.workflow_id.name, - launch_plan.workflow_id.version, - ) - return workflow.interface - - # dynamically generated nodes won't have a corresponding node in the compiled workflow closure. - # in that case, we fetch the interface from the underlying task execution they ran - if len(node_execution.task_executions) > 0: - # if not a parent node, assume a task execution node - task_id = node_execution.task_executions[0].id.task_id - task = self.fetch_task(task_id.project, task_id.domain, task_id.name, task_id.version) - return task.interface - - remote_logger.info("failed to find node interface from entity definition closure") diff --git a/flytekit/remote/tasks/task.py b/flytekit/remote/task.py similarity index 95% rename from flytekit/remote/tasks/task.py rename to flytekit/remote/task.py index 967f4e43aa..0c48f5f15e 100644 --- a/flytekit/remote/tasks/task.py +++ b/flytekit/remote/task.py @@ -6,7 +6,6 @@ from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.core import identifier as _identifier_model -from flytekit.remote import identifier as _identifier from flytekit.remote import interface as _interfaces @@ -61,7 +60,7 @@ def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> "FlyteTask" ) # Override the newly generated name if one exists in the base model if not base_model.id.is_empty: - t._id = _identifier.Identifier.promote_from_model(base_model.id) + t._id = base_model.id if t.interface is not None: try: diff --git a/flytekit/remote/tasks/__init__.py b/flytekit/remote/tasks/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/remote/tasks/executions.py b/flytekit/remote/tasks/executions.py deleted file mode 100644 index 9937b4be77..0000000000 --- a/flytekit/remote/tasks/executions.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Any, Dict, Optional - -from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import artifact as _artifact_mixin -from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models.admin import task_execution as _task_execution_model -from flytekit.models.core import execution as _execution_models - - -class FlyteTaskExecution(_task_execution_model.TaskExecution, _artifact_mixin.ExecutionArtifact): - """A class encapsulating a task execution being run on a Flyte remote backend.""" - - def __init__(self, *args, **kwargs): - super(FlyteTaskExecution, self).__init__(*args, **kwargs) - self._inputs = None - self._outputs = None - - @property - def is_complete(self) -> bool: - """Whether or not the execution is complete.""" - return self.closure.phase in { - _execution_models.TaskExecutionPhase.ABORTED, - _execution_models.TaskExecutionPhase.FAILED, - _execution_models.TaskExecutionPhase.SUCCEEDED, - } - - @property - def inputs(self) -> Dict[str, Any]: - """ - Returns the inputs of the task execution in the standard Python format that is produced by - the type engine. - """ - return self._inputs - - @property - def outputs(self) -> Dict[str, Any]: - """ - Returns the outputs of the task execution, if available, in the standard Python format that is produced by - the type engine. - - :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please wait until the node execution has completed before requesting the outputs." - ) - if self.error: - raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") - return self._outputs - - @property - def error(self) -> Optional[_execution_models.ExecutionError]: - """ - If execution is in progress, raise an exception. Otherwise, return None if no error was present upon - reaching completion. - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please what until the task execution has completed before requesting error information." - ) - return self.closure.error - - def get_child_executions(self, filters=None): - from flytekit.remote import nodes as _nodes - - if not self.is_parent: - raise _user_exceptions.FlyteAssertion("Only task executions marked with 'is_parent' have child executions.") - client = _flyte_engine.get_client() - models = { - v.id.node_id: v - for v in _iterate_node_executions(client, task_execution_identifier=self.id, filters=filters) - } - - return {k: _nodes.FlyteNodeExecution.promote_from_model(v) for k, v in models.items()} - - @classmethod - def promote_from_model(cls, base_model: _task_execution_model.TaskExecution) -> "FlyteTaskExecution": - return cls( - closure=base_model.closure, - id=base_model.id, - input_uri=base_model.input_uri, - is_parent=base_model.is_parent, - ) - - def sync(self): - """ - Syncs the state of the underlying execution artifact with the state observed by the platform. - """ - self._sync_closure() - - def _sync_closure(self): - """ - Syncs the closure of the underlying execution artifact with the state observed by the platform. - """ - self._closure = _flyte_engine.get_client().get_task_execution(self.id).closure diff --git a/flytekit/remote/workflow.py b/flytekit/remote/workflow.py index 66bb7f00e7..396f377500 100644 --- a/flytekit/remote/workflow.py +++ b/flytekit/remote/workflow.py @@ -1,16 +1,17 @@ +from __future__ import annotations + from typing import Dict, List, Optional from flytekit.common import constants as _constants -from flytekit.common.exceptions import system as _system_exceptions from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.mixins import hash as _hash_mixin from flytekit.core.interface import Interface from flytekit.core.type_engine import TypeEngine -from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import launch_plan as launch_plan_models from flytekit.models import task as _task_models -from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import compiler as compiler_models +from flytekit.models.core import identifier as id_models from flytekit.models.core import workflow as _workflow_models -from flytekit.remote import identifier as _identifier from flytekit.remote import interface as _interfaces from flytekit.remote import nodes as _nodes @@ -23,10 +24,15 @@ def __init__( nodes: List[_nodes.FlyteNode], interface, output_bindings, - id, + id: id_models.Identifier, metadata, metadata_defaults, + subworkflows: Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[id_models.Identifier, _task_models.TaskSpec]] = None, + launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None, + compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None, ): + # TODO: Remove check for node in nodes: for upstream in node.upstream_nodes: if upstream.id is None: @@ -46,9 +52,13 @@ def __init__( self._flyte_nodes = nodes self._python_interface = None - @property - def upstream_entities(self): - return set(n.executable_flyte_object for n in self._flyte_nodes) + # Optional things that we save for ease of access when promoting from a model or CompiledWorkflowClosure + self._subworkflows = subworkflows + self._tasks = tasks + self._launch_plans = launch_plans + self._compiled_closure = compiled_closure + + self._node_map = None @property def interface(self) -> _interfaces.TypedInterface: @@ -60,7 +70,7 @@ def entity_type_text(self) -> str: @property def resource_type(self): - return _identifier_model.ResourceType.WORKFLOW + return id_models.ResourceType.WORKFLOW @property def flyte_nodes(self) -> List[_nodes.FlyteNode]: @@ -76,37 +86,6 @@ def guessed_python_interface(self, value): return self._python_interface = value - def get_sub_workflows(self) -> List["FlyteWorkflow"]: - result = [] - for node in self.flyte_nodes: - if node.workflow_node is not None and node.workflow_node.sub_workflow_ref is not None: - if node.flyte_entity is not None and node.flyte_entity.entity_type_text == "Workflow": - result.append(node.flyte_entity) - result.extend(node.flyte_entity.get_sub_workflows()) - else: - raise _system_exceptions.FlyteSystemException( - "workflow node with subworkflow found but bad executable " "object {}".format(node.flyte_entity) - ) - - # get subworkflows in conditional branches - if node.branch_node is not None: - if_else: _workflow_models.IfElseBlock = node.branch_node.if_else - leaf_nodes: List[_nodes.FlyteNode] = filter( - None, - [ - if_else.case.then_node, - *([] if if_else.other is None else [x.then_node for x in if_else.other]), - if_else.else_node, - ], - ) - for leaf_node in leaf_nodes: - exec_flyte_obj = leaf_node.flyte_entity - if exec_flyte_obj is not None and exec_flyte_obj.entity_type_text == "Workflow": - result.append(exec_flyte_obj) - result.extend(exec_flyte_obj.get_sub_workflows()) - - return result - @classmethod def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] @@ -115,10 +94,10 @@ def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workf def promote_from_model( cls, base_model: _workflow_models.WorkflowTemplate, - sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_models.WorkflowTemplate]] = None, - node_launch_plans: Optional[Dict[_identifier.Identifier, _launch_plan_models.LaunchPlanSpec]] = None, - tasks: Optional[Dict[_identifier.Identifier, _task_models.TaskTemplate]] = None, - ) -> "FlyteWorkflow": + sub_workflows: Optional[Dict[id_models, _workflow_models.WorkflowTemplate]] = None, + node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, + tasks: Optional[Dict[id_models, _task_models.TaskTemplate]] = None, + ) -> FlyteWorkflow: base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) sub_workflows = sub_workflows or {} tasks = tasks or {} @@ -137,11 +116,14 @@ def promote_from_model( # No inputs/outputs specified, see the constructor for more information on the overrides. wf = cls( nodes=list(node_map.values()), - id=_identifier.Identifier.promote_from_model(base_model.id), + id=base_model.id, metadata=base_model.metadata, metadata_defaults=base_model.metadata_defaults, interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), output_bindings=base_model.outputs, + subworkflows=sub_workflows, + tasks=tasks, + launch_plans=node_launch_plans, ) if wf.interface is not None: @@ -149,8 +131,35 @@ def promote_from_model( inputs=TypeEngine.guess_python_types(wf.interface.inputs), outputs=TypeEngine.guess_python_types(wf.interface.outputs), ) + wf._node_map = node_map return wf + @classmethod + def promote_from_closure( + cls, + closure: compiler_models.CompiledWorkflowClosure, + node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, + ): + """ + Extracts out the relevant portions of a FlyteWorkflow from a closure from the control plane. + + :param closure: This is the closure returned by Admin + :param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans. + It only has subworkflows and tasks. Why this is is unclear. If supplied, this map of launch plans will be + :return: + """ + sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows} + tasks = {t.template.id: t.template for t in closure.tasks} + + flyte_wf = FlyteWorkflow.promote_from_model( + base_model=closure.primary.template, + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + flyte_wf._compiled_closure = closure + return flyte_wf + def __call__(self, *args, **input_map): raise NotImplementedError diff --git a/flytekit/remote/workflow_execution.py b/flytekit/remote/workflow_execution.py index 0c201c9056..e69de29bb2 100644 --- a/flytekit/remote/workflow_execution.py +++ b/flytekit/remote/workflow_execution.py @@ -1,76 +0,0 @@ -from typing import Any, Dict - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import execution as _execution_models -from flytekit.models.core import execution as _core_execution_models -from flytekit.remote import identifier as _core_identifier -from flytekit.remote import nodes as _nodes - - -class FlyteWorkflowExecution(_execution_models.Execution): - """A class encapsulating a workflow execution being run on a Flyte remote backend.""" - - def __init__(self, *args, **kwargs): - super(FlyteWorkflowExecution, self).__init__(*args, **kwargs) - self._node_executions = None - self._inputs = None - self._outputs = None - - @property - def node_executions(self) -> Dict[str, _nodes.FlyteNodeExecution]: - """Get a dictionary of node executions that are a part of this workflow execution.""" - return self._node_executions or {} - - @property - def inputs(self) -> Dict[str, Any]: - """ - Returns the inputs to the execution in the standard python format as dictated by the type engine. - """ - return self._inputs - - @property - def outputs(self) -> Dict[str, Any]: - """ - Returns the outputs to the execution in the standard python format as dictated by the type engine. - - :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please wait until the node execution has completed before requesting the outputs." - ) - if self.error: - raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") - return self._outputs - - @property - def error(self) -> _core_execution_models.ExecutionError: - """ - If execution is in progress, raise an exception. Otherwise, return None if no error was present upon - reaching completion. - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please wait until a workflow has completed before checking for an error." - ) - return self.closure.error - - @property - def is_complete(self) -> bool: - """ - Whether or not the execution is complete. - """ - return self.closure.phase in { - _core_execution_models.WorkflowExecutionPhase.ABORTED, - _core_execution_models.WorkflowExecutionPhase.FAILED, - _core_execution_models.WorkflowExecutionPhase.SUCCEEDED, - _core_execution_models.WorkflowExecutionPhase.TIMED_OUT, - } - - @classmethod - def promote_from_model(cls, base_model: _execution_models.Execution) -> "FlyteWorkflowExecution": - return cls( - closure=base_model.closure, - id=_core_identifier.WorkflowExecutionIdentifier.promote_from_model(base_model.id), - spec=base_model.spec, - ) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index c00564a5f3..9dc32a2e54 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -82,7 +82,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte poll_interval = datetime.timedelta(seconds=1) time_to_give_up = datetime.datetime.utcnow() + datetime.timedelta(seconds=60) - execution = remote.sync_workflow_execution(execution) + execution = remote.sync_workflow_execution(execution, sync_nodes=True) while datetime.datetime.utcnow() < time_to_give_up: if execution.is_complete: @@ -94,7 +94,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte execution.outputs time.sleep(poll_interval.total_seconds()) - execution = remote.sync_workflow_execution(execution) + execution = remote.sync_workflow_execution(execution, sync_nodes=True) if execution.node_executions: assert execution.node_executions["start-node"].closure.phase == 3 # SUCCEEEDED diff --git a/tests/flytekit/unit/models/admin/test_node_executions.py b/tests/flytekit/unit/models/admin/test_node_executions.py index 84d8785d09..b4cd77e5e8 100644 --- a/tests/flytekit/unit/models/admin/test_node_executions.py +++ b/tests/flytekit/unit/models/admin/test_node_executions.py @@ -1,7 +1,51 @@ from flytekit.models import node_execution as node_execution_models +from flytekit.models.core import catalog, identifier +from tests.flytekit.unit.common_tests.test_workflow_promote import get_compiled_workflow_closure def test_metadata(): md = node_execution_models.NodeExecutionMetaData(retry_group="0", is_parent_node=True, spec_node_id="n0") md2 = node_execution_models.NodeExecutionMetaData.from_flyte_idl(md.to_flyte_idl()) assert md == md2 + + +def test_workflow_node_metadata(): + wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") + + obj = node_execution_models.WorkflowNodeMetadata(execution_id=wf_exec_id) + assert obj.execution_id is wf_exec_id + + obj2 = node_execution_models.WorkflowNodeMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj == obj2 + + +def test_task_node_metadata(): + task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") + wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") + node_exec_id = identifier.NodeExecutionIdentifier( + "node_id", + wf_exec_id, + ) + te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) + ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "t1", "abcdef") + tag = catalog.CatalogArtifactTag("my-artifact-id", "some name") + catalog_metadata = catalog.CatalogMetadata(dataset_id=ds_id, artifact_tag=tag, source_task_execution=te_id) + + obj = node_execution_models.TaskNodeMetadata(cache_status=0, catalog_key=catalog_metadata) + assert obj.cache_status == 0 + assert obj.catalog_key == catalog_metadata + + obj2 = node_execution_models.TaskNodeMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj2 == obj + + +def test_dynamic_wf_node_metadata(): + wf_id = identifier.Identifier(identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version") + cwc = get_compiled_workflow_closure() + + obj = node_execution_models.DynamicWorkflowNodeMetadata(id=wf_id, compiled_workflow=cwc) + assert obj.id == wf_id + assert obj.compiled_workflow == cwc + + obj2 = node_execution_models.DynamicWorkflowNodeMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj2 == obj diff --git a/tests/flytekit/unit/models/core/test_catalog.py b/tests/flytekit/unit/models/core/test_catalog.py new file mode 100644 index 0000000000..c89fe6545c --- /dev/null +++ b/tests/flytekit/unit/models/core/test_catalog.py @@ -0,0 +1,32 @@ +from flytekit.models.core import catalog, identifier + + +def test_catalog_artifact_tag(): + obj = catalog.CatalogArtifactTag("my-artifact-id", "some name") + assert obj.artifact_id == "my-artifact-id" + assert obj.name == "some name" + + obj2 = catalog.CatalogArtifactTag.from_flyte_idl(obj.to_flyte_idl()) + assert obj == obj2 + assert obj2.artifact_id == "my-artifact-id" + assert obj2.name == "some name" + + +def test_catalog_metadata(): + task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") + wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") + node_exec_id = identifier.NodeExecutionIdentifier( + "node_id", + wf_exec_id, + ) + te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) + ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "t1", "abcdef") + tag = catalog.CatalogArtifactTag("my-artifact-id", "some name") + obj = catalog.CatalogMetadata(dataset_id=ds_id, artifact_tag=tag, source_task_execution=te_id) + assert obj.dataset_id is ds_id + assert obj.source_execution is te_id + assert obj.source_task_execution is te_id + assert obj.artifact_tag is tag + + obj2 = catalog.CatalogMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj == obj2 diff --git a/tests/flytekit/unit/remote/test_identifier.py b/tests/flytekit/unit/remote/test_identifier.py deleted file mode 100644 index 0335b1db76..0000000000 --- a/tests/flytekit/unit/remote/test_identifier.py +++ /dev/null @@ -1,77 +0,0 @@ -import pytest - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models.core import identifier as _core_identifier -from flytekit.remote import identifier as _identifier - - -def test_identifier(): - identifier = _identifier.Identifier(_core_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "v1") - assert identifier == _identifier.Identifier.from_urn("wf:project:domain:name:v1") - assert identifier == _core_identifier.Identifier( - _core_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "v1" - ) - assert identifier.__str__() == "wf:project:domain:name:v1" - - -@pytest.mark.parametrize( - "urn", - [ - "", - "project:domain:name:v1", - "wf:project:domain:name:v1:foobar", - "foobar:project:domain:name:v1", - ], -) -def test_identifier_exceptions(urn): - with pytest.raises(_user_exceptions.FlyteValueException): - _identifier.Identifier.from_urn(urn) - - -def test_workflow_execution_identifier(): - identifier = _identifier.WorkflowExecutionIdentifier("project", "domain", "name") - assert identifier == _identifier.WorkflowExecutionIdentifier.from_urn("ex:project:domain:name") - assert identifier == _identifier.WorkflowExecutionIdentifier.promote_from_model( - _core_identifier.WorkflowExecutionIdentifier("project", "domain", "name") - ) - assert identifier.__str__() == "ex:project:domain:name" - - -@pytest.mark.parametrize( - "urn", ["", "project:domain:name", "project:domain:name:foobar", "ex:project:domain:name:foobar"] -) -def test_workflow_execution_identifier_exceptions(urn): - with pytest.raises(_user_exceptions.FlyteValueException): - _identifier.WorkflowExecutionIdentifier.from_urn(urn) - - -def test_task_execution_identifier(): - task_id = _identifier.Identifier(_core_identifier.ResourceType.TASK, "project", "domain", "name", "version") - node_execution_id = _core_identifier.NodeExecutionIdentifier( - node_id="n0", execution_id=_core_identifier.WorkflowExecutionIdentifier("project", "domain", "name") - ) - identifier = _identifier.TaskExecutionIdentifier( - task_id=task_id, - node_execution_id=node_execution_id, - retry_attempt=0, - ) - assert identifier == _identifier.TaskExecutionIdentifier.from_urn( - "te:project:domain:name:n0:project:domain:name:version:0" - ) - assert identifier == _identifier.TaskExecutionIdentifier.promote_from_model( - _core_identifier.TaskExecutionIdentifier(task_id, node_execution_id, 0) - ) - assert identifier.__str__() == "te:project:domain:name:n0:project:domain:name:version:0" - - -@pytest.mark.parametrize( - "urn", - [ - "", - "te:project:domain:name:n0:project:domain:name:version", - "foobar:project:domain:name:n0:project:domain:name:version:0", - ], -) -def test_task_execution_identifier_exceptions(urn): - with pytest.raises(_user_exceptions.FlyteValueException): - _identifier.TaskExecutionIdentifier.from_urn(urn) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 873f67d752..2bb8174069 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -6,20 +6,8 @@ from flytekit.common.exceptions import user as user_exceptions from flytekit.configuration import internal from flytekit.models import common as common_models -from flytekit.models.admin.workflow import Workflow -from flytekit.models.core.identifier import ( - Identifier, - NodeExecutionIdentifier, - ResourceType, - WorkflowExecutionIdentifier, -) +from flytekit.models.core.identifier import ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution -from flytekit.models.interface import TypedInterface, Variable -from flytekit.models.launch_plan import LaunchPlan -from flytekit.models.node_execution import NodeExecution, NodeExecutionMetaData -from flytekit.models.task import Task -from flytekit.models.types import LiteralType, SimpleType -from flytekit.remote import FlyteWorkflow from flytekit.remote.remote import FlyteRemote CLIENT_METHODS = { @@ -41,50 +29,6 @@ } -@patch("flytekit.clients.friendly.SynchronousFlyteClient") -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -@pytest.mark.parametrize( - "entity_cls,resource_type", - [ - [Workflow, ResourceType.WORKFLOW], - [Task, ResourceType.TASK], - [LaunchPlan, ResourceType.LAUNCH_PLAN], - ], -) -def test_remote_fetch_execute_entities_task_workflow_launchplan( - mock_insecure, - mock_url, - mock_client, - entity_cls, - resource_type, -): - admin_entities = [ - entity_cls( - Identifier(resource_type, "p1", "d1", "n1", version), - *([MagicMock()] if resource_type != ResourceType.LAUNCH_PLAN else [MagicMock(), MagicMock()]), - ) - for version in ["latest", "old"] - ] - - mock_url.get.return_value = "localhost" - mock_insecure.get.return_value = True - mock_client = MagicMock() - getattr(mock_client, CLIENT_METHODS[resource_type]).return_value = admin_entities, "" - - remote = FlyteRemote.from_config("p1", "d1") - remote._client = mock_client - fetch_method = getattr(remote, REMOTE_METHODS[resource_type]) - flyte_entity_latest = fetch_method(name="n1", version="latest") - flyte_entity_latest_implicit = fetch_method(name="n1") - flyte_entity_old = fetch_method(name="n1", version="old") - - assert flyte_entity_latest.entity_type_text == ENTITY_TYPE_TEXT[resource_type] - assert flyte_entity_latest.id == admin_entities[0].id - assert flyte_entity_latest.id == flyte_entity_latest_implicit.id - assert flyte_entity_latest.id != flyte_entity_old.id - - @patch("flytekit.clients.friendly.SynchronousFlyteClient") @patch("flytekit.configuration.platform.URL") @patch("flytekit.configuration.platform.INSECURE") @@ -106,39 +50,7 @@ def test_remote_fetch_workflow_execution(mock_insecure, mock_url, mock_client_ma assert flyte_workflow_execution.id == admin_workflow_execution.id -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -def test_get_node_execution_interface(mock_insecure, mock_url): - expected_interface = TypedInterface( - {"in1": Variable(LiteralType(simple=SimpleType.STRING), "in1 description")}, - {"out1": Variable(LiteralType(simple=SimpleType.INTEGER), "out1 description")}, - ) - - node_exec_id = NodeExecutionIdentifier("node_id", WorkflowExecutionIdentifier("p1", "d1", "exec_name")) - - mock_node = MagicMock() - mock_node.id = node_exec_id.node_id - task_node = MagicMock() - flyte_task = MagicMock() - flyte_task.interface = expected_interface - task_node.flyte_task = flyte_task - mock_node.task_node = task_node - - flyte_workflow = FlyteWorkflow([mock_node], None, None, None, None, None) - - mock_url.get.return_value = "localhost" - mock_insecure.get.return_value = True - mock_client = MagicMock() - - remote = FlyteRemote.from_config("p1", "d1") - remote._client = mock_client - actual_interface = remote._get_node_execution_interface( - NodeExecution(node_exec_id, None, None, NodeExecutionMetaData(None, True, None)), flyte_workflow - ) - assert actual_interface == expected_interface - - -@patch("flytekit.remote.workflow_execution.FlyteWorkflowExecution.promote_from_model") +@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") @patch("flytekit.configuration.platform.URL") @patch("flytekit.configuration.platform.INSECURE") def test_underscore_execute_uses_launch_plan_attributes(mock_insecure, mock_url, mock_wf_exec): @@ -171,7 +83,7 @@ def local_assertions(*args, **kwargs): ) -@patch("flytekit.remote.workflow_execution.FlyteWorkflowExecution.promote_from_model") +@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") @patch("flytekit.configuration.auth.ASSUMABLE_IAM_ROLE") @patch("flytekit.configuration.platform.URL") @patch("flytekit.configuration.platform.INSECURE") @@ -201,7 +113,7 @@ def local_assertions(*args, **kwargs): ) -@patch("flytekit.remote.workflow_execution.FlyteWorkflowExecution.promote_from_model") +@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") @patch("flytekit.configuration.platform.URL") @patch("flytekit.configuration.platform.INSECURE") def test_execute_with_wrong_input_key(mock_insecure, mock_url, mock_wf_exec): diff --git a/tests/flytekit/unit/remote/test_wrapper_classes.py b/tests/flytekit/unit/remote/test_wrapper_classes.py index d37fd738c2..b26253e1db 100644 --- a/tests/flytekit/unit/remote/test_wrapper_classes.py +++ b/tests/flytekit/unit/remote/test_wrapper_classes.py @@ -68,9 +68,6 @@ def wf(b: int) -> int: assert list(fwf.interface.inputs.keys()) == ["b"] assert len(fwf.nodes) == 1 assert len(fwf.flyte_nodes) == 1 - flyte_subwfs = fwf.get_sub_workflows() - assert len(flyte_subwfs) == 1 - assert fwf.nodes[0].workflow_node.sub_workflow_ref == flyte_subwfs[0].id # Test another subwf that calls a launch plan instead of the sub_wf directly @workflow @@ -88,8 +85,6 @@ def wf2(b: int) -> int: assert list(fwf.interface.inputs.keys()) == ["b"] assert len(fwf.nodes) == 1 assert len(fwf.flyte_nodes) == 1 - flyte_subwfs = fwf.get_sub_workflows() - assert len(flyte_subwfs) == 0 # The resource type will be different, so just check the name assert fwf.nodes[0].workflow_node.launchplan_ref.name == list(lp_specs.values())[0].workflow_id.name