Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,42 @@

import os
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import count
from typing import Any, Dict, List, Optional, Type

from dataclasses_json import dataclass_json

from flytekit.core.base_task import PythonTask
from flytekit.core.constants import SdkTaskType
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, SerializationSettings
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.exceptions import scopes as exception_scopes
from flytekit.models.array_job import ArrayJob
from flytekit.models.interface import Variable
from flytekit.models.task import Container, K8sPod, Sql


@dataclass_json
@dataclass
class ArrayJob:
"""
Initializes a new ArrayJob.
:param int parallelism: Defines the minimum number of instances to bring up concurrently at any given point.
:param int size: Defines the number of instances to launch at most. This number should match the size of
the input if the job requires processing of all input data. This has to be a positive number.
:param int min_successes: An absolute number of the minimum number of successful completions of subtasks. As
soon as this criteria is met, the array job will be marked as successful and outputs will be computed.
:param float min_success_ratio: Determines the minimum fraction of total jobs which can complete successfully
before terminating the job and marking it successful.
"""

parallelism: Optional[int] = None
size: Optional[int] = None
min_successes: Optional[int] = None
min_success_ratio: Optional[float] = None


class MapPythonTask(PythonTask):
"""
A MapPythonTask defines a :py:class:`flytekit.PythonTask` which specifies how to run
Expand Down Expand Up @@ -109,11 +131,11 @@ def get_sql(self, settings: SerializationSettings) -> Sql:
with self.prepare_target():
return self._run_task.get_sql(settings)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return ArrayJob(parallelism=self._max_concurrency, min_success_ratio=self._min_success_ratio).to_dict()

def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
return self._run_task.get_config(settings)
array_job = ArrayJob(parallelism=self._max_concurrency, min_success_ratio=self._min_success_ratio).to_dict()
if self._run_task.get_config(settings) is not None:
array_job.update(self._run_task.get_config(settings))
return {str(key): str(value) for key, value in array_job.items()}

@property
def run_task(self) -> PythonTask:
Expand Down
5 changes: 5 additions & 0 deletions flytekit/models/array_job.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import json as _json

from deprecated import deprecated as _deprecated
from flyteidl.plugins import array_job_pb2 as _array_job
from google.protobuf import json_format as _json_format

from flytekit.models import common as _common


@_deprecated(
reason="Use ArrayJob in flytekit.core.map_task instead",
version="0.25.0",
)
class ArrayJob(_common.FlyteCustomIdlEntity):
def __init__(self, parallelism=None, size=None, min_successes=None, min_success_ratio=None):
"""
Expand Down
8 changes: 7 additions & 1 deletion plugins/flytekit-k8s-pod/tests/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,13 @@ def simple_pod_task(i: int):
"task-name",
"simple_pod_task",
]
assert {"primary_container_name": "primary"} == mapped_task.get_config(serialization_settings)
assert {
"min_success_ratio": "None",
"min_successes": "None",
"parallelism": "None",
"primary_container_name": "primary",
"size": "None",
} == mapped_task.get_config(serialization_settings)


def test_fast_pod_task_serialization():
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/models/test_dynamic_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from google.protobuf import text_format

from flytekit.models import array_job as _array_job
from flytekit.core.map_task import ArrayJob
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import literals as _literals
from flytekit.models import task as _task
Expand All @@ -18,7 +18,7 @@
"python",
task_metadata,
interfaces,
_array_job.ArrayJob(2, 2, 2).to_dict(),
ArrayJob(2, 2, 2).to_dict(),
container=_task.Container(
"my_image",
["this", "is", "a", "cmd"],
Expand Down