Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,16 +815,28 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]:
return str


def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any:
def dataclass_from_dict(cls: dataclasses, src: typing.Dict[str, typing.Any]) -> typing.Any:
"""
Utility function to construct a dataclass object from dict
"""
field_types_lookup = {field.name: field.type for field in dataclasses.fields(cls)}

constructor_inputs = {}
for field_name, value in src.items():
if dataclasses.is_dataclass(field_types_lookup[field_name]):
field_type = field_types_lookup[field_name]
if dataclasses.is_dataclass(field_type):
constructor_inputs[field_name] = dataclass_from_dict(field_types_lookup[field_name], value)
elif hasattr(field_type, "__origin__"):
if field_type.__origin__ is list and dataclasses.is_dataclass(ListTransformer.get_sub_type(field_type)):
t = ListTransformer.get_sub_type(field_type)
constructor_inputs[field_name] = [dataclass_from_dict(t, v) for v in value]
elif field_type.__origin__ is dict and dataclasses.is_dataclass(
DictTransformer.get_dict_types(field_type)[1]
):
t = DictTransformer.get_dict_types(field_type)[1]
constructor_inputs[field_name] = {k: dataclass_from_dict(t, v) for k, v in value.items()}
else:
constructor_inputs[field_name] = value
else:
constructor_inputs[field_name] = value

Expand Down
33 changes: 17 additions & 16 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import os
import pathlib
import typing
from dataclasses import dataclass, field

from dataclasses_json import config, dataclass_json
from marshmallow import fields

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer
Expand All @@ -18,7 +22,12 @@ def noop():
T = typing.TypeVar("T")


@dataclass_json
@dataclass
class FlyteFile(os.PathLike, typing.Generic[T]):
path: str = field(metadata=config(mm_field=fields.String()))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to specify metadata?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I remove it, will get the below error.

{"asctime": "2021-10-28 20:12:39,540", "name": "flytekit", "levelname": "WARNING", "message": 
"failed to extract schema for object <class 'test_type_engine.TestFileStruct'>, (will run 
schemaless) error: unsupported field type <fields.Field(dump_default=<marshmallow.missing>, 
attribute=None, validate=None, required=False, load_only=False, dump_only=False, load_default=
<marshmallow.missing>, allow_none=False, error_messages={'required': 'Missing data for required 
field.', 'null': 'Field may not be null.', 'validator_failed': 'Invalid value.'})>"}

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very wired, the error only happens with Flyte type.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where were you seeing this? Was it in the test_flyte_file_in_dataclass test?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String()))

"""
Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int
exists for Flyte's Integer type) we need to create one so that users can express that their tasks take
Expand Down Expand Up @@ -161,41 +170,33 @@ def __init__(self, path: str, downloader: typing.Callable = noop, remote_path=No
until a user actually calls open().
:param remote_path: If the user wants to return something and also specify where it should be uploaded to.
"""
self._path = path
self.path = path
self.remote_path = remote_path
self._downloader = downloader
self._downloaded = False
self._remote_path = remote_path
self._remote_source = None

def __fspath__(self):
# This is where a delayed downloading of the file will happen
if not self._downloaded:
self._downloader()
self._downloaded = True
return self._path
return self.path

def __eq__(self, other):
if isinstance(other, FlyteFile):
return (
self._path == other._path
and self._remote_path == other._remote_path
self.path == other.path
and self.remote_path == other.remote_path
and self.extension() == other.extension()
)
else:
return self._path == other
return self.path == other

@property
def downloaded(self) -> bool:
return self._downloaded

@property
def remote_path(self) -> typing.Optional[str]:
return self._remote_path

@property
def path(self) -> str:
return self._path

@property
def remote_source(self) -> str:
"""
Expand All @@ -215,10 +216,10 @@ def download(self) -> str:
raise ValueError(f"Attempting to trigger download on non-downloadable file {self}")

def __repr__(self):
return self._path
return self.path

def __str__(self):
return self._path
return self.path


class FlyteFilePathTransformer(TypeTransformer[FlyteFile]):
Expand Down
34 changes: 33 additions & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from flytekit.models.core.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar
from flytekit.models.core.types import BlobType, LiteralType, SimpleType
from flytekit.types.directory.types import FlyteDirectory
from flytekit.types.file import JPEGImageFile
from flytekit.types.file import JPEGImageFile, PNGImageFile
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer


Expand Down Expand Up @@ -500,6 +500,38 @@ def test_dataclass_int_preserving():
assert ot == o


@dataclass_json
@dataclass
class TestInnerFileStruct(object):
a: PNGImageFile
b: typing.List[FlyteFile]
c: typing.Dict[str, FlyteFile]


@dataclass_json
@dataclass
class TestFileStruct(object):
a: FlyteFile
b: typing.List[FlyteFile]
c: typing.Dict[str, FlyteFile]
d: TestInnerFileStruct


def test_flyte_file_in_dataclass():
f = FlyteFile("s3://tmp/file.jpeg")
o = TestFileStruct(
a=f, b=[f], c={"hello": f}, d=TestInnerFileStruct(a=PNGImageFile("s3://tmp/file.png"), b=[f], c={"hello": f})
)

ctx = FlyteContext.current_context()
tf = DataclassTransformer()
lt = tf.get_literal_type(TestFileStruct)
gt = tf.guess_python_type(lt)
lv = tf.to_literal(ctx, o, TestFileStruct, lt)
ot = tf.to_python_value(ctx, lv=lv, expected_python_type=gt)
assert o == dataclass_from_dict(TestFileStruct, asdict(typing.cast(dataclass, ot)))

Comment thread
eapolinario marked this conversation as resolved.

# Enums should have string values
class Color(Enum):
RED = "red"
Expand Down
31 changes: 31 additions & 0 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from flytekit.models.core.interface import Parameter
from flytekit.models.core.task import Resources as _resource_models
from flytekit.models.core.types import LiteralType, SimpleType
from flytekit.types.file import FlyteFile
from flytekit.types.schema import FlyteSchema, SchemaOpenMode

serialization_settings = context_manager.SerializationSettings(
Expand Down Expand Up @@ -347,6 +348,36 @@ def test_user_demo_test(mock_sql):
assert context_manager.FlyteContextManager.size() == 1


def test_flytefile_in_dataclass():
@dataclass_json
@dataclass
class InnerFileStruct(object):
a: FlyteFile

@dataclass_json
@dataclass
class FileStruct(object):
a: FlyteFile
b: InnerFileStruct

@task
def t1(path: str) -> FileStruct:
file = FlyteFile(path)
fs = FileStruct(a=file, b=InnerFileStruct(a=file))
return fs

@task
def t2(fs: FileStruct) -> str:
return fs.a.path

@workflow
def wf(path: str) -> str:
n1 = t1(path=path)
return t2(fs=n1)

assert wf(path="/tmp/demo.txt") == "/tmp/demo.txt"


def test_wf1_with_map():
@task
def t1(a: int) -> int:
Expand Down