From 7cff20673189596494e5a02802b3ab9f992f4e54 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 18 Aug 2023 15:56:09 +0800 Subject: [PATCH 01/18] Add async delete function in base_agent Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index a6f21e1b0e..039b1a31a8 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -182,7 +182,7 @@ def execute(self, **kwargs) -> typing.Any: res = asyncio.run(agent.async_create(dummy_context, output_prefix, cp_entity.template, inputs)) else: res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs) - signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta)) + signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta)) state = RUNNING metadata = res.resource_meta progress = Progress(transient=True) @@ -211,5 +211,8 @@ def signal_handler( signum: int, frame: FrameType, ) -> typing.Any: - agent.delete(context, resource_meta) + if agent.asynchronous: + asyncio.run(agent.async_delete(context, resource_meta)) + else: + agent.delete(context, resource_meta) sys.exit(1) From 319e93f3c4333bb403018c0aca2bc6adc1bc37a5 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sun, 20 Aug 2023 12:23:10 +0800 Subject: [PATCH 02/18] Async Agent Delete Function For While Loop Case Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 039b1a31a8..b5f8b7a6a4 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -212,7 +212,9 @@ def signal_handler( frame: FrameType, ) -> typing.Any: if agent.asynchronous: - asyncio.run(agent.async_delete(context, resource_meta)) + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + new_loop.run_until_complete(agent.async_delete(context, resource_meta)) else: agent.delete(context, resource_meta) sys.exit(1) From 1986ffb380a51f5b5deab520744de208edbfa32e Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sun, 20 Aug 2023 12:37:31 +0800 Subject: [PATCH 03/18] add resource allocation mechanism Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index b5f8b7a6a4..227164bb95 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -212,9 +212,12 @@ def signal_handler( frame: FrameType, ) -> typing.Any: if agent.asynchronous: - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - new_loop.run_until_complete(agent.async_delete(context, resource_meta)) + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + new_loop.run_until_complete(agent.async_delete(context, resource_meta)) + finally: + new_loop.close() else: agent.delete(context, resource_meta) sys.exit(1) From bdfcb4ec5668dd551b597f0cf4f50abf4dd2cd94 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sun, 20 Aug 2023 12:43:45 +0800 Subject: [PATCH 04/18] Add better comments Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 227164bb95..13e4988168 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -217,6 +217,7 @@ def signal_handler( asyncio.set_event_loop(new_loop) new_loop.run_until_complete(agent.async_delete(context, resource_meta)) finally: + # Close the loop to prevent resource leakag new_loop.close() else: agent.delete(context, resource_meta) From e8daef3d4b144cbfe22e7820bda99cdedfc29c4c Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sun, 20 Aug 2023 12:45:23 +0800 Subject: [PATCH 05/18] fix spell error Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 13e4988168..8e29ca5ffd 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -217,7 +217,7 @@ def signal_handler( asyncio.set_event_loop(new_loop) new_loop.run_until_complete(agent.async_delete(context, resource_meta)) finally: - # Close the loop to prevent resource leakag + # Close the loop to prevent resource leakage new_loop.close() else: agent.delete(context, resource_meta) From d2dcb87fd03bc99277fb54922a1f78d163e5e0f5 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 21 Aug 2023 14:03:05 +0800 Subject: [PATCH 06/18] Fix the Async Agent Delete Base Agent Interface Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 41 ++++++++++++++++++--------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 8e29ca5ffd..bedcf0f289 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -4,6 +4,7 @@ import time import typing from abc import ABC +from asyncio import AbstractEventLoop from collections import OrderedDict from functools import partial from types import FrameType @@ -158,6 +159,8 @@ class AsyncAgentExecutorMixin: Task should inherit from this class if the task can be run in the agent. """ + aysnc_delete = False + def execute(self, **kwargs) -> typing.Any: from unittest.mock import MagicMock @@ -179,10 +182,21 @@ def execute(self, **kwargs) -> typing.Any: output_prefix = ctx.file_access.get_random_local_directory() cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) if agent.asynchronous: - res = asyncio.run(agent.async_create(dummy_context, output_prefix, cp_entity.template, inputs)) + loop = asyncio.new_event_loop() + res = loop.run_until_complete(agent.async_create(dummy_context, output_prefix, cp_entity.template, inputs)) + loop.add_signal_handler( + signal.SIGINT, + partial( + self.signal_handler, + agent=agent, + context=dummy_context, + resource_meta=res.resource_meta, + loop=loop, + ), + ) else: res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs) - signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta)) + signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta, None)) state = RUNNING metadata = res.resource_meta progress = Progress(transient=True) @@ -192,12 +206,17 @@ def execute(self, **kwargs) -> typing.Any: progress.start_task(task) time.sleep(1) if agent.asynchronous: - res = asyncio.run(agent.async_get(dummy_context, metadata)) + res = loop.run_until_complete(agent.async_get(dummy_context, metadata)) else: res = agent.get(dummy_context, metadata) state = res.resource.state logger.info(f"Task state: {state}") + if agent.asynchronous: + loop.close() + + if self.aysnc_delete: + sys.exit(1) if state != SUCCEEDED: raise Exception(f"Failed to run the task {entity.name}") @@ -208,17 +227,13 @@ def signal_handler( agent: AgentBase, context: grpc.ServicerContext, resource_meta: bytes, - signum: int, - frame: FrameType, + loop: typing.Optional[AbstractEventLoop] = None, + signum: typing.Optional[int] = None, + frame: typing.Optional[FrameType] = None, ) -> typing.Any: if agent.asynchronous: - try: - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - new_loop.run_until_complete(agent.async_delete(context, resource_meta)) - finally: - # Close the loop to prevent resource leakage - new_loop.close() + self.aysnc_delete = True + loop.create_task(agent.async_delete(context, resource_meta)) else: agent.delete(context, resource_meta) - sys.exit(1) + sys.exit(1) From c8c0fa23ae7de6c18e24434a11a7e96d19dbe12f Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 22 Aug 2023 16:53:00 +0800 Subject: [PATCH 07/18] revise base_agent.py async delete Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index bedcf0f289..06ddee6fa3 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -214,9 +214,9 @@ def execute(self, **kwargs) -> typing.Any: if agent.asynchronous: loop.close() + if self.aysnc_delete: + sys.exit(1) - if self.aysnc_delete: - sys.exit(1) if state != SUCCEEDED: raise Exception(f"Failed to run the task {entity.name}") From 1c9f3424b33b1e4a047c77d9abe9435f7bc28290 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 22 Aug 2023 16:57:40 +0800 Subject: [PATCH 08/18] delete new line Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 87e2a71d25..06ddee6fa3 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -197,7 +197,6 @@ def execute(self, **kwargs) -> typing.Any: else: res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs) signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta, None)) - state = RUNNING metadata = res.resource_meta progress = Progress(transient=True) From 3fa0431440305769b2437aa17b2cdc89f480c942 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 30 Aug 2023 09:22:11 +0800 Subject: [PATCH 09/18] change boolean variable sync_delete to is_canceled Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 06ddee6fa3..92f1208fef 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -159,7 +159,7 @@ class AsyncAgentExecutorMixin: Task should inherit from this class if the task can be run in the agent. """ - aysnc_delete = False + is_canceled = False def execute(self, **kwargs) -> typing.Any: from unittest.mock import MagicMock @@ -214,7 +214,7 @@ def execute(self, **kwargs) -> typing.Any: if agent.asynchronous: loop.close() - if self.aysnc_delete: + if self.is_canceled: sys.exit(1) if state != SUCCEEDED: @@ -232,7 +232,7 @@ def signal_handler( frame: typing.Optional[FrameType] = None, ) -> typing.Any: if agent.asynchronous: - self.aysnc_delete = True + self.is_canceled = True loop.create_task(agent.async_delete(context, resource_meta)) else: agent.delete(context, resource_meta) From fd9f755e4448195a03b82a316fc5eafb2810e228 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 1 Sep 2023 14:29:40 +0800 Subject: [PATCH 10/18] support async sensor Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 92f1208fef..d354278f2a 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -207,15 +207,14 @@ def execute(self, **kwargs) -> typing.Any: time.sleep(1) if agent.asynchronous: res = loop.run_until_complete(agent.async_get(dummy_context, metadata)) + if self.is_canceled and len(asyncio.all_tasks(loop=loop)) == 0: + sys.exit(1) else: res = agent.get(dummy_context, metadata) state = res.resource.state logger.info(f"Task state: {state}") - - if agent.asynchronous: - loop.close() - if self.is_canceled: - sys.exit(1) + if agent.asynchronous: + loop.close() if state != SUCCEEDED: raise Exception(f"Failed to run the task {entity.name}") @@ -232,8 +231,8 @@ def signal_handler( frame: typing.Optional[FrameType] = None, ) -> typing.Any: if agent.asynchronous: - self.is_canceled = True loop.create_task(agent.async_delete(context, resource_meta)) + self.is_canceled = True else: agent.delete(context, resource_meta) sys.exit(1) From 5680a258c4d8b1cc4426d7897945543dc4fc7cd3 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 1 Sep 2023 14:39:14 +0800 Subject: [PATCH 11/18] add canceled case after exit the while loop Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index d354278f2a..fe7020db1e 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -215,6 +215,8 @@ def execute(self, **kwargs) -> typing.Any: logger.info(f"Task state: {state}") if agent.asynchronous: loop.close() + if self.is_canceled: + sys.exit(1) if state != SUCCEEDED: raise Exception(f"Failed to run the task {entity.name}") From d09eec9ab5769a822df02b48ff91b3bcc9393742 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 1 Sep 2023 15:14:58 +0800 Subject: [PATCH 12/18] remove useless async signal exit case Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index fe7020db1e..d354278f2a 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -215,8 +215,6 @@ def execute(self, **kwargs) -> typing.Any: logger.info(f"Task state: {state}") if agent.asynchronous: loop.close() - if self.is_canceled: - sys.exit(1) if state != SUCCEEDED: raise Exception(f"Failed to run the task {entity.name}") From a8b895205756c9d43d554677ec7e6c7f8856eaee Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 1 Sep 2023 17:08:24 -0700 Subject: [PATCH 13/18] Add agent error handler Signed-off-by: Kevin Su --- flytekit/core/error_handlers.py | 12 +++++ flytekit/extend/backend/agent_service.py | 14 ++---- flytekit/extend/backend/base_agent.py | 64 +++++------------------- 3 files changed, 27 insertions(+), 63 deletions(-) create mode 100644 flytekit/core/error_handlers.py diff --git a/flytekit/core/error_handlers.py b/flytekit/core/error_handlers.py new file mode 100644 index 0000000000..23a3d435f0 --- /dev/null +++ b/flytekit/core/error_handlers.py @@ -0,0 +1,12 @@ +import asyncio + +import grpc + +from flytekit.extend.backend.base_agent import AgentBase + + +def agent_error_handler(agent: AgentBase, context: grpc.ServicerContext, resource_meta: bytes): + if agent.asynchronous: + asyncio.run(agent.async_delete(context, resource_meta)) + else: + agent.delete(context, resource_meta) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 470bd01e2e..9c294e5fae 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -2,14 +2,12 @@ import grpc from flyteidl.admin.agent_pb2 import ( - PERMANENT_FAILURE, CreateTaskRequest, CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, GetTaskRequest, GetTaskResponse, - Resource, ) from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer @@ -24,10 +22,8 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon try: tmp = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(context, tmp.type) + agent = AgentRegistry.get_agent(tmp.type) logger.info(f"{tmp.type} agent start creating the job") - if agent is None: - return CreateTaskResponse() if agent.asynchronous: try: return await agent.async_create( @@ -50,10 +46,8 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: try: - agent = AgentRegistry.get_agent(context, request.task_type) + agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.task_type} agent start checking the status of the job") - if agent is None: - return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) if agent.asynchronous: try: return await agent.async_get(context=context, resource_meta=request.resource_meta) @@ -72,10 +66,8 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: try: - agent = AgentRegistry.get_agent(context, request.task_type) + agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.task_type} agent start deleting the job") - if agent is None: - return DeleteTaskResponse() if agent.asynchronous: try: return await agent.async_delete(context=context, resource_meta=request.resource_meta) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index d354278f2a..e3035cedde 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,13 +1,9 @@ import asyncio -import signal -import sys +import atexit import time import typing from abc import ABC -from asyncio import AbstractEventLoop from collections import OrderedDict -from functools import partial -from types import FrameType import grpc from flyteidl.admin.agent_pb2 import ( @@ -26,6 +22,7 @@ from flytekit import FlyteContext, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask +from flytekit.core.error_handlers import agent_error_handler from flytekit.core.type_engine import TypeEngine from flytekit.models.literals import LiteralMap @@ -123,12 +120,9 @@ def register(agent: AgentBase): logger.info(f"Registering an agent for task type {agent.task_type}") @staticmethod - def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]: + def get_agent(task_type: str) -> typing.Optional[AgentBase]: if task_type not in AgentRegistry._REGISTRY: - logger.error(f"Cannot find agent for task type [{task_type}]") - context.set_code(grpc.StatusCode.NOT_FOUND) - context.set_details(f"Cannot find the agent for task type [{task_type}]") - return None + raise ValueError(f"Cannot find an agent for task type {task_type}") return AgentRegistry._REGISTRY[task_type] @@ -159,8 +153,6 @@ class AsyncAgentExecutorMixin: Task should inherit from this class if the task can be run in the agent. """ - is_canceled = False - def execute(self, **kwargs) -> typing.Any: from unittest.mock import MagicMock @@ -168,12 +160,10 @@ def execute(self, **kwargs) -> typing.Any: entity = typing.cast(PythonTask, self) m: OrderedDict = OrderedDict() - dummy_context = MagicMock(spec=grpc.ServicerContext) + grpc_context = MagicMock(spec=grpc.ServicerContext) cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) - agent = AgentRegistry.get_agent(dummy_context, cp_entity.template.type) + agent = AgentRegistry.get_agent(cp_entity.template.type) - if agent is None: - raise Exception("Cannot run the task locally, please mock.") literals = {} ctx = FlyteContext.current_context() for k, v in kwargs.items(): @@ -182,21 +172,10 @@ def execute(self, **kwargs) -> typing.Any: output_prefix = ctx.file_access.get_random_local_directory() cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) if agent.asynchronous: - loop = asyncio.new_event_loop() - res = loop.run_until_complete(agent.async_create(dummy_context, output_prefix, cp_entity.template, inputs)) - loop.add_signal_handler( - signal.SIGINT, - partial( - self.signal_handler, - agent=agent, - context=dummy_context, - resource_meta=res.resource_meta, - loop=loop, - ), - ) + res = asyncio.run(agent.async_create(grpc_context, output_prefix, cp_entity.template, inputs)) + atexit.register(agent_error_handler, agent, grpc_context, res.resource_meta) else: - res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs) - signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta, None)) + res = agent.create(grpc_context, output_prefix, cp_entity.template, inputs) state = RUNNING metadata = res.resource_meta progress = Progress(transient=True) @@ -206,33 +185,14 @@ def execute(self, **kwargs) -> typing.Any: progress.start_task(task) time.sleep(1) if agent.asynchronous: - res = loop.run_until_complete(agent.async_get(dummy_context, metadata)) - if self.is_canceled and len(asyncio.all_tasks(loop=loop)) == 0: - sys.exit(1) + res = asyncio.run(agent.async_get(grpc_context, metadata)) else: - res = agent.get(dummy_context, metadata) + res = agent.get(grpc_context, metadata) state = res.resource.state logger.info(f"Task state: {state}") - if agent.asynchronous: - loop.close() + atexit.unregister(agent_error_handler) if state != SUCCEEDED: raise Exception(f"Failed to run the task {entity.name}") return LiteralMap.from_flyte_idl(res.resource.outputs) - - def signal_handler( - self, - agent: AgentBase, - context: grpc.ServicerContext, - resource_meta: bytes, - loop: typing.Optional[AbstractEventLoop] = None, - signum: typing.Optional[int] = None, - frame: typing.Optional[FrameType] = None, - ) -> typing.Any: - if agent.asynchronous: - loop.create_task(agent.async_delete(context, resource_meta)) - self.is_canceled = True - else: - agent.delete(context, resource_meta) - sys.exit(1) From ae542dd35ea8543b40e8388027acb3492272b7e7 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 2 Sep 2023 17:47:13 -0700 Subject: [PATCH 14/18] update base agent Signed-off-by: Kevin Su --- flytekit/core/error_handlers.py | 12 --- flytekit/extend/backend/base_agent.py | 89 +++++++++++++------ plugins/flytekit-bigquery/tests/test_agent.py | 2 +- tests/flytekit/unit/extend/test_agent.py | 2 +- 4 files changed, 64 insertions(+), 41 deletions(-) delete mode 100644 flytekit/core/error_handlers.py diff --git a/flytekit/core/error_handlers.py b/flytekit/core/error_handlers.py deleted file mode 100644 index 23a3d435f0..0000000000 --- a/flytekit/core/error_handlers.py +++ /dev/null @@ -1,12 +0,0 @@ -import asyncio - -import grpc - -from flytekit.extend.backend.base_agent import AgentBase - - -def agent_error_handler(agent: AgentBase, context: grpc.ServicerContext, resource_meta: bytes): - if agent.asynchronous: - asyncio.run(agent.async_delete(context, resource_meta)) - else: - agent.delete(context, resource_meta) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index e3035cedde..c996063e64 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,9 +1,12 @@ import asyncio -import atexit +import signal +import sys import time import typing from abc import ABC from collections import OrderedDict +from functools import partial +from types import FrameType import grpc from flyteidl.admin.agent_pb2 import ( @@ -22,7 +25,6 @@ from flytekit import FlyteContext, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask -from flytekit.core.error_handlers import agent_error_handler from flytekit.core.type_engine import TypeEngine from flytekit.models.literals import LiteralMap @@ -122,7 +124,7 @@ def register(agent: AgentBase): @staticmethod def get_agent(task_type: str) -> typing.Optional[AgentBase]: if task_type not in AgentRegistry._REGISTRY: - raise ValueError(f"Cannot find an agent for task type {task_type}") + raise ValueError(f"Unrecognized task type {task_type}") return AgentRegistry._REGISTRY[task_type] @@ -153,46 +155,79 @@ class AsyncAgentExecutorMixin: Task should inherit from this class if the task can be run in the agent. """ - def execute(self, **kwargs) -> typing.Any: - from unittest.mock import MagicMock + _is_canceled = None + _agent = None + _entity = None + def execute(self, **kwargs) -> typing.Any: from flytekit.tools.translator import get_serializable - entity = typing.cast(PythonTask, self) - m: OrderedDict = OrderedDict() - grpc_context = MagicMock(spec=grpc.ServicerContext) - cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) - agent = AgentRegistry.get_agent(cp_entity.template.type) + self._entity = typing.cast(PythonTask, self) + task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template + self._agent = AgentRegistry.get_agent(task_template.type) - literals = {} + res = asyncio.run(self._create(task_template, kwargs)) + res = asyncio.run(self._get(resource_meta=res.resource_meta)) + + if res.resource.state != SUCCEEDED: + raise Exception(f"Failed to run the task {self._entity.name}") + + return LiteralMap.from_flyte_idl(res.resource.outputs) + + async def _create( + self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None + ) -> CreateTaskResponse: ctx = FlyteContext.current_context() - for k, v in kwargs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type) + grpc_ctx = _get_grpc_context() + + # Convert python inputs to literals + literals = {} + for k, v in inputs.items(): + literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) inputs = LiteralMap(literals) if literals else None output_prefix = ctx.file_access.get_random_local_directory() - cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) - if agent.asynchronous: - res = asyncio.run(agent.async_create(grpc_context, output_prefix, cp_entity.template, inputs)) - atexit.register(agent_error_handler, agent, grpc_context, res.resource_meta) + + if self._agent.asynchronous: + res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs) else: - res = agent.create(grpc_context, output_prefix, cp_entity.template, inputs) + res = await asyncio.to_thread(self._agent.create, grpc_ctx, output_prefix, task_template, inputs) + + signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore + return res + + async def _get(self, resource_meta: bytes) -> GetTaskResponse: state = RUNNING - metadata = res.resource_meta + grpc_ctx = _get_grpc_context() + progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Running Task {entity.name}...", total=None) + task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) with progress: while not is_terminal_state(state): progress.start_task(task) time.sleep(1) - if agent.asynchronous: - res = asyncio.run(agent.async_get(grpc_context, metadata)) + if self._agent.asynchronous: + res = await self._agent.async_get(grpc_ctx, resource_meta) + if self._is_canceled: + await self._is_canceled + sys.exit(1) else: - res = agent.get(grpc_context, metadata) + res = await asyncio.to_thread(self._agent.get, grpc_ctx, resource_meta) state = res.resource.state logger.info(f"Task state: {state}") + return res - atexit.unregister(agent_error_handler) - if state != SUCCEEDED: - raise Exception(f"Failed to run the task {entity.name}") + def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: + grpc_ctx = _get_grpc_context() + if self._agent.asynchronous: + if self._is_canceled is None: + self._is_canceled = asyncio.create_task(self._agent.async_delete(grpc_ctx, resource_meta)) + else: + self._agent.delete(grpc_ctx, resource_meta) + sys.exit(1) - return LiteralMap.from_flyte_idl(res.resource.outputs) + +def _get_grpc_context(): + from unittest.mock import MagicMock + + grpc_ctx = MagicMock(spec=grpc.ServicerContext) + return grpc_ctx diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index 16b5b7af4d..af53f4031d 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -44,7 +44,7 @@ def __init__(self): mock_instance.cancel_job.return_value = MockJob() ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent(ctx, "bigquery_query_job_task") + agent = AgentRegistry.get_agent("bigquery_query_job_task") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index bf1db6e333..9d0415fb5c 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -97,7 +97,7 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT def test_dummy_agent(): ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent(ctx, "dummy") + agent = AgentRegistry.get_agent("dummy") metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED From e5958bfa943da4f4e24d4addc2fae5ecc634420b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 2 Sep 2023 18:03:44 -0700 Subject: [PATCH 15/18] nit Signed-off-by: Kevin Su --- tests/flytekit/unit/extend/test_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 9d0415fb5c..7623396fee 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -114,7 +114,7 @@ def __init__(self, **kwargs): t.execute() t._task_type = "non-exist-type" - with pytest.raises(Exception, match="Cannot find the agent for the task"): + with pytest.raises(Exception, match="Unrecognized task type non-exist-type"): t.execute() From 98bad06273214f1465ca2995c4cce3b72e4c6d86 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 2 Sep 2023 18:06:53 -0700 Subject: [PATCH 16/18] nit Signed-off-by: Kevin Su --- flytekit/extend/backend/base_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index c996063e64..02a010829c 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -190,7 +190,7 @@ async def _create( if self._agent.asynchronous: res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs) else: - res = await asyncio.to_thread(self._agent.create, grpc_ctx, output_prefix, task_template, inputs) + res = self._agent.create(grpc_ctx, output_prefix, task_template, inputs) signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore return res @@ -211,7 +211,7 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: await self._is_canceled sys.exit(1) else: - res = await asyncio.to_thread(self._agent.get, grpc_ctx, resource_meta) + res = self._agent.get(grpc_ctx, resource_meta) state = res.resource.state logger.info(f"Task state: {state}") return res From f89c25fe237e20a118e4501af21fa93306ee678d Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 6 Sep 2023 09:24:11 +0800 Subject: [PATCH 17/18] add convert_to_flyte_state test Signed-off-by: Future Outlier --- tests/flytekit/unit/extend/test_agent.py | 25 +++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 7623396fee..b763a7e402 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -9,6 +9,7 @@ import pytest from flyteidl.admin.agent_pb2 import ( PERMANENT_FAILURE, + RETRYABLE_FAILURE, RUNNING, SUCCEEDED, CreateTaskRequest, @@ -23,7 +24,13 @@ import flytekit.models.interface as interface_models from flytekit import PythonFunctionTask from flytekit.extend.backend.agent_service import AsyncAgentService -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, AsyncAgentExecutorMixin, is_terminal_state +from flytekit.extend.backend.base_agent import ( + AgentBase, + AgentRegistry, + AsyncAgentExecutorMixin, + convert_to_flyte_state, + is_terminal_state, +) from flytekit.models import literals, task, types from flytekit.models.core.identifier import Identifier, ResourceType from flytekit.models.literals import LiteralMap @@ -147,3 +154,19 @@ def test_is_terminal_state(): assert is_terminal_state(PERMANENT_FAILURE) assert is_terminal_state(PERMANENT_FAILURE) assert not is_terminal_state(RUNNING) + + +def test_convert_to_flyte_state(): + assert convert_to_flyte_state("FAILED") == RETRYABLE_FAILURE + assert convert_to_flyte_state("TIMEDOUT") == RETRYABLE_FAILURE + assert convert_to_flyte_state("CANCELED") == RETRYABLE_FAILURE + + assert convert_to_flyte_state("DONE") == SUCCEEDED + assert convert_to_flyte_state("SUCCEEDED") == SUCCEEDED + assert convert_to_flyte_state("SUCCESS") == SUCCEEDED + + assert convert_to_flyte_state("RUNNING") == RUNNING + + invalid_state = "INVALID_STATE" + with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"): + convert_to_flyte_state(invalid_state) From 1155c16478d687fad2dfafe82ac5f6fbc9512e27 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 6 Sep 2023 09:29:11 +0800 Subject: [PATCH 18/18] convert_to_flyte_state test function Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 02a010829c..50574e67b1 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -133,9 +133,9 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() - if state in ["failed"]: + if state in ["failed", "timedout", "canceled"]: return RETRYABLE_FAILURE - elif state in ["done", "succeeded"]: + elif state in ["done", "succeeded", "success"]: return SUCCEEDED elif state in ["running"]: return RUNNING