Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
from snowflake import connector
from snowflake.connector import DictCursor, SnowflakeConnection
from snowflake.connector.util_text import split_statements
from sqlalchemy import create_engine

from airflow.hooks.dbapi import DbApiHook
from airflow.utils.strings import to_boolean


class SnowflakeHook(DbApiHook):
Expand Down Expand Up @@ -64,6 +66,10 @@ class SnowflakeHook(DbApiHook):
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:type session_parameters: Optional[dict]
:param insecure_mode: Turns off OCSP certificate checks.
For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community
<https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`__
:type insecure_mode: Optional[bool]

.. note::
get_sqlalchemy_engine() depends on snowflake-sqlalchemy
Expand All @@ -84,7 +90,7 @@ def get_connection_form_widgets() -> Dict[str, Any]:
"""Returns connection widgets to add to connection form"""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField
from wtforms import BooleanField, StringField

return {
"extra__snowflake__account": StringField(lazy_gettext('Account'), widget=BS3TextFieldWidget()),
Expand All @@ -94,6 +100,9 @@ def get_connection_form_widgets() -> Dict[str, Any]:
"extra__snowflake__database": StringField(lazy_gettext('Database'), widget=BS3TextFieldWidget()),
"extra__snowflake__region": StringField(lazy_gettext('Region'), widget=BS3TextFieldWidget()),
"extra__snowflake__role": StringField(lazy_gettext('Role'), widget=BS3TextFieldWidget()),
"extra__snowflake__insecure_mode": BooleanField(
label=lazy_gettext('Insecure mode'), description="Turns off OCSP certificate checks"
),
}

@staticmethod
Expand All @@ -113,7 +122,6 @@ def get_ui_field_behaviour() -> Dict:
},
indent=1,
),
'host': 'snowflake hostname',
'schema': 'snowflake schema',
'login': 'snowflake username',
'password': 'snowflake password',
Expand All @@ -122,6 +130,7 @@ def get_ui_field_behaviour() -> Dict:
'extra__snowflake__database': 'snowflake db name',
'extra__snowflake__region': 'snowflake hosted region',
'extra__snowflake__role': 'snowflake role',
'extra__snowflake__insecure_mode': 'insecure mode',
},
}

Expand Down Expand Up @@ -157,6 +166,11 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]:
schema = conn.schema or ''
authenticator = conn.extra_dejson.get('authenticator', 'snowflake')
session_parameters = conn.extra_dejson.get('session_parameters')
insecure_mode = to_boolean(
conn.extra_dejson.get(
'extra__snowflake__insecure_mode', conn.extra_dejson.get('insecure_mode', None)
)
)

conn_config = {
"user": conn.login,
Expand All @@ -172,6 +186,8 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]:
# application is used to track origin of the requests
"application": os.environ.get("AIRFLOW_SNOWFLAKE_PARTNER", "AIRFLOW"),
}
if insecure_mode:
conn_config['insecure_mode'] = insecure_mode

# If private_key_file is specified in the extra json, load the contents of the file as a private
# key and specify that in the connection configuration. The connection password then becomes the
Expand Down Expand Up @@ -202,19 +218,36 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]:

def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()"""
conn_config = self._get_conn_params()
conn_params = self._get_conn_params()
return self._conn_params_to_sqlalchemy_uri(conn_params)

def _conn_params_to_sqlalchemy_uri(self, conn_params: Dict) -> str:
uri = (
'snowflake://{user}:{password}@{account}.{region}/{database}/{schema}'
'?warehouse={warehouse}&role={role}&authenticator={authenticator}'
)
return uri.format(**conn_config)
return uri.format(**conn_params)

def get_conn(self) -> SnowflakeConnection:
"""Returns a snowflake.connection object"""
conn_config = self._get_conn_params()
conn = connector.connect(**conn_config)
return conn

def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
Get an sqlalchemy_engine object.

:param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
:return: the created engine.
"""
engine_kwargs = engine_kwargs or {}
conn_params = self._get_conn_params()
if 'insecure_mode' in conn_params:
engine_kwargs.setdefault('connect_args', dict())
engine_kwargs['connect_args']['insecure_mode'] = True
return create_engine(self._conn_params_to_sqlalchemy_uri(conn_params), **engine_kwargs)

def set_autocommit(self, conn, autocommit: Any) -> None:
conn.autocommit(autocommit)
conn.autocommit_mode = autocommit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ Extra (optional)
* ``private_key_file``: Specify the path to the private key file.
* ``session_parameters``: Specify `session level parameters
<https://docs.snowflake.com/en/user-guide/python-connector-example.html#setting-session-parameters>`_
* ``insecure_mode``: Turn off OCSP certificate checks
For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community
<https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`__.

When specifying the connection in environment variable you should specify
it using URI syntax.
Expand Down
98 changes: 97 additions & 1 deletion tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#
import re
import unittest
from copy import deepcopy
from pathlib import Path
from unittest import mock

Expand Down Expand Up @@ -122,9 +123,72 @@ class TestPytestSnowflakeHook:
'warehouse': 'af_wh',
},
),
(
{
**BASE_CONNECTION_KWARGS,
'extra': {
'extra__snowflake__database': 'db',
'extra__snowflake__account': 'airflow',
'extra__snowflake__warehouse': 'af_wh',
'extra__snowflake__region': 'af_region',
'extra__snowflake__role': 'af_role',
'extra__snowflake__insecure_mode': 'True',
},
},
(
'snowflake://user:pw@airflow.af_region/db/public?'
'warehouse=af_wh&role=af_role&authenticator=snowflake'
),
{
'account': 'airflow',
'application': 'AIRFLOW',
'authenticator': 'snowflake',
'database': 'db',
'password': 'pw',
'region': 'af_region',
'role': 'af_role',
'schema': 'public',
'session_parameters': None,
'user': 'user',
'warehouse': 'af_wh',
'insecure_mode': True,
},
),
(
{
**BASE_CONNECTION_KWARGS,
'extra': {
'extra__snowflake__database': 'db',
'extra__snowflake__account': 'airflow',
'extra__snowflake__warehouse': 'af_wh',
'extra__snowflake__region': 'af_region',
'extra__snowflake__role': 'af_role',
'extra__snowflake__insecure_mode': 'False',
},
},
(
'snowflake://user:pw@airflow.af_region/db/public?'
'warehouse=af_wh&role=af_role&authenticator=snowflake'
),
{
'account': 'airflow',
'application': 'AIRFLOW',
'authenticator': 'snowflake',
'database': 'db',
'password': 'pw',
'region': 'af_region',
'role': 'af_role',
'schema': 'public',
'session_parameters': None,
'user': 'user',
'warehouse': 'af_wh',
},
),
],
)
def test_hook_should_support_pass_auth(self, connection_kwargs, expected_uri, expected_conn_params):
def test_hook_should_support_prepare_basic_conn_params_and_uri(
self, connection_kwargs, expected_uri, expected_conn_params
):
with unittest.mock.patch.dict(
'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
):
Expand Down Expand Up @@ -201,6 +265,38 @@ def test_get_conn_should_call_connect(self):
mock_connector.connect.assert_called_once_with(**hook._get_conn_params())
assert mock_connector.connect.return_value == conn

def test_get_sqlalchemy_engine_should_support_pass_auth(self):
with unittest.mock.patch.dict(
'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()
), unittest.mock.patch(
'airflow.providers.snowflake.hooks.snowflake.create_engine'
) as mock_create_engine:
hook = SnowflakeHook(snowflake_conn_id='test_conn')
conn = hook.get_sqlalchemy_engine()
mock_create_engine.assert_called_once_with(
'snowflake://user:pw@airflow.af_region/db/public'
'?warehouse=af_wh&role=af_role&authenticator=snowflake'
)
assert mock_create_engine.return_value == conn

def test_get_sqlalchemy_engine_should_support_insecure_mode(self):
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs['extra']['extra__snowflake__insecure_mode'] = 'True'

with unittest.mock.patch.dict(
'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
), unittest.mock.patch(
'airflow.providers.snowflake.hooks.snowflake.create_engine'
) as mock_create_engine:
hook = SnowflakeHook(snowflake_conn_id='test_conn')
conn = hook.get_sqlalchemy_engine()
mock_create_engine.assert_called_once_with(
'snowflake://user:pw@airflow.af_region/db/public'
'?warehouse=af_wh&role=af_role&authenticator=snowflake',
connect_args={'insecure_mode': True},
)
assert mock_create_engine.return_value == conn

def test_hook_parameters_should_take_precedence(self):
with unittest.mock.patch.dict(
'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()
Expand Down