diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index ce887249c6991..cef9389073e82 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -25,8 +25,10 @@ from snowflake import connector from snowflake.connector import DictCursor, SnowflakeConnection from snowflake.connector.util_text import split_statements +from sqlalchemy import create_engine from airflow.hooks.dbapi import DbApiHook +from airflow.utils.strings import to_boolean class SnowflakeHook(DbApiHook): @@ -64,6 +66,10 @@ class SnowflakeHook(DbApiHook): :param session_parameters: You can set session-level parameters at the time you connect to Snowflake :type session_parameters: Optional[dict] + :param insecure_mode: Turns off OCSP certificate checks. + For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community + `__ + :type insecure_mode: Optional[bool] .. note:: get_sqlalchemy_engine() depends on snowflake-sqlalchemy @@ -84,7 +90,7 @@ def get_connection_form_widgets() -> Dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext - from wtforms import StringField + from wtforms import BooleanField, StringField return { "extra__snowflake__account": StringField(lazy_gettext('Account'), widget=BS3TextFieldWidget()), @@ -94,6 +100,9 @@ def get_connection_form_widgets() -> Dict[str, Any]: "extra__snowflake__database": StringField(lazy_gettext('Database'), widget=BS3TextFieldWidget()), "extra__snowflake__region": StringField(lazy_gettext('Region'), widget=BS3TextFieldWidget()), "extra__snowflake__role": StringField(lazy_gettext('Role'), widget=BS3TextFieldWidget()), + "extra__snowflake__insecure_mode": BooleanField( + label=lazy_gettext('Insecure mode'), description="Turns off OCSP certificate checks" + ), } @staticmethod @@ -113,7 +122,6 @@ def get_ui_field_behaviour() -> Dict: }, indent=1, ), - 'host': 'snowflake hostname', 'schema': 'snowflake schema', 'login': 'snowflake username', 'password': 'snowflake password', @@ -122,6 +130,7 @@ def get_ui_field_behaviour() -> Dict: 'extra__snowflake__database': 'snowflake db name', 'extra__snowflake__region': 'snowflake hosted region', 'extra__snowflake__role': 'snowflake role', + 'extra__snowflake__insecure_mode': 'insecure mode', }, } @@ -157,6 +166,11 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]: schema = conn.schema or '' authenticator = conn.extra_dejson.get('authenticator', 'snowflake') session_parameters = conn.extra_dejson.get('session_parameters') + insecure_mode = to_boolean( + conn.extra_dejson.get( + 'extra__snowflake__insecure_mode', conn.extra_dejson.get('insecure_mode', None) + ) + ) conn_config = { "user": conn.login, @@ -172,6 +186,8 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]: # application is used to track origin of the requests "application": os.environ.get("AIRFLOW_SNOWFLAKE_PARTNER", "AIRFLOW"), } + if insecure_mode: + conn_config['insecure_mode'] = insecure_mode # If private_key_file is specified in the extra json, load the contents of the file as a private # key and specify that in the connection configuration. The connection password then becomes the @@ -202,12 +218,15 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]: def get_uri(self) -> str: """Override DbApiHook get_uri method for get_sqlalchemy_engine()""" - conn_config = self._get_conn_params() + conn_params = self._get_conn_params() + return self._conn_params_to_sqlalchemy_uri(conn_params) + + def _conn_params_to_sqlalchemy_uri(self, conn_params: Dict) -> str: uri = ( 'snowflake://{user}:{password}@{account}.{region}/{database}/{schema}' '?warehouse={warehouse}&role={role}&authenticator={authenticator}' ) - return uri.format(**conn_config) + return uri.format(**conn_params) def get_conn(self) -> SnowflakeConnection: """Returns a snowflake.connection object""" @@ -215,6 +234,20 @@ def get_conn(self) -> SnowflakeConnection: conn = connector.connect(**conn_config) return conn + def get_sqlalchemy_engine(self, engine_kwargs=None): + """ + Get an sqlalchemy_engine object. + + :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`. + :return: the created engine. + """ + engine_kwargs = engine_kwargs or {} + conn_params = self._get_conn_params() + if 'insecure_mode' in conn_params: + engine_kwargs.setdefault('connect_args', dict()) + engine_kwargs['connect_args']['insecure_mode'] = True + return create_engine(self._conn_params_to_sqlalchemy_uri(conn_params), **engine_kwargs) + def set_autocommit(self, conn, autocommit: Any) -> None: conn.autocommit(autocommit) conn.autocommit_mode = autocommit diff --git a/docs/apache-airflow-providers-snowflake/connections/snowflake.rst b/docs/apache-airflow-providers-snowflake/connections/snowflake.rst index 598d2e106326b..cb9e76928825e 100644 --- a/docs/apache-airflow-providers-snowflake/connections/snowflake.rst +++ b/docs/apache-airflow-providers-snowflake/connections/snowflake.rst @@ -63,6 +63,9 @@ Extra (optional) * ``private_key_file``: Specify the path to the private key file. * ``session_parameters``: Specify `session level parameters `_ + * ``insecure_mode``: Turn off OCSP certificate checks + For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community + `__. When specifying the connection in environment variable you should specify it using URI syntax. diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py index cfced9442b2d1..436cde6be41b4 100644 --- a/tests/providers/snowflake/hooks/test_snowflake.py +++ b/tests/providers/snowflake/hooks/test_snowflake.py @@ -18,6 +18,7 @@ # import re import unittest +from copy import deepcopy from pathlib import Path from unittest import mock @@ -122,9 +123,72 @@ class TestPytestSnowflakeHook: 'warehouse': 'af_wh', }, ), + ( + { + **BASE_CONNECTION_KWARGS, + 'extra': { + 'extra__snowflake__database': 'db', + 'extra__snowflake__account': 'airflow', + 'extra__snowflake__warehouse': 'af_wh', + 'extra__snowflake__region': 'af_region', + 'extra__snowflake__role': 'af_role', + 'extra__snowflake__insecure_mode': 'True', + }, + }, + ( + 'snowflake://user:pw@airflow.af_region/db/public?' + 'warehouse=af_wh&role=af_role&authenticator=snowflake' + ), + { + 'account': 'airflow', + 'application': 'AIRFLOW', + 'authenticator': 'snowflake', + 'database': 'db', + 'password': 'pw', + 'region': 'af_region', + 'role': 'af_role', + 'schema': 'public', + 'session_parameters': None, + 'user': 'user', + 'warehouse': 'af_wh', + 'insecure_mode': True, + }, + ), + ( + { + **BASE_CONNECTION_KWARGS, + 'extra': { + 'extra__snowflake__database': 'db', + 'extra__snowflake__account': 'airflow', + 'extra__snowflake__warehouse': 'af_wh', + 'extra__snowflake__region': 'af_region', + 'extra__snowflake__role': 'af_role', + 'extra__snowflake__insecure_mode': 'False', + }, + }, + ( + 'snowflake://user:pw@airflow.af_region/db/public?' + 'warehouse=af_wh&role=af_role&authenticator=snowflake' + ), + { + 'account': 'airflow', + 'application': 'AIRFLOW', + 'authenticator': 'snowflake', + 'database': 'db', + 'password': 'pw', + 'region': 'af_region', + 'role': 'af_role', + 'schema': 'public', + 'session_parameters': None, + 'user': 'user', + 'warehouse': 'af_wh', + }, + ), ], ) - def test_hook_should_support_pass_auth(self, connection_kwargs, expected_uri, expected_conn_params): + def test_hook_should_support_prepare_basic_conn_params_and_uri( + self, connection_kwargs, expected_uri, expected_conn_params + ): with unittest.mock.patch.dict( 'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() ): @@ -201,6 +265,38 @@ def test_get_conn_should_call_connect(self): mock_connector.connect.assert_called_once_with(**hook._get_conn_params()) assert mock_connector.connect.return_value == conn + def test_get_sqlalchemy_engine_should_support_pass_auth(self): + with unittest.mock.patch.dict( + 'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri() + ), unittest.mock.patch( + 'airflow.providers.snowflake.hooks.snowflake.create_engine' + ) as mock_create_engine: + hook = SnowflakeHook(snowflake_conn_id='test_conn') + conn = hook.get_sqlalchemy_engine() + mock_create_engine.assert_called_once_with( + 'snowflake://user:pw@airflow.af_region/db/public' + '?warehouse=af_wh&role=af_role&authenticator=snowflake' + ) + assert mock_create_engine.return_value == conn + + def test_get_sqlalchemy_engine_should_support_insecure_mode(self): + connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) + connection_kwargs['extra']['extra__snowflake__insecure_mode'] = 'True' + + with unittest.mock.patch.dict( + 'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() + ), unittest.mock.patch( + 'airflow.providers.snowflake.hooks.snowflake.create_engine' + ) as mock_create_engine: + hook = SnowflakeHook(snowflake_conn_id='test_conn') + conn = hook.get_sqlalchemy_engine() + mock_create_engine.assert_called_once_with( + 'snowflake://user:pw@airflow.af_region/db/public' + '?warehouse=af_wh&role=af_role&authenticator=snowflake', + connect_args={'insecure_mode': True}, + ) + assert mock_create_engine.return_value == conn + def test_hook_parameters_should_take_precedence(self): with unittest.mock.patch.dict( 'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()