diff --git a/flytekit/common/constants.py b/flytekit/common/constants.py index fa4c4d70db..6b423a40a0 100644 --- a/flytekit/common/constants.py +++ b/flytekit/common/constants.py @@ -22,3 +22,8 @@ class SdkTaskType(object): GLOBAL_INPUT_NODE_ID = '' + + +class CloudProvider(object): + AWS = "aws" + GCP = "gcp" diff --git a/flytekit/configuration/gcp.py b/flytekit/configuration/gcp.py new file mode 100644 index 0000000000..f54862f54c --- /dev/null +++ b/flytekit/configuration/gcp.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import + +from flytekit.configuration import common as _config_common + +GCS_PREFIX = _config_common.FlyteRequiredStringConfigurationEntry('gcp', 'gcs_prefix') diff --git a/flytekit/configuration/platform.py b/flytekit/configuration/platform.py index 06440dcfd8..4d63e97b0a 100644 --- a/flytekit/configuration/platform.py +++ b/flytekit/configuration/platform.py @@ -1,6 +1,10 @@ from __future__ import absolute_import from flytekit.configuration import common as _config_common +from flytekit.common import constants as _constants URL = _config_common.FlyteRequiredStringConfigurationEntry('platform', 'url') INSECURE = _config_common.FlyteBoolConfigurationEntry('platform', 'insecure', default=False) +CLOUD_PROVIDER = _config_common.FlyteStringConfigurationEntry( + 'platform', 'cloud_provider', default=_constants.CloudProvider.AWS +) diff --git a/flytekit/interfaces/data/data_proxy.py b/flytekit/interfaces/data/data_proxy.py index 71092c7dd8..0cc27fc95a 100644 --- a/flytekit/interfaces/data/data_proxy.py +++ b/flytekit/interfaces/data/data_proxy.py @@ -1,10 +1,11 @@ from __future__ import absolute_import -from flytekit.configuration import sdk as _sdk_config +from flytekit.configuration import sdk as _sdk_config, platform as _platform_config from flytekit.interfaces.data.s3 import s3proxy as _s3proxy +from flytekit.interfaces.data.gcs import gcs_proxy as _gcs_proxy from flytekit.interfaces.data.local import local_file_proxy as _local_file_proxy from flytekit.common.exceptions import user as _user_exception -from flytekit.common import utils as _common_utils +from flytekit.common import utils as _common_utils, constants as _constants import six as _six @@ -57,14 +58,33 @@ def __init__(self, sandbox): class RemoteDataContext(_OutputDataContext): - def __init__(self): - super(RemoteDataContext, self).__init__(_s3proxy.AwsS3Proxy()) + + _CLOUD_PROVIDER_TO_PROXIES = { + _constants.CloudProvider.AWS: _s3proxy.AwsS3Proxy(), + _constants.CloudProvider.GCP: _gcs_proxy.GCSProxy(), + } + + def __init__(self, cloud_provider=None): + """ + :param Optional[Text] cloud_provider: From flytekit.common.constants.CloudProvider enum + """ + cloud_provider = cloud_provider or _platform_config.CLOUD_PROVIDER.get() + proxy = type(self)._CLOUD_PROVIDER_TO_PROXIES.get(cloud_provider, None) + if proxy is None: + raise _user_exception.FlyteAssertion( + "Configured cloud provider is not supported for data I/O. Received: {}, expected one of: {}".format( + cloud_provider, + list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys()) + ) + ) + super(RemoteDataContext, self).__init__(proxy) class Data(object): # TODO: More proxies for more environments. _DATA_PROXIES = { - "s3:/": _s3proxy.AwsS3Proxy() + "s3:/": _s3proxy.AwsS3Proxy(), + "gs:/": _gcs_proxy.GCSProxy() } @classmethod diff --git a/flytekit/interfaces/data/gcs/__init__.py b/flytekit/interfaces/data/gcs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/interfaces/data/gcs/gcs_proxy.py b/flytekit/interfaces/data/gcs/gcs_proxy.py new file mode 100644 index 0000000000..cb40135e71 --- /dev/null +++ b/flytekit/interfaces/data/gcs/gcs_proxy.py @@ -0,0 +1,112 @@ +from __future__ import absolute_import + +import os as _os +import sys as _sys +import uuid as _uuid + +from flytekit.configuration import gcp as _gcp_config +from flytekit.interfaces import random as _flyte_random +from flytekit.interfaces.data import common as _common_data +from flytekit.tools import subprocess as _subprocess +from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException + + +if _sys.version_info >= (3,): + from shutil import which as _which +else: + from distutils.spawn import find_executable as _which + + +def _update_cmd_config_and_execute(cmd): + env = _os.environ.copy() + return _subprocess.check_call(cmd, env=env) + + +class GCSProxy(_common_data.DataProxy): + _GS_UTIL_CLI = "gsutil" + + @staticmethod + def _check_binary(): + """ + Make sure that the AWS cli is present + """ + if not _which(GCSProxy._GS_UTIL_CLI): + raise _FlyteUserException('gsutil (gcloud cli) not found at Please install.') + + def exists(self, remote_path): + """ + :param Text remote_path: remote gs:// path + :rtype bool: whether the gs file exists or not + """ + GCSProxy._check_binary() + + if not remote_path.startswith("gs://"): + raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") + + cmd = [GCSProxy._GS_UTIL_CLI, "-q", "stat", remote_path] + try: + _update_cmd_config_and_execute(cmd) + return True + except Exception: + return False + + def download_directory(self, remote_path, local_path): + """ + :param Text remote_path: remote s3:// path + :param Text local_path: directory to copy to + """ + GCSProxy._check_binary() + + if not remote_path.startswith("gs://"): + raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") + + cmd = [GCSProxy._GS_UTIL_CLI, "cp", "-r", remote_path, local_path] + return _update_cmd_config_and_execute(cmd) + + def download(self, remote_path, local_path): + """ + :param Text remote_path: remote s3:// path + :param Text local_path: directory to copy to + """ + if not remote_path.startswith("gs://"): + raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") + + GCSProxy._check_binary() + cmd = [GCSProxy._GS_UTIL_CLI, "cp", remote_path, local_path] + return _update_cmd_config_and_execute(cmd) + + def upload(self, file_path, to_path): + """ + :param Text file_path: + :param Text to_path: + """ + GCSProxy._check_binary() + + cmd = [GCSProxy._GS_UTIL_CLI, "cp", file_path, to_path] + + return _update_cmd_config_and_execute(cmd) + + def upload_directory(self, local_path, remote_path): + """ + :param Text local_path: + :param Text remote_path: + """ + if not remote_path.startswith("gs://"): + raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") + + GCSProxy._check_binary() + cmd = [GCSProxy._GS_UTIL_CLI, "cp", "-r", local_path, remote_path] + return _update_cmd_config_and_execute(cmd) + + def get_random_path(self): + """ + :rtype: Text + """ + key = _uuid.UUID(int=_flyte_random.random.getrandbits(128)).hex + return _os.path.join(_gcp_config.GCS_PREFIX.get(), key) + + def get_random_directory(self): + """ + :rtype: Text + """ + return self.get_random_path() + "/"