From c6b9a7a054b3dcd8b707f36a405ebfddaa137d2a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 20 Oct 2021 18:46:08 +0800 Subject: [PATCH 1/5] Add support FlyteFile in dataclass Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 14 +++++++-- flytekit/types/file/file.py | 33 ++++++++++---------- tests/flytekit/unit/core/test_type_engine.py | 30 ++++++++++++++++++ 3 files changed, 59 insertions(+), 18 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 38b7668975..c56fba91a9 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -815,7 +815,7 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]: return str -def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any: +def dataclass_from_dict(cls: dataclasses, src: typing.Dict[str, typing.Any]) -> typing.Any: """ Utility function to construct a dataclass object from dict """ @@ -823,8 +823,18 @@ def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing. constructor_inputs = {} for field_name, value in src.items(): - if dataclasses.is_dataclass(field_types_lookup[field_name]): + field_type = field_types_lookup[field_name] + if dataclasses.is_dataclass(field_type): constructor_inputs[field_name] = dataclass_from_dict(field_types_lookup[field_name], value) + elif hasattr(field_type, "__origin__"): + if field_type.__origin__ is list and dataclasses.is_dataclass(ListTransformer.get_sub_type(field_type)): + t = ListTransformer.get_sub_type(field_type) + constructor_inputs[field_name] = [dataclass_from_dict(t, v) for v in value] + elif field_type.__origin__ is dict and dataclasses.is_dataclass( + DictTransformer.get_dict_types(field_type)[1] + ): + t = DictTransformer.get_dict_types(field_type)[1] + constructor_inputs[field_name] = {k: dataclass_from_dict(t, v) for k, v in value.items()} else: constructor_inputs[field_name] = value diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 5032184038..8bed962a44 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -3,6 +3,10 @@ import os import pathlib import typing +from dataclasses import dataclass, field + +from dataclasses_json import config, dataclass_json +from marshmallow import Schema, fields from flytekit.core.context_manager import FlyteContext from flytekit.core.type_engine import TypeEngine, TypeTransformer @@ -18,7 +22,12 @@ def noop(): T = typing.TypeVar("T") +@dataclass_json +@dataclass class FlyteFile(os.PathLike, typing.Generic[T]): + path: str = field(metadata=config(mm_field=fields.String())) + remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) + """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int exists for Flyte's Integer type) we need to create one so that users can express that their tasks take @@ -161,10 +170,10 @@ def __init__(self, path: str, downloader: typing.Callable = noop, remote_path=No until a user actually calls open(). :param remote_path: If the user wants to return something and also specify where it should be uploaded to. """ - self._path = path + self.path = path + self.remote_path = remote_path self._downloader = downloader self._downloaded = False - self._remote_path = remote_path self._remote_source = None def __fspath__(self): @@ -172,30 +181,22 @@ def __fspath__(self): if not self._downloaded: self._downloader() self._downloaded = True - return self._path + return self.path def __eq__(self, other): if isinstance(other, FlyteFile): return ( - self._path == other._path - and self._remote_path == other._remote_path + self.path == other.path + and self.remote_path == other.remote_path and self.extension() == other.extension() ) else: - return self._path == other + return self.path == other @property def downloaded(self) -> bool: return self._downloaded - @property - def remote_path(self) -> typing.Optional[str]: - return self._remote_path - - @property - def path(self) -> str: - return self._path - @property def remote_source(self) -> str: """ @@ -215,10 +216,10 @@ def download(self) -> str: raise ValueError(f"Attempting to trigger download on non-downloadable file {self}") def __repr__(self): - return self._path + return self.path def __str__(self): - return self._path + return self.path class FlyteFilePathTransformer(TypeTransformer[FlyteFile]): diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 3912cae4b3..dce00373a4 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -500,6 +500,36 @@ def test_dataclass_int_preserving(): assert ot == o +@dataclass_json +@dataclass +class TestInnerFileStruct(object): + a: FlyteFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + + +@dataclass_json +@dataclass +class TestFileStruct(object): + a: FlyteFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + d: TestInnerFileStruct + + +def test_flyte_file_in_dataclass(): + f = FlyteFile("s3://tmp/file.jpeg") + o = TestFileStruct(a=f, b=[f], c={"hello": f}, d=TestInnerFileStruct(a=f, b=[f], c={"hello": f})) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + gt = tf.guess_python_type(lt) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=gt) + assert o == dataclass_from_dict(TestFileStruct, asdict(typing.cast(dataclass, ot))) + + # Enums should have string values class Color(Enum): RED = "red" From be078fd2ce28db0783a77d4632898ec84a6b6b5f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 20 Oct 2021 19:04:08 +0800 Subject: [PATCH 2/5] Fixed test Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index c56fba91a9..84a9cd1301 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -835,6 +835,8 @@ def dataclass_from_dict(cls: dataclasses, src: typing.Dict[str, typing.Any]) -> ): t = DictTransformer.get_dict_types(field_type)[1] constructor_inputs[field_name] = {k: dataclass_from_dict(t, v) for k, v in value.items()} + else: + constructor_inputs[field_name] = value else: constructor_inputs[field_name] = value From 5b53447b0bbce321d80d859ed02e08419d6a2f28 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 20 Oct 2021 19:16:02 +0800 Subject: [PATCH 3/5] Fixed lint Signed-off-by: Kevin Su --- flytekit/types/file/file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 8bed962a44..474358b722 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from dataclasses_json import config, dataclass_json -from marshmallow import Schema, fields +from marshmallow import fields from flytekit.core.context_manager import FlyteContext from flytekit.core.type_engine import TypeEngine, TypeTransformer From 243d282ace2bb96360cacf886b94617404cb84cf Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 26 Oct 2021 17:32:23 +0800 Subject: [PATCH 4/5] Added png file Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_type_engine.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index dce00373a4..675c193d08 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -27,7 +27,7 @@ from flytekit.models.core.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar from flytekit.models.core.types import BlobType, LiteralType, SimpleType from flytekit.types.directory.types import FlyteDirectory -from flytekit.types.file import JPEGImageFile +from flytekit.types.file import JPEGImageFile, PNGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer @@ -503,7 +503,7 @@ def test_dataclass_int_preserving(): @dataclass_json @dataclass class TestInnerFileStruct(object): - a: FlyteFile + a: PNGImageFile b: typing.List[FlyteFile] c: typing.Dict[str, FlyteFile] @@ -519,7 +519,9 @@ class TestFileStruct(object): def test_flyte_file_in_dataclass(): f = FlyteFile("s3://tmp/file.jpeg") - o = TestFileStruct(a=f, b=[f], c={"hello": f}, d=TestInnerFileStruct(a=f, b=[f], c={"hello": f})) + o = TestFileStruct( + a=f, b=[f], c={"hello": f}, d=TestInnerFileStruct(a=PNGImageFile("s3://tmp/file.png"), b=[f], c={"hello": f}) + ) ctx = FlyteContext.current_context() tf = DataclassTransformer() From 862248b5cbce4e197fc94500a8a9e1974f83f94c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 28 Oct 2021 20:34:26 +0800 Subject: [PATCH 5/5] Added tests Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_type_hints.py | 31 +++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index dacfa392bd..7ad47b49dc 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -31,6 +31,7 @@ from flytekit.models.core.interface import Parameter from flytekit.models.core.task import Resources as _resource_models from flytekit.models.core.types import LiteralType, SimpleType +from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema, SchemaOpenMode serialization_settings = context_manager.SerializationSettings( @@ -347,6 +348,36 @@ def test_user_demo_test(mock_sql): assert context_manager.FlyteContextManager.size() == 1 +def test_flytefile_in_dataclass(): + @dataclass_json + @dataclass + class InnerFileStruct(object): + a: FlyteFile + + @dataclass_json + @dataclass + class FileStruct(object): + a: FlyteFile + b: InnerFileStruct + + @task + def t1(path: str) -> FileStruct: + file = FlyteFile(path) + fs = FileStruct(a=file, b=InnerFileStruct(a=file)) + return fs + + @task + def t2(fs: FileStruct) -> str: + return fs.a.path + + @workflow + def wf(path: str) -> str: + n1 = t1(path=path) + return t2(fs=n1) + + assert wf(path="/tmp/demo.txt") == "/tmp/demo.txt" + + def test_wf1_with_map(): @task def t1(a: int) -> int: