Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import grpc
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
CreateTaskRequest,
CreateTaskResponse,
DeleteTaskRequest,
DeleteTaskResponse,
GetTaskRequest,
GetTaskResponse,
Resource,
)
from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer

Expand All @@ -24,10 +22,8 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
try:
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(context, tmp.type)
agent = AgentRegistry.get_agent(tmp.type)
logger.info(f"{tmp.type} agent start creating the job")
if agent is None:
return CreateTaskResponse()
if agent.asynchronous:
try:
return await agent.async_create(
Expand All @@ -50,10 +46,8 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon

async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start checking the status of the job")
if agent is None:
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE))
if agent.asynchronous:
try:
return await agent.async_get(context=context, resource_meta=request.resource_meta)
Expand All @@ -72,10 +66,8 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext)

async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start deleting the job")
if agent is None:
return DeleteTaskResponse()
if agent.asynchronous:
try:
return await agent.async_delete(context=context, resource_meta=request.resource_meta)
Expand Down
105 changes: 60 additions & 45 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,9 @@ def register(agent: AgentBase):
logger.info(f"Registering an agent for task type {agent.task_type}")

@staticmethod
def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]:
def get_agent(task_type: str) -> typing.Optional[AgentBase]:
if task_type not in AgentRegistry._REGISTRY:
logger.error(f"Cannot find agent for task type [{task_type}]")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Cannot find the agent for task type [{task_type}]")
return None
raise ValueError(f"Unrecognized task type {task_type}")
return AgentRegistry._REGISTRY[task_type]


Expand All @@ -136,9 +133,9 @@ def convert_to_flyte_state(state: str) -> State:
Convert the state from the agent to the state in flyte.
"""
state = state.lower()
if state in ["failed"]:
if state in ["failed", "timedout", "canceled"]:
return RETRYABLE_FAILURE
elif state in ["done", "succeeded"]:
elif state in ["done", "succeeded", "success"]:
return SUCCEEDED
elif state in ["running"]:
return RUNNING
Expand All @@ -158,61 +155,79 @@ class AsyncAgentExecutorMixin:
Task should inherit from this class if the task can be run in the agent.
"""

def execute(self, **kwargs) -> typing.Any:
from unittest.mock import MagicMock
_is_canceled = None
_agent = None
_entity = None

def execute(self, **kwargs) -> typing.Any:
from flytekit.tools.translator import get_serializable

entity = typing.cast(PythonTask, self)
m: OrderedDict = OrderedDict()
dummy_context = MagicMock(spec=grpc.ServicerContext)
cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity)
agent = AgentRegistry.get_agent(dummy_context, cp_entity.template.type)
self._entity = typing.cast(PythonTask, self)
task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template
self._agent = AgentRegistry.get_agent(task_template.type)

if agent is None:
raise Exception("Cannot find the agent for the task")
literals = {}
res = asyncio.run(self._create(task_template, kwargs))
res = asyncio.run(self._get(resource_meta=res.resource_meta))

if res.resource.state != SUCCEEDED:
raise Exception(f"Failed to run the task {self._entity.name}")

return LiteralMap.from_flyte_idl(res.resource.outputs)

async def _create(
self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None
) -> CreateTaskResponse:
ctx = FlyteContext.current_context()
for k, v in kwargs.items():
literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type)
grpc_ctx = _get_grpc_context()

# Convert python inputs to literals
literals = {}
for k, v in inputs.items():
literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type)
inputs = LiteralMap(literals) if literals else None
output_prefix = ctx.file_access.get_random_local_directory()
cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity)
if agent.asynchronous:
res = asyncio.run(agent.async_create(dummy_context, output_prefix, cp_entity.template, inputs))

if self._agent.asynchronous:
res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs)
else:
res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs)
signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta))
res = self._agent.create(grpc_ctx, output_prefix, task_template, inputs)

signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore
return res

async def _get(self, resource_meta: bytes) -> GetTaskResponse:
state = RUNNING
metadata = res.resource_meta
grpc_ctx = _get_grpc_context()

