Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ class SdkTaskType(object):


GLOBAL_INPUT_NODE_ID = ''


class CloudProvider(object):
AWS = "aws"
GCP = "gcp"
5 changes: 5 additions & 0 deletions flytekit/configuration/gcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import absolute_import

from flytekit.configuration import common as _config_common

GCS_PREFIX = _config_common.FlyteRequiredStringConfigurationEntry('gcp', 'gcs_prefix')
4 changes: 4 additions & 0 deletions flytekit/configuration/platform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import absolute_import

from flytekit.configuration import common as _config_common
from flytekit.common import constants as _constants

URL = _config_common.FlyteRequiredStringConfigurationEntry('platform', 'url')
INSECURE = _config_common.FlyteBoolConfigurationEntry('platform', 'insecure', default=False)
CLOUD_PROVIDER = _config_common.FlyteStringConfigurationEntry(
'platform', 'cloud_provider', default=_constants.CloudProvider.AWS
)
28 changes: 24 additions & 4 deletions flytekit/interfaces/data/data_proxy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import absolute_import

from flytekit.configuration import sdk as _sdk_config
from flytekit.configuration import sdk as _sdk_config, platform as _platform_config
from flytekit.interfaces.data.s3 import s3proxy as _s3proxy
from flytekit.interfaces.data.gcs import gcs_proxy as _gcs_proxy
from flytekit.interfaces.data.local import local_file_proxy as _local_file_proxy
from flytekit.interfaces.data.http import http_data_proxy as _http_data_proxy
from flytekit.common.exceptions import user as _user_exception
from flytekit.common import utils as _common_utils
from flytekit.common import utils as _common_utils, constants as _constants
import six as _six


Expand Down Expand Up @@ -58,14 +59,33 @@ def __init__(self, sandbox):


class RemoteDataContext(_OutputDataContext):
def __init__(self):
super(RemoteDataContext, self).__init__(_s3proxy.AwsS3Proxy())

_CLOUD_PROVIDER_TO_PROXIES = {
_constants.CloudProvider.AWS: _s3proxy.AwsS3Proxy(),
_constants.CloudProvider.GCP: _gcs_proxy.GCSProxy(),
}

def __init__(self, cloud_provider=None):
"""
:param Optional[Text] cloud_provider: From flytekit.common.constants.CloudProvider enum
"""
cloud_provider = cloud_provider or _platform_config.CLOUD_PROVIDER.get()
proxy = type(self)._CLOUD_PROVIDER_TO_PROXIES.get(cloud_provider, None)
if proxy is None:
raise _user_exception.FlyteAssertion(
"Configured cloud provider is not supported for data I/O. Received: {}, expected one of: {}".format(
cloud_provider,
list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys())
)
)
super(RemoteDataContext, self).__init__(proxy)


class Data(object):
# TODO: More proxies for more environments.
_DATA_PROXIES = {
"s3:/": _s3proxy.AwsS3Proxy(),
"gs:/": _gcs_proxy.GCSProxy(),
"http://": _http_data_proxy.HttpFileProxy(),
"https://": _http_data_proxy.HttpFileProxy(),
}
Expand Down
Empty file.
116 changes: 116 additions & 0 deletions flytekit/interfaces/data/gcs/gcs_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import absolute_import

import os as _os
import sys as _sys
import uuid as _uuid

from flytekit.configuration import gcp as _gcp_config
from flytekit.interfaces import random as _flyte_random
from flytekit.interfaces.data import common as _common_data
from flytekit.tools import subprocess as _subprocess
from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException


if _sys.version_info >= (3,):
from shutil import which as _which
else:
from distutils.spawn import find_executable as _which


def _update_cmd_config_and_execute(cmd):
env = _os.environ.copy()
return _subprocess.check_call(cmd, env=env)


def _amend_path(path):
return _os.path.join(path, "*") if not path.endswith("*") else path


class GCSProxy(_common_data.DataProxy):
_GS_UTIL_CLI = "gsutil"

@staticmethod
def _check_binary():
"""
Make sure that the AWS cli is present
"""
if not _which(GCSProxy._GS_UTIL_CLI):
raise _FlyteUserException('gsutil (gcloud cli) not found at Please install.')

def exists(self, remote_path):
"""
:param Text remote_path: remote gs:// path
:rtype bool: whether the gs file exists or not
"""
GCSProxy._check_binary()

if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

cmd = [GCSProxy._GS_UTIL_CLI, "-q", "stat", remote_path]
try:
_update_cmd_config_and_execute(cmd)
return True
except Exception:
return False

def download_directory(self, remote_path, local_path):
"""
:param Text remote_path: remote gs:// path
:param Text local_path: directory to copy to
"""
GCSProxy._check_binary()

if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

cmd = [GCSProxy._GS_UTIL_CLI, "cp", "-r", _amend_path(remote_path), local_path]
return _update_cmd_config_and_execute(cmd)

def download(self, remote_path, local_path):
"""
:param Text remote_path: remote gs:// path
:param Text local_path: directory to copy to
"""
if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

GCSProxy._check_binary()
cmd = [GCSProxy._GS_UTIL_CLI, "cp", remote_path, local_path]
return _update_cmd_config_and_execute(cmd)

def upload(self, file_path, to_path):
"""
:param Text file_path:
:param Text to_path:
"""
GCSProxy._check_binary()

cmd = [GCSProxy._GS_UTIL_CLI, "cp", file_path, to_path]

return _update_cmd_config_and_execute(cmd)

def upload_directory(self, local_path, remote_path):
"""
:param Text local_path:
:param Text remote_path:
"""
if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

GCSProxy._check_binary()
cmd = [GCSProxy._GS_UTIL_CLI, "cp", "-r", _amend_path(local_path), remote_path]
return _update_cmd_config_and_execute(cmd)

def get_random_path(self):
"""
:rtype: Text
"""
key = _uuid.UUID(int=_flyte_random.random.getrandbits(128)).hex
return _os.path.join(_gcp_config.GCS_PREFIX.get(), key)

def get_random_directory(self):
"""
:rtype: Text
"""
return self.get_random_path() + "/"