Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ class SdkTaskType(object):


GLOBAL_INPUT_NODE_ID = ''


class CloudProvider(object):
AWS = "aws"
GCP = "gcp"
5 changes: 5 additions & 0 deletions flytekit/configuration/gcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import absolute_import

from flytekit.configuration import common as _config_common

GCS_PREFIX = _config_common.FlyteRequiredStringConfigurationEntry('gcp', 'gcs_prefix')
4 changes: 4 additions & 0 deletions flytekit/configuration/platform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import absolute_import

from flytekit.configuration import common as _config_common
from flytekit.common import constants as _constants

URL = _config_common.FlyteRequiredStringConfigurationEntry('platform', 'url')
INSECURE = _config_common.FlyteBoolConfigurationEntry('platform', 'insecure', default=False)
CLOUD_PROVIDER = _config_common.FlyteStringConfigurationEntry(
'platform', 'cloud_provider', default=_constants.CloudProvider.AWS
)
30 changes: 25 additions & 5 deletions flytekit/interfaces/data/data_proxy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import absolute_import

from flytekit.configuration import sdk as _sdk_config
from flytekit.configuration import sdk as _sdk_config, platform as _platform_config
from flytekit.interfaces.data.s3 import s3proxy as _s3proxy
from flytekit.interfaces.data.gcs import gcs_proxy as _gcs_proxy
from flytekit.interfaces.data.local import local_file_proxy as _local_file_proxy
from flytekit.common.exceptions import user as _user_exception
from flytekit.common import utils as _common_utils
from flytekit.common import utils as _common_utils, constants as _constants
import six as _six


Expand Down Expand Up @@ -57,14 +58,33 @@ def __init__(self, sandbox):


class RemoteDataContext(_OutputDataContext):
def __init__(self):
super(RemoteDataContext, self).__init__(_s3proxy.AwsS3Proxy())

_CLOUD_PROVIDER_TO_PROXIES = {
_constants.CloudProvider.AWS: _s3proxy.AwsS3Proxy(),
_constants.CloudProvider.GCP: _gcs_proxy.GCSProxy(),
}

def __init__(self, cloud_provider=None):
"""
:param Optional[Text] cloud_provider: From flytekit.common.constants.CloudProvider enum
"""
cloud_provider = cloud_provider or _platform_config.CLOUD_PROVIDER.get()
proxy = type(self)._CLOUD_PROVIDER_TO_PROXIES.get(cloud_provider, None)
if proxy is None:
raise _user_exception.FlyteAssertion(
"Configured cloud provider is not supported for data I/O. Received: {}, expected one of: {}".format(
cloud_provider,
list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys())
)
)
super(RemoteDataContext, self).__init__(proxy)


class Data(object):
# TODO: More proxies for more environments.
_DATA_PROXIES = {
"s3:/": _s3proxy.AwsS3Proxy()
"s3:/": _s3proxy.AwsS3Proxy(),
"gs:/": _gcs_proxy.GCSProxy()
}

@classmethod
Expand Down
Empty file.
112 changes: 112 additions & 0 deletions flytekit/interfaces/data/gcs/gcs_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import absolute_import

import os as _os
import sys as _sys
import uuid as _uuid

from flytekit.configuration import gcp as _gcp_config
from flytekit.interfaces import random as _flyte_random
from flytekit.interfaces.data import common as _common_data
from flytekit.tools import subprocess as _subprocess
from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException


if _sys.version_info >= (3,):
from shutil import which as _which
else:
from distutils.spawn import find_executable as _which


def _update_cmd_config_and_execute(cmd):
env = _os.environ.copy()
return _subprocess.check_call(cmd, env=env)


class GCSProxy(_common_data.DataProxy):
_GS_UTIL_CLI = "gsutil"

@staticmethod
def _check_binary():
"""
Make sure that the AWS cli is present
"""
if not _which(GCSProxy._GS_UTIL_CLI):
raise _FlyteUserException('gsutil (gcloud cli) not found at Please install.')

def exists(self, remote_path):
"""
:param Text remote_path: remote gs:// path
:rtype bool: whether the gs file exists or not
"""
GCSProxy._check_binary()

if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

cmd = [GCSProxy._GS_UTIL_CLI, "-q", "stat", remote_path]
try:
_update_cmd_config_and_execute(cmd)
return True
except Exception:
return False

def download_directory(self, remote_path, local_path):
"""
:param Text remote_path: remote s3:// path

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/s3/gs/g

:param Text local_path: directory to copy to
"""
GCSProxy._check_binary()

if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

cmd = [GCSProxy._GS_UTIL_CLI, "cp", "-r", remote_path, local_path]
return _update_cmd_config_and_execute(cmd)

def download(self, remote_path, local_path):
"""
:param Text remote_path: remote s3:// path
:param Text local_path: directory to copy to
"""
if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

GCSProxy._check_binary()
cmd = [GCSProxy._GS_UTIL_CLI, "cp", remote_path, local_path]
return _update_cmd_config_and_execute(cmd)

def upload(self, file_path, to_path):
"""
:param Text file_path:
:param Text to_path:
"""
GCSProxy._check_binary()

cmd = [GCSProxy._GS_UTIL_CLI, "cp", file_path, to_path]

return _update_cmd_config_and_execute(cmd)

def upload_directory(self, local_path, remote_path):
"""
:param Text local_path:
:param Text remote_path:
"""
if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

GCSProxy._check_binary()
cmd = [GCSProxy._GS_UTIL_CLI, "cp", "-r", local_path, remote_path]

@tnsetting tnsetting Nov 1, 2019

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can change local_path to local_path + '/*' to just avoid upload the whole engine_dir folder.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behaviour in this PR seems fine to me, because the expectation of this function is to upload the specified local path. Maybe it is the caller who should use ../* instead?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can confirm that gsutil and aws s3 behave differently when copying a dir recursively.

e.g:

$ tree test
test
└── test1
    └── test.txt

1 directories, 1 file

$ gsutil cp -r test/* gs://flyte-test
Copying file://test/test1/a/b/test.txt [Content-Type=text/plain]...
/ [1 files][    0.0 B/    0.0 B]
Operation completed over 1 objects.

$ gsutil ls gs://flyte-test
gs://flyte-test//
gs://flyte-test/test1/    <--------

vs.

$ gsutil cp -r test gs://flyte-test
Copying file://test/test1/a/b/test.txt [Content-Type=text/plain]...
/ [1 files][    0.0 B/    0.0 B]
Operation completed over 1 objects.

$ gsutil ls gs://flyte-test
gs://flyte-test//
gs://flyte-test/test/    <--------

return _update_cmd_config_and_execute(cmd)

def get_random_path(self):
"""
:rtype: Text
"""
key = _uuid.UUID(int=_flyte_random.random.getrandbits(128)).hex
return _os.path.join(_gcp_config.GCS_PREFIX.get(), key)

def get_random_directory(self):
"""
:rtype: Text
"""
return self.get_random_path() + "/"