diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 38b7668975..84a9cd1301 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -815,7 +815,7 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]: return str -def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any: +def dataclass_from_dict(cls: dataclasses, src: typing.Dict[str, typing.Any]) -> typing.Any: """ Utility function to construct a dataclass object from dict """ @@ -823,8 +823,20 @@ def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing. constructor_inputs = {} for field_name, value in src.items(): - if dataclasses.is_dataclass(field_types_lookup[field_name]): + field_type = field_types_lookup[field_name] + if dataclasses.is_dataclass(field_type): constructor_inputs[field_name] = dataclass_from_dict(field_types_lookup[field_name], value) + elif hasattr(field_type, "__origin__"): + if field_type.__origin__ is list and dataclasses.is_dataclass(ListTransformer.get_sub_type(field_type)): + t = ListTransformer.get_sub_type(field_type) + constructor_inputs[field_name] = [dataclass_from_dict(t, v) for v in value] + elif field_type.__origin__ is dict and dataclasses.is_dataclass( + DictTransformer.get_dict_types(field_type)[1] + ): + t = DictTransformer.get_dict_types(field_type)[1] + constructor_inputs[field_name] = {k: dataclass_from_dict(t, v) for k, v in value.items()} + else: + constructor_inputs[field_name] = value else: constructor_inputs[field_name] = value diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 5032184038..474358b722 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -3,6 +3,10 @@ import os import pathlib import typing +from dataclasses import dataclass, field + +from dataclasses_json import config, dataclass_json +from marshmallow import fields from flytekit.core.context_manager import FlyteContext from flytekit.core.type_engine import TypeEngine, TypeTransformer @@ -18,7 +22,12 @@ def noop(): T = typing.TypeVar("T") +@dataclass_json +@dataclass class FlyteFile(os.PathLike, typing.Generic[T]): + path: str = field(metadata=config(mm_field=fields.String())) + remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) + """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int exists for Flyte's Integer type) we need to create one so that users can express that their tasks take @@ -161,10 +170,10 @@ def __init__(self, path: str, downloader: typing.Callable = noop, remote_path=No until a user actually calls open(). :param remote_path: If the user wants to return something and also specify where it should be uploaded to. """ - self._path = path + self.path = path + self.remote_path = remote_path self._downloader = downloader self._downloaded = False - self._remote_path = remote_path self._remote_source = None def __fspath__(self): @@ -172,30 +181,22 @@ def __fspath__(self): if not self._downloaded: self._downloader() self._downloaded = True - return self._path + return self.path def __eq__(self, other): if isinstance(other, FlyteFile): return ( - self._path == other._path - and self._remote_path == other._remote_path + self.path == other.path + and self.remote_path == other.remote_path and self.extension() == other.extension() ) else: - return self._path == other + return self.path == other @property def downloaded(self) -> bool: return self._downloaded - @property - def remote_path(self) -> typing.Optional[str]: - return self._remote_path - - @property - def path(self) -> str: - return self._path - @property def remote_source(self) -> str: """ @@ -215,10 +216,10 @@ def download(self) -> str: raise ValueError(f"Attempting to trigger download on non-downloadable file {self}") def __repr__(self): - return self._path + return self.path def __str__(self): - return self._path + return self.path class FlyteFilePathTransformer(TypeTransformer[FlyteFile]): diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 3912cae4b3..675c193d08 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -27,7 +27,7 @@ from flytekit.models.core.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar from flytekit.models.core.types import BlobType, LiteralType, SimpleType from flytekit.types.directory.types import FlyteDirectory -from flytekit.types.file import JPEGImageFile +from flytekit.types.file import JPEGImageFile, PNGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer @@ -500,6 +500,38 @@ def test_dataclass_int_preserving(): assert ot == o +@dataclass_json +@dataclass +class TestInnerFileStruct(object): + a: PNGImageFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + + +@dataclass_json +@dataclass +class TestFileStruct(object): + a: FlyteFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + d: TestInnerFileStruct + + +def test_flyte_file_in_dataclass(): + f = FlyteFile("s3://tmp/file.jpeg") + o = TestFileStruct( + a=f, b=[f], c={"hello": f}, d=TestInnerFileStruct(a=PNGImageFile("s3://tmp/file.png"), b=[f], c={"hello": f}) + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + gt = tf.guess_python_type(lt) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=gt) + assert o == dataclass_from_dict(TestFileStruct, asdict(typing.cast(dataclass, ot))) + + # Enums should have string values class Color(Enum): RED = "red" diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index dacfa392bd..7ad47b49dc 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -31,6 +31,7 @@ from flytekit.models.core.interface import Parameter from flytekit.models.core.task import Resources as _resource_models from flytekit.models.core.types import LiteralType, SimpleType +from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema, SchemaOpenMode serialization_settings = context_manager.SerializationSettings( @@ -347,6 +348,36 @@ def test_user_demo_test(mock_sql): assert context_manager.FlyteContextManager.size() == 1 +def test_flytefile_in_dataclass(): + @dataclass_json + @dataclass + class InnerFileStruct(object): + a: FlyteFile + + @dataclass_json + @dataclass + class FileStruct(object): + a: FlyteFile + b: InnerFileStruct + + @task + def t1(path: str) -> FileStruct: + file = FlyteFile(path) + fs = FileStruct(a=file, b=InnerFileStruct(a=file)) + return fs + + @task + def t2(fs: FileStruct) -> str: + return fs.a.path + + @workflow + def wf(path: str) -> str: + n1 = t1(path=path) + return t2(fs=n1) + + assert wf(path="/tmp/demo.txt") == "/tmp/demo.txt" + + def test_wf1_with_map(): @task def t1(a: int) -> int: