From 0b4acd80e98ef2fd7d781ae97a8b3d633f3d7b19 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Tue, 11 Jan 2022 07:27:03 -0300 Subject: [PATCH 1/5] Sql to s3 operator --- CONTRIBUTING.rst | 2 +- .../amazon/aws/transfers/mysql_to_s3.py | 162 ++-------------- .../amazon/aws/transfers/sql_to_s3.py | 177 ++++++++++++++++++ airflow/providers/amazon/provider.yaml | 3 + airflow/providers/dependencies.json | 1 - .../prepare_provider_packages.py | 1 + .../amazon/aws/transfers/test_mysql_to_s3.py | 29 +-- .../amazon/aws/transfers/test_sql_to_s3.py | 115 ++++++++++++ 8 files changed, 330 insertions(+), 160 deletions(-) create mode 100644 airflow/providers/amazon/aws/transfers/sql_to_s3.py create mode 100644 tests/providers/amazon/aws/transfers/test_sql_to_s3.py diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index b7e25c12d9483..3688481551f85 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -674,7 +674,7 @@ Here is the list of packages and their extras: Package Extras ========================== =========================== airbyte http -amazon apache.hive,cncf.kubernetes,exasol,ftp,google,imap,mongo,mysql,salesforce,ssh +amazon apache.hive,cncf.kubernetes,exasol,ftp,google,imap,mongo,salesforce,ssh apache.beam google apache.druid apache.hive apache.hive amazon,microsoft.mssql,mysql,presto,samba,vertica diff --git a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py index 72e28d720c256..ed40fdb164947 100644 --- a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py @@ -16,176 +16,50 @@ # specific language governing permissions and limitations # under the License. -import os import warnings -from collections import namedtuple -from enum import Enum -from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence, Union - -import numpy as np -import pandas as pd -from typing_extensions import Literal +from typing import Optional from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.mysql.hooks.mysql import MySqlHook - -if TYPE_CHECKING: - from airflow.utils.context import Context +from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator - -FILE_FORMAT = Enum( - "FILE_FORMAT", - "CSV, PARQUET", +warnings.warn( + "This module is deprecated. Please use airflow.providers.amazon.aws.transfers.sql_to_s3`.", + DeprecationWarning, + stacklevel=2, ) -FileOptions = namedtuple('FileOptions', ['mode', 'suffix']) - -FILE_OPTIONS_MAP = { - FILE_FORMAT.CSV: FileOptions('r+', '.csv'), - FILE_FORMAT.PARQUET: FileOptions('rb+', '.parquet'), -} - -class MySQLToS3Operator(BaseOperator): +class MySQLToS3Operator(SqlToS3Operator): """ - Saves data from an specific MySQL query into a file in S3. - - :param query: the sql query to be executed. If you want to execute a file, place the absolute path of it, - ending with .sql extension. (templated) - :type query: str - :param s3_bucket: bucket where the data will be stored. (templated) - :type s3_bucket: str - :param s3_key: desired key for the file. It includes the name of the file. (templated) - :type s3_key: str - :param replace: whether or not to replace the file in S3 if it previously existed - :type replace: bool - :param mysql_conn_id: Reference to :ref:`mysql connection id `. - :type mysql_conn_id: str - :param aws_conn_id: reference to a specific S3 connection - :type aws_conn_id: str - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - You can provide the following values: - - - ``False``: do not validate SSL certificates. SSL will still be used - (unless use_ssl is False), but SSL certificates will not be verified. - - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. - :type verify: bool or str - :param pd_csv_kwargs: arguments to include in pd.to_csv (header, index, columns...) - :type pd_csv_kwargs: dict - :param index: whether to have the index or not in the dataframe - :type index: str - :param header: whether to include header or not into the S3 file - :type header: bool - :param file_format: the destination file format, only string 'csv' or 'parquet' is accepted. - :type file_format: str - :param pd_kwargs: arguments to include in ``DataFrame.to_parquet()`` or - ``DataFrame.to_csv()``. This is preferred than ``pd_csv_kwargs``. - :type pd_kwargs: dict + This class is deprecated. + Please use `airflow.providers.amazon.aws.transfers.sql_to_s3.SqlToS3Operator`. """ - template_fields: Sequence[str] = ( - 's3_bucket', - 's3_key', - 'query', - ) - template_ext: Sequence[str] = ('.sql',) template_fields_renderers = { - "query": "sql", "pd_csv_kwargs": "json", - "pd_kwargs": "json", } def __init__( self, *, - query: str, - s3_bucket: str, - s3_key: str, - replace: bool = False, mysql_conn_id: str = 'mysql_default', - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, pd_csv_kwargs: Optional[dict] = None, index: bool = False, header: bool = False, - file_format: Literal['csv', 'parquet'] = 'csv', - pd_kwargs: Optional[dict] = None, **kwargs, ) -> None: - super().__init__(**kwargs) - self.query = query - self.s3_bucket = s3_bucket - self.s3_key = s3_key - self.mysql_conn_id = mysql_conn_id - self.aws_conn_id = aws_conn_id - self.verify = verify - self.replace = replace - if file_format == "csv": - self.file_format = FILE_FORMAT.CSV - else: - self.file_format = FILE_FORMAT.PARQUET - - if pd_csv_kwargs: - warnings.warn( - "pd_csv_kwargs is deprecated. Please use pd_kwargs.", - DeprecationWarning, - stacklevel=2, - ) - if index or header: - warnings.warn( - "index and header are deprecated. Please pass them via pd_kwargs.", - DeprecationWarning, - stacklevel=2, - ) - - self.pd_kwargs = pd_kwargs or pd_csv_kwargs or {} - if self.file_format == FILE_FORMAT.CSV: - if "path_or_buf" in self.pd_kwargs: + pd_kwargs = kwargs.get('pd_kwargs', {}) + if kwargs.get('file_format', "csv") == "csv": + if "path_or_buf" in pd_kwargs: raise AirflowException('The argument path_or_buf is not allowed, please remove it') - if "index" not in self.pd_kwargs: - self.pd_kwargs["index"] = index - if "header" not in self.pd_kwargs: - self.pd_kwargs["header"] = header + if "index" not in pd_kwargs: + pd_kwargs["index"] = index + if "header" not in pd_kwargs: + pd_kwargs["header"] = header + kwargs["pd_kwargs"] = {**kwargs.get('pd_kwargs', {}), **pd_kwargs} else: if pd_csv_kwargs is not None: raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'") - @staticmethod - def _fix_int_dtypes(df: pd.DataFrame) -> None: - """Mutate DataFrame to set dtypes for int columns containing NaN values.""" - for col in df: - if "float" in df[col].dtype.name and df[col].hasnans: - # inspect values to determine if dtype of non-null values is int or float - notna_series = df[col].dropna().values - if np.isclose(notna_series, notna_series.astype(int)).all(): - # set to dtype that retains integers and supports NaNs - df[col] = np.where(df[col].isnull(), None, df[col]) - df[col] = df[col].astype(pd.Int64Dtype()) - - def execute(self, context: 'Context') -> None: - mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) - s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - data_df = mysql_hook.get_pandas_df(self.query) - self.log.info("Data from MySQL obtained") - - self._fix_int_dtypes(data_df) - file_options = FILE_OPTIONS_MAP[self.file_format] - with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file: - if self.file_format == FILE_FORMAT.CSV: - data_df.to_csv(tmp_file.name, **self.pd_kwargs) - else: - data_df.to_parquet(tmp_file.name, **self.pd_kwargs) - s3_conn.load_file( - filename=tmp_file.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace - ) - - if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): - file_location = os.path.join(self.s3_bucket, self.s3_key) - self.log.info("File saved correctly in %s", file_location) + super().__init__(sql_conn_id=mysql_conn_id, **kwargs) diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py new file mode 100644 index 0000000000000..7b6fa7f5195d7 --- /dev/null +++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from collections import namedtuple +from enum import Enum +from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union + +import numpy as np +import pandas as pd +from typing_extensions import Literal + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.hooks.dbapi import DbApiHook +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +FILE_FORMAT = Enum( + "FILE_FORMAT", + "CSV, PARQUET", +) + +FileOptions = namedtuple('FileOptions', ['mode', 'suffix']) + +FILE_OPTIONS_MAP = { + FILE_FORMAT.CSV: FileOptions('r+', '.csv'), + FILE_FORMAT.PARQUET: FileOptions('rb+', '.parquet'), +} + + +class SqlToS3Operator(BaseOperator): + """ + Saves data from an specific SQL query into a file in S3. + + :param query: the sql query to be executed. If you want to execute a file, place the absolute path of it, + ending with .sql extension. (templated) + :type query: str + :param s3_bucket: bucket where the data will be stored. (templated) + :type s3_bucket: str + :param s3_key: desired key for the file. It includes the name of the file. (templated) + :type s3_key: str + :param replace: whether or not to replace the file in S3 if it previously existed + :type replace: bool + :param sql_conn_id: reference to a specific database. + :type sql_conn_id: str + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + :param aws_conn_id: reference to a specific S3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :param file_format: the destination file format, only string 'csv' or 'parquet' is accepted. + :type file_format: str + :param pd_kwargs: arguments to include in ``DataFrame.to_parquet()`` or + ``DataFrame.to_csv()``. + :type pd_kwargs: dict + """ + + template_fields: Sequence[str] = ( + 's3_bucket', + 's3_key', + 'query', + ) + template_ext: Sequence[str] = ('.sql',) + template_fields_renderers = { + "query": "sql", + "pd_csv_kwargs": "json", + "pd_kwargs": "json", + } + + def __init__( + self, + *, + query: str, + s3_bucket: str, + s3_key: str, + sql_conn_id: str, + parameters: Optional[Union[Mapping, Iterable]] = None, + replace: bool = False, + aws_conn_id: str = 'aws_default', + verify: Optional[Union[bool, str]] = None, + file_format: Literal['csv', 'parquet'] = 'csv', + pd_kwargs: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.query = query + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.sql_conn_id = sql_conn_id + self.aws_conn_id = aws_conn_id + self.verify = verify + self.replace = replace + self.pd_kwargs = pd_kwargs or {} + self.parameters = parameters + + if file_format == "csv": + self.file_format = FILE_FORMAT.CSV + if "path_or_buf" in self.pd_kwargs: + raise AirflowException('The argument path_or_buf is not allowed, please remove it') + elif file_format == "parquet": + self.file_format = FILE_FORMAT.PARQUET + else: + raise AirflowException(f"The argument file_format doesn't support {file_format} value.") + + @staticmethod + def _fix_int_dtypes(df: pd.DataFrame) -> None: + """Mutate DataFrame to set dtypes for int columns containing NaN values.""" + for col in df: + if "float" in df[col].dtype.name and df[col].hasnans: + # inspect values to determine if dtype of non-null values is int or float + notna_series = df[col].dropna().values + if np.isclose(notna_series, notna_series.astype(int)).all(): + # set to dtype that retains integers and supports NaNs + df[col] = np.where(df[col].isnull(), None, df[col]) + df[col] = df[col].astype(pd.Int64Dtype()) + + def execute(self, context: 'Context') -> None: + sql_hook = self._get_hook() + s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + data_df = sql_hook.get_pandas_df(sql=self.query, parameters=self.parameters) + self.log.info("Data from SQL obtained") + + self._fix_int_dtypes(data_df) + file_options = FILE_OPTIONS_MAP[self.file_format] + + with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file: + + if self.file_format == FILE_FORMAT.CSV: + data_df.to_csv(tmp_file.name, **self.pd_kwargs) + else: + data_df.to_parquet(tmp_file.name, **self.pd_kwargs) + + s3_conn.load_file( + filename=tmp_file.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace + ) + + if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): + file_location = os.path.join(self.s3_bucket, self.s3_key) + self.log.info("File saved correctly in %s", file_location) + + def _get_hook(self) -> DbApiHook: + self.log.debug("Get connection for %s", self.sql_conn_id) + conn = BaseHook.get_connection(self.sql_conn_id) + hook = conn.get_hook() + if not isinstance(hook, DbApiHook): + raise AirflowException("This hook is not supported. The hook class must extends DbApiHook.") + return hook diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 9eae4ab8dd02f..3768373c0269f 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -465,6 +465,9 @@ transfers: - source-integration-name: Local target-integration-name: Amazon Simple Storage Service (S3) python-module: airflow.providers.amazon.aws.transfers.local_to_s3 + - source-integration-name: SQL + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.sql_to_s3 hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - airflow.providers.amazon.aws.hooks.s3.S3Hook diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json index 3c511465a7c43..bb73d81a099b2 100644 --- a/airflow/providers/dependencies.json +++ b/airflow/providers/dependencies.json @@ -10,7 +10,6 @@ "google", "imap", "mongo", - "mysql", "salesforce", "ssh" ], diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py index 65636bcd88ec3..619494fd834f3 100755 --- a/dev/provider_packages/prepare_provider_packages.py +++ b/dev/provider_packages/prepare_provider_packages.py @@ -2192,6 +2192,7 @@ def summarise_total_vs_bad_and_warnings(total: int, bad: int, warns: List[warnin 'This module is deprecated. Please use `airflow.providers.amazon.aws.operators.redshift_sql` or ' '`airflow.providers.amazon.aws.operators.redshift_cluster` as appropriate.', 'This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.', + "This module is deprecated. Please use airflow.providers.amazon.aws.transfers.sql_to_s3`.", } diff --git a/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py b/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py index 8ef0a43cb5fd4..63c40da59aacb 100644 --- a/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py @@ -26,16 +26,16 @@ class TestMySqlToS3Operator(unittest.TestCase): - @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.NamedTemporaryFile") - @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.S3Hook") - @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.MySqlHook") - def test_execute_csv(self, mock_mysql_hook, mock_s3_hook, temp_mock): + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") + def test_execute_csv(self, mock_s3_hook, temp_mock): query = "query" s3_bucket = "bucket" s3_key = "key" + mock_dbapi_hook = mock.Mock() test_df = pd.DataFrame({'a': '1', 'b': '2'}, index=[0, 1]) - get_pandas_df_mock = mock_mysql_hook.return_value.get_pandas_df + get_pandas_df_mock = mock_dbapi_hook.return_value.get_pandas_df get_pandas_df_mock.return_value = test_df with NamedTemporaryFile() as f: temp_mock.return_value.__enter__.return_value.name = f.name @@ -53,11 +53,11 @@ def test_execute_csv(self, mock_mysql_hook, mock_s3_hook, temp_mock): pd_csv_kwargs={'index': False, 'header': False}, dag=None, ) + op._get_hook = mock_dbapi_hook op.execute(None) - mock_mysql_hook.assert_called_once_with(mysql_conn_id="mysql_conn_id") mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) - get_pandas_df_mock.assert_called_once_with(query) + get_pandas_df_mock.assert_called_once_with(sql=query, parameters=None) temp_mock.assert_called_once_with(mode='r+', suffix=".csv") mock_s3_hook.return_value.load_file.assert_called_once_with( @@ -67,16 +67,17 @@ def test_execute_csv(self, mock_mysql_hook, mock_s3_hook, temp_mock): replace=True, ) - @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.NamedTemporaryFile") - @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.S3Hook") - @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.MySqlHook") - def test_execute_parquet(self, mock_mysql_hook, mock_s3_hook, temp_mock): + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") + def test_execute_parquet(self, mock_s3_hook, temp_mock): query = "query" s3_bucket = "bucket" s3_key = "key" + mock_dbapi_hook = mock.Mock() + test_df = pd.DataFrame({'a': '1', 'b': '2'}, index=[0, 1]) - get_pandas_df_mock = mock_mysql_hook.return_value.get_pandas_df + get_pandas_df_mock = mock_dbapi_hook.return_value.get_pandas_df get_pandas_df_mock.return_value = test_df with NamedTemporaryFile() as f: temp_mock.return_value.__enter__.return_value.name = f.name @@ -92,11 +93,11 @@ def test_execute_parquet(self, mock_mysql_hook, mock_s3_hook, temp_mock): replace=False, dag=None, ) + op._get_hook = mock_dbapi_hook op.execute(None) - mock_mysql_hook.assert_called_once_with(mysql_conn_id="mysql_conn_id") mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) - get_pandas_df_mock.assert_called_once_with(query) + get_pandas_df_mock.assert_called_once_with(sql=query, parameters=None) temp_mock.assert_called_once_with(mode='rb+', suffix=".parquet") mock_s3_hook.return_value.load_file.assert_called_once_with( diff --git a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py new file mode 100644 index 0000000000000..97585ba983fa5 --- /dev/null +++ b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import unittest +from tempfile import NamedTemporaryFile +from unittest import mock + +import pandas as pd + +from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator + + +class TestSqlToS3Operator(unittest.TestCase): + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") + def test_execute_csv(self, mock_s3_hook, temp_mock): + query = "query" + s3_bucket = "bucket" + s3_key = "key" + + mock_dbapi_hook = mock.Mock() + test_df = pd.DataFrame({'a': '1', 'b': '2'}, index=[0, 1]) + get_pandas_df_mock = mock_dbapi_hook.return_value.get_pandas_df + get_pandas_df_mock.return_value = test_df + with NamedTemporaryFile() as f: + temp_mock.return_value.__enter__.return_value.name = f.name + + op = SqlToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + sql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + replace=True, + pd_kwargs={'index': False, 'header': False}, + dag=None, + ) + op._get_hook = mock_dbapi_hook + op.execute(None) + mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + + get_pandas_df_mock.assert_called_once_with(sql=query, parameters=None) + + temp_mock.assert_called_once_with(mode='r+', suffix=".csv") + mock_s3_hook.return_value.load_file.assert_called_once_with( + filename=f.name, + key=s3_key, + bucket_name=s3_bucket, + replace=True, + ) + + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") + def test_execute_parquet(self, mock_s3_hook, temp_mock): + query = "query" + s3_bucket = "bucket" + s3_key = "key" + + mock_dbapi_hook = mock.Mock() + + test_df = pd.DataFrame({'a': '1', 'b': '2'}, index=[0, 1]) + get_pandas_df_mock = mock_dbapi_hook.return_value.get_pandas_df + get_pandas_df_mock.return_value = test_df + with NamedTemporaryFile() as f: + temp_mock.return_value.__enter__.return_value.name = f.name + + op = SqlToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + sql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + file_format="parquet", + replace=False, + dag=None, + ) + op._get_hook = mock_dbapi_hook + op.execute(None) + mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + + get_pandas_df_mock.assert_called_once_with(sql=query, parameters=None) + + temp_mock.assert_called_once_with(mode='rb+', suffix=".parquet") + mock_s3_hook.return_value.load_file.assert_called_once_with( + filename=f.name, key=s3_key, bucket_name=s3_bucket, replace=False + ) + + def test_fix_int_dtypes(self): + op = SqlToS3Operator( + query="query", + s3_bucket="s3_bucket", + s3_key="s3_key", + task_id="task_id", + sql_conn_id="mysql_conn_id", + ) + dirty_df = pd.DataFrame({"strings": ["a", "b", "c"], "ints": [1, 2, None]}) + op._fix_int_dtypes(df=dirty_df) + assert dirty_df["ints"].dtype.kind == "i" From 4e5efdbffb62cad71e83f23f992b7358f5dc29e4 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Tue, 11 Jan 2022 16:42:29 +0000 Subject: [PATCH 2/5] update sql_to_s3.py --- .../providers/amazon/aws/transfers/sql_to_s3.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py index 7b6fa7f5195d7..4008c01a21c21 100644 --- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. -import os from collections import namedtuple from enum import Enum from tempfile import NamedTemporaryFile @@ -80,8 +79,7 @@ class SqlToS3Operator(BaseOperator): :type verify: bool or str :param file_format: the destination file format, only string 'csv' or 'parquet' is accepted. :type file_format: str - :param pd_kwargs: arguments to include in ``DataFrame.to_parquet()`` or - ``DataFrame.to_csv()``. + :param pd_kwargs: arguments to include in ``DataFrame.to_parquet()`` or ``DataFrame.to_csv()``. :type pd_kwargs: dict """ @@ -104,7 +102,7 @@ def __init__( s3_bucket: str, s3_key: str, sql_conn_id: str, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Union[None, Mapping, Iterable] = None, replace: bool = False, aws_conn_id: str = 'aws_default', verify: Optional[Union[bool, str]] = None, @@ -164,14 +162,12 @@ def execute(self, context: 'Context') -> None: filename=tmp_file.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace ) - if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): - file_location = os.path.join(self.s3_bucket, self.s3_key) - self.log.info("File saved correctly in %s", file_location) - def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) conn = BaseHook.get_connection(self.sql_conn_id) hook = conn.get_hook() - if not isinstance(hook, DbApiHook): - raise AirflowException("This hook is not supported. The hook class must extends DbApiHook.") + if not callable(getattr(hook, 'get_pandas_df', None)): + raise AirflowException( + "This hook is not supported. The hook class must have get_pandas_df method." + ) return hook From 06a597a3f0cb4a757d6f66b94f1dd3dae77c30fc Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Wed, 12 Jan 2022 20:59:32 -0300 Subject: [PATCH 3/5] Ignore deprecation warning on mysql to s3 test --- tests/providers/amazon/aws/transfers/test_mysql_to_s3.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py b/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py index 63c40da59aacb..85bc1110d01fd 100644 --- a/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py @@ -21,9 +21,12 @@ from unittest import mock import pandas as pd +import pytest from airflow.providers.amazon.aws.transfers.mysql_to_s3 import MySQLToS3Operator +pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") + class TestMySqlToS3Operator(unittest.TestCase): @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") From bd9737714538ba260415ac07826a21d089d4071d Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Wed, 12 Jan 2022 21:00:50 -0300 Subject: [PATCH 4/5] Update test_mysql_to_s3.py --- tests/providers/amazon/aws/transfers/test_mysql_to_s3.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py b/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py index 85bc1110d01fd..22fb1cc0aa5c4 100644 --- a/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py @@ -23,8 +23,6 @@ import pandas as pd import pytest -from airflow.providers.amazon.aws.transfers.mysql_to_s3 import MySQLToS3Operator - pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") @@ -32,6 +30,8 @@ class TestMySqlToS3Operator(unittest.TestCase): @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") def test_execute_csv(self, mock_s3_hook, temp_mock): + from airflow.providers.amazon.aws.transfers.mysql_to_s3 import MySQLToS3Operator + query = "query" s3_bucket = "bucket" s3_key = "key" @@ -73,6 +73,8 @@ def test_execute_csv(self, mock_s3_hook, temp_mock): @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") def test_execute_parquet(self, mock_s3_hook, temp_mock): + from airflow.providers.amazon.aws.transfers.mysql_to_s3 import MySQLToS3Operator + query = "query" s3_bucket = "bucket" s3_key = "key" @@ -108,6 +110,8 @@ def test_execute_parquet(self, mock_s3_hook, temp_mock): ) def test_fix_int_dtypes(self): + from airflow.providers.amazon.aws.transfers.mysql_to_s3 import MySQLToS3Operator + op = MySQLToS3Operator(query="query", s3_bucket="s3_bucket", s3_key="s3_key", task_id="task_id") dirty_df = pd.DataFrame({"strings": ["a", "b", "c"], "ints": [1, 2, None]}) op._fix_int_dtypes(df=dirty_df) From 1b37e13594bdfd82de5e27cf72d9c552ba2d4e7b Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Thu, 13 Jan 2022 08:12:02 -0300 Subject: [PATCH 5/5] Add deprecation warning on MySQLToS3Operator constructor --- airflow/providers/amazon/aws/transfers/mysql_to_s3.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py index ed40fdb164947..728aaddcba0a9 100644 --- a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py @@ -48,6 +48,14 @@ def __init__( header: bool = False, **kwargs, ) -> None: + warnings.warn( + """ + MySQLToS3Operator is deprecated. + Please use `airflow.providers.amazon.aws.transfers.sql_to_s3.SqlToS3Operator`. + """, + DeprecationWarning, + stacklevel=2, + ) pd_kwargs = kwargs.get('pd_kwargs', {}) if kwargs.get('file_format', "csv") == "csv":