progress = Progress(transient=True)
task = progress.add_task(f"[cyan]Running Task {entity.name}...", total=None)
task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None)
with progress:
while not is_terminal_state(state):
progress.start_task(task)
time.sleep(1)
if agent.asynchronous:
res = asyncio.run(agent.async_get(dummy_context, metadata))
if self._agent.asynchronous:
res = await self._agent.async_get(grpc_ctx, resource_meta)
if self._is_canceled:
await self._is_canceled
sys.exit(1)
else:
res = agent.get(dummy_context, metadata)
res = self._agent.get(grpc_ctx, resource_meta)
state = res.resource.state
logger.info(f"Task state: {state}")
return res

def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any:
grpc_ctx = _get_grpc_context()
if self._agent.asynchronous:
if self._is_canceled is None:
self._is_canceled = asyncio.create_task(self._agent.async_delete(grpc_ctx, resource_meta))
else:
self._agent.delete(grpc_ctx, resource_meta)
sys.exit(1)

if state != SUCCEEDED:
raise Exception(f"Failed to run the task {entity.name}")

return LiteralMap.from_flyte_idl(res.resource.outputs)
def _get_grpc_context():
from unittest.mock import MagicMock

def signal_handler(
self,
agent: AgentBase,
context: grpc.ServicerContext,
resource_meta: bytes,
signum: int,
frame: FrameType,
) -> typing.Any:
if agent.asynchronous:
asyncio.run(agent.async_delete(context, resource_meta))
else:
agent.delete(context, resource_meta)
sys.exit(1)
grpc_ctx = MagicMock(spec=grpc.ServicerContext)
return grpc_ctx
2 changes: 1 addition & 1 deletion plugins/flytekit-bigquery/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self):
mock_instance.cancel_job.return_value = MockJob()

ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent(ctx, "bigquery_query_job_task")
agent = AgentRegistry.get_agent("bigquery_query_job_task")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
Expand Down
29 changes: 26 additions & 3 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
CreateTaskRequest,
Expand All @@ -23,7 +24,13 @@
import flytekit.models.interface as interface_models
from flytekit import PythonFunctionTask
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, AsyncAgentExecutorMixin, is_terminal_state
from flytekit.extend.backend.base_agent import (
AgentBase,
AgentRegistry,
AsyncAgentExecutorMixin,
convert_to_flyte_state,
is_terminal_state,
)
from flytekit.models import literals, task, types
from flytekit.models.core.identifier import Identifier, ResourceType
from flytekit.models.literals import LiteralMap
Expand Down Expand Up @@ -97,7 +104,7 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT

def test_dummy_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent(ctx, "dummy")
agent = AgentRegistry.get_agent("dummy")
metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")
assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes
assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED
Expand All @@ -114,7 +121,7 @@ def __init__(self, **kwargs):
t.execute()

t._task_type = "non-exist-type"
with pytest.raises(Exception, match="Cannot find the agent for the task"):
with pytest.raises(Exception, match="Unrecognized task type non-exist-type"):
t.execute()


Expand Down Expand Up @@ -147,3 +154,19 @@ def test_is_terminal_state():
assert is_terminal_state(PERMANENT_FAILURE)
assert is_terminal_state(PERMANENT_FAILURE)
assert not is_terminal_state(RUNNING)


def test_convert_to_flyte_state():
assert convert_to_flyte_state("FAILED") == RETRYABLE_FAILURE
assert convert_to_flyte_state("TIMEDOUT") == RETRYABLE_FAILURE
assert convert_to_flyte_state("CANCELED") == RETRYABLE_FAILURE

assert convert_to_flyte_state("DONE") == SUCCEEDED
assert convert_to_flyte_state("SUCCEEDED") == SUCCEEDED
assert convert_to_flyte_state("SUCCESS") == SUCCEEDED

assert convert_to_flyte_state("RUNNING") == RUNNING

invalid_state = "INVALID_STATE"
with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"):
convert_to_flyte_state(invalid_state)