Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions airflow/providers/google/ads/hooks/ads.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow import AirflowException
from airflow.compat.functools import cached_property
from airflow.hooks.base import BaseHook
from airflow.providers.google.common.hooks.base_google import get_field


class GoogleAdsHook(BaseHook):
Expand Down Expand Up @@ -200,8 +201,10 @@ def _update_config_with_secret(self, secrets_temp: IO[str]) -> None:
Updates google ads config with file path of the temp file containing the secret
Note, the secret must be passed as a file path for Google Ads API
"""
secret_conn = self.get_connection(self.gcp_conn_id)
secret = secret_conn.extra_dejson["extra__google_cloud_platform__keyfile_dict"]
extras = self.get_connection(self.gcp_conn_id).extra_dejson
secret = get_field(extras, 'keyfile_dict')
if not secret:
raise KeyError("secret_conn.extra_dejson does not contain keyfile_dict")
secrets_temp.write(secret)
secrets_temp.flush()

Expand Down
21 changes: 9 additions & 12 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -156,15 +156,14 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
if engine_kwargs is None:
engine_kwargs = {}
connection = self.get_connection(self.gcp_conn_id)
if connection.extra_dejson.get("extra__google_cloud_platform__key_path"):
credentials_path = connection.extra_dejson['extra__google_cloud_platform__key_path']
extras = self.get_connection(self.gcp_conn_id).extra_dejson
credentials_path = get_field(extras, 'key_path')
if credentials_path:
return create_engine(self.get_uri(), credentials_path=credentials_path, **engine_kwargs)
elif connection.extra_dejson.get("extra__google_cloud_platform__keyfile_dict"):
credential_file_content = json.loads(
connection.extra_dejson["extra__google_cloud_platform__keyfile_dict"]
)
return create_engine(self.get_uri(), credentials_info=credential_file_content, **engine_kwargs)
keyfile_dict = get_field(extras, 'keyfile_dict')
if keyfile_dict:
keyfile_content = keyfile_dict if isinstance(keyfile_dict, dict) else json.loads(keyfile_dict)
return create_engine(self.get_uri(), credentials_info=keyfile_content, **engine_kwargs)
try:
# 1. If the environment variable GOOGLE_APPLICATION_CREDENTIALS is set
# ADC uses the service account key or configuration file that the variable points to.
Expand All @@ -175,9 +174,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
self.log.error(e)
raise AirflowException(
"For now, we only support instantiating SQLAlchemy engine by"
" using ADC"
", extra__google_cloud_platform__key_path"
"and extra__google_cloud_platform__keyfile_dict"
" using ADC or extra fields `key_path` and `keyfile_dict`."
)

def get_records(self, sql, parameters=None):
Expand Down
22 changes: 10 additions & 12 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# For requests that are "retriable"
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -375,9 +375,6 @@ def _wait_for_operation_to_complete(self, project_id: str, operation_name: str)
"https://storage.googleapis.com/cloudsql-proxy/{}/cloud_sql_proxy.{}.{}"
)

GCP_CREDENTIALS_KEY_PATH = "extra__google_cloud_platform__key_path"
GCP_CREDENTIALS_KEYFILE_DICT = "extra__google_cloud_platform__keyfile_dict"


class CloudSqlProxyRunner(LoggingMixin):
"""
Expand Down Expand Up @@ -484,15 +481,16 @@ def _download_sql_proxy_if_needed(self) -> None:
self.sql_proxy_was_downloaded = True

def _get_credential_parameters(self) -> list[str]:
connection = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id)

if connection.extra_dejson.get(GCP_CREDENTIALS_KEY_PATH):
credential_params = ['-credential_file', connection.extra_dejson[GCP_CREDENTIALS_KEY_PATH]]
elif connection.extra_dejson.get(GCP_CREDENTIALS_KEYFILE_DICT):
credential_file_content = json.loads(connection.extra_dejson[GCP_CREDENTIALS_KEYFILE_DICT])
extras = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id).extra_dejson
key_path = get_field(extras, 'key_path')
keyfile_dict = get_field(extras, 'keyfile_dict')
if key_path:
credential_params = ['-credential_file', key_path]
elif keyfile_dict:
keyfile_content = keyfile_dict if isinstance(keyfile_dict, dict) else json.loads(keyfile_dict)
self.log.info("Saving credentials to %s", self.credentials_path)
with open(self.credentials_path, "w") as file:
json.dump(credential_file_content, file)
json.dump(keyfile_content, file)
credential_params = ['-credential_file', self.credentials_path]
else:
self.log.info(
Expand All @@ -504,7 +502,7 @@ def _get_credential_parameters(self) -> list[str]:
credential_params = []

if not self.instance_specification:
project_id = connection.extra_dejson.get('extra__google_cloud_platform__project')
project_id = get_field(extras, 'project')
if self.project_id:
project_id = self.project_id
if not project_id:
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/google/cloud/operators/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink
from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator
from airflow.providers.google.common.hooks.base_google import get_field
from airflow.providers.google.common.links.storage import FileDetailsLink
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
Expand Down Expand Up @@ -1092,9 +1093,7 @@ def execute(self, context: Context):
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
gcp_conn_id=self.gcp_conn_id,
default_gcp_project_id=self.gcp_connection.extra_dejson.get(
'extra__google_cloud_platform__project'
),
default_gcp_project_id=get_field(self.gcp_connection.extra_dejson, 'project'),
)
hook.validate_ssl_certs()
connection = hook.create_connection()
Expand Down
8 changes: 3 additions & 5 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,14 @@ def build_gcp_conn(
:return: String representing Airflow connection.
"""
conn = "google-cloud-platform://?{}"
extras = "extra__google_cloud_platform"

query_params = {}
if key_file_path:
query_params[f"{extras}__key_path"] = key_file_path
query_params["key_path"] = key_file_path
if scopes:
scopes_string = ",".join(scopes)
query_params[f"{extras}__scope"] = scopes_string
query_params["scope"] = scopes_string
if project_id:
query_params[f"{extras}__projects"] = project_id
query_params["projects"] = project_id

query = urlencode(query_params)
return conn.format(query)
Expand Down
41 changes: 21 additions & 20 deletions airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ def __init__(self):
RT = TypeVar('RT')


def get_field(extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
if field_name.startswith('extra__'):
raise ValueError(
f"Got prefixed name {field_name}; please remove the 'extra__google_cloud_platform__' prefix "
"when using this method."
)
if field_name in extras:
return extras[field_name] or None
prefixed_name = f"extra__google_cloud_platform__{field_name}"
return extras.get(prefixed_name) or None


class GoogleBaseHook(BaseHook):
"""
A base hook for Google cloud-related hooks. Google cloud has a shared REST
Expand Down Expand Up @@ -179,25 +192,17 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms.validators import NumberRange

return {
"extra__google_cloud_platform__project": StringField(
lazy_gettext('Project Id'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__key_path": StringField(
lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__keyfile_dict": PasswordField(
lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget()
),
"extra__google_cloud_platform__scope": StringField(
lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__key_secret_name": StringField(
"project": StringField(lazy_gettext('Project Id'), widget=BS3TextFieldWidget()),
"key_path": StringField(lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget()),
"keyfile_dict": PasswordField(lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget()),
"scope": StringField(lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget()),
"key_secret_name": StringField(
lazy_gettext('Keyfile Secret Name (in GCP Secret Manager)'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__key_secret_project_id": StringField(
"key_secret_project_id": StringField(
lazy_gettext('Keyfile Secret Project Id (in GCP Secret Manager)'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__num_retries": IntegerField(
"num_retries": IntegerField(
lazy_gettext('Number of Retries'),
validators=[NumberRange(min=0)],
widget=BS3TextFieldWidget(),
Expand Down Expand Up @@ -325,11 +330,7 @@ def _get_field(self, f: str, default: Any = None) -> Any:
to the hook page, which allow admins to specify service_account,
key_path, etc. They get formatted as shown below.
"""
long_f = f'extra__google_cloud_platform__{f}'
if hasattr(self, 'extras') and long_f in self.extras:
return self.extras[long_f]
else:
return default
return hasattr(self, 'extras') and get_field(self.extras, f) or default

@property
def project_id(self) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,5 +147,5 @@ For connecting to a google cloud conn, all the fields must be in the extra field

.. code-block:: ini

{'extra__google_cloud_platform__key_path': '/opt/airflow/service_account.json',
'extra__google_cloud_platform__scope': 'https://www.googleapis.com/auth/devstorage.read_only'}
{'key_path': '/opt/airflow/service_account.json',
'scope': 'https://www.googleapis.com/auth/devstorage.read_only'}
24 changes: 15 additions & 9 deletions docs/apache-airflow-providers-google/connections/gcp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,21 +124,27 @@ Number of Retries
* query parameters contains information specific to this type of
connection. The following keys are accepted:

* ``extra__google_cloud_platform__project`` - Project Id
* ``extra__google_cloud_platform__key_path`` - Keyfile Path
* ``extra__google_cloud_platform__keyfile_dict`` - Keyfile JSON
* ``extra__google_cloud_platform__key_secret_name`` - Secret name which holds Keyfile JSON
* ``extra__google_cloud_platform__key_secret_project_id`` - Project Id which holds Keyfile JSON
* ``extra__google_cloud_platform__scope`` - Scopes
* ``extra__google_cloud_platform__num_retries`` - Number of Retries
* ``project`` - Project Id
* ``key_path`` - Keyfile Path
* ``keyfile_dict`` - Keyfile JSON
* ``key_secret_name`` - Secret name which holds Keyfile JSON
* ``key_secret_project_id`` - Project Id which holds Keyfile JSON
* ``scope`` - Scopes
* ``num_retries`` - Number of Retries

Note that all components of the URI should be URL-encoded.

For example:
For example, with URI format:

.. code-block:: bash

export AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT='google-cloud-platform://?extra__google_cloud_platform__key_path=%2Fkeys%2Fkey.json&extra__google_cloud_platform__scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&extra__google_cloud_platform__project=airflow&extra__google_cloud_platform__num_retries=5'
export AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT='google-cloud-platform://?key_path=%2Fkeys%2Fkey.json&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&project=airflow&num_retries=5'

And using JSON format:

.. code-block:: bash

export AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT='{"conn_type": "google-cloud-platform", "key_path": "/keys/key.json", "scope": "https://www.googleapis.com/auth/cloud-platform", "project": "airflow", "num_retries": 5}'

.. _howto/connection:gcp:impersonation:

Expand Down
24 changes: 12 additions & 12 deletions docs/apache-airflow-providers-google/connections/gcp_ssh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ Extra (optional)
connection. The following parameters are supported in addition to those describing
the Google Cloud connection.

* ``extra__google_cloud_platform__instance_name`` - The name of the Compute Engine instance.
* ``extra__google_cloud_platform__zone`` - The zone of the Compute Engine instance.
* ``extra__google_cloud_platform__use_internal_ip`` - Whether to connect using internal IP.
* ``extra__google_cloud_platform__use_iap_tunnel`` - Whether to connect through IAP tunnel.
* ``extra__google_cloud_platform__use_oslogin`` - Whether to manage keys using OsLogin API. If false, keys are managed using instance metadata.
* ``extra__google_cloud_platform__expire_time`` - The maximum amount of time in seconds before the private key expires.
* ``instance_name`` - The name of the Compute Engine instance.
* ``zone`` - The zone of the Compute Engine instance.
* ``use_internal_ip`` - Whether to connect using internal IP.
* ``use_iap_tunnel`` - Whether to connect through IAP tunnel.
* ``use_oslogin`` - Whether to manage keys using OsLogin API. If false, keys are managed using instance metadata.
* ``expire_time`` - The maximum amount of time in seconds before the private key expires.


Environment variable
Expand All @@ -64,9 +64,9 @@ For example:
.. code-block:: bash

export AIRFLOW_CONN_GOOGLE_CLOUD_SQL_DEFAULT="gcpssh://conn-user@conn-host?\
extra__google_cloud_platform__instance_name=conn-instance-name&\
extra__google_cloud_platform__zone=zone&\
extra__google_cloud_platform__use_internal_ip=True&\
extra__google_cloud_platform__use_iap_tunnel=True&\
extra__google_cloud_platform__use_oslogin=False&\
extra__google_cloud_platform__expire_time=4242"
instance_name=conn-instance-name&\
zone=zone&\
use_internal_ip=True&\
use_iap_tunnel=True&\
use_oslogin=False&\
expire_time=4242"
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ raise an exception. The following is a sample file.
.. code-block:: text

mysql_conn_id=mysql://log:password@13.1.21.1:3306/mysqldbrd
google_custom_key=google-cloud-platform://?extra__google_cloud_platform__key_path=%2Fkeys%2Fkey.json
google_custom_key=google-cloud-platform://?key_path=%2Fkeys%2Fkey.json

Storing and Retrieving Variables
""""""""""""""""""""""""""""""""
Expand Down
12 changes: 5 additions & 7 deletions tests/always/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,8 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
description='no schema',
),
UriTestCaseConfig(
test_conn_uri='google-cloud-platform://?extra__google_cloud_platform__key_'
'path=%2Fkeys%2Fkey.json&extra__google_cloud_platform__scope='
'https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&extra'
'__google_cloud_platform__project=airflow',
test_conn_uri='google-cloud-platform://?key_path=%2Fkeys%2Fkey.json&scope='
'https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&project=airflow',
test_conn_attributes=dict(
conn_type='google_cloud_platform',
host='',
Expand All @@ -287,9 +285,9 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
password=None,
port=None,
extra_dejson=dict(
extra__google_cloud_platform__key_path='/keys/key.json',
extra__google_cloud_platform__scope='https://www.googleapis.com/auth/cloud-platform',
extra__google_cloud_platform__project='airflow',
key_path='/keys/key.json',
scope='https://www.googleapis.com/auth/cloud-platform',
project='airflow',
),
),
description='with underscore',
Expand Down
20 changes: 10 additions & 10 deletions tests/always/test_secrets_local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def test_missing_file(self, mock_exists):
extra_dejson:
arbitrary_dict:
a: b
extra__google_cloud_platform__keyfile_dict: '{"a": "b"}'
extra__google_cloud_platform__keyfile_path: asaa""",
keyfile_dict: '{"a": "b"}'
keyfile_path: asaa""",
{
"conn_a": {'conn_type': 'mysql', 'host': 'hosta'},
"conn_b": {
Expand All @@ -270,8 +270,8 @@ def test_missing_file(self, mock_exists):
'port': 1234,
'extra_dejson': {
'arbitrary_dict': {"a": "b"},
'extra__google_cloud_platform__keyfile_dict': '{"a": "b"}',
'extra__google_cloud_platform__keyfile_path': 'asaa',
'keyfile_dict': '{"a": "b"}',
'keyfile_path': 'asaa',
},
},
},
Expand Down Expand Up @@ -314,14 +314,14 @@ def test_yaml_file_should_load_connection(self, file_content, expected_attrs_dic
password: None
port: 1234
extra_dejson:
extra__google_cloud_platform__keyfile_dict:
keyfile_dict:
a: b
extra__google_cloud_platform__key_path: xxx
key_path: xxx
""",
{
"conn_d": {
"extra__google_cloud_platform__keyfile_dict": {"a": "b"},
"extra__google_cloud_platform__key_path": "xxx",
"keyfile_dict": {"a": "b"},
"key_path": "xxx",
}
},
),
Expand All @@ -334,9 +334,9 @@ def test_yaml_file_should_load_connection(self, file_content, expected_attrs_dic
login: Login
password: None
port: 1234
extra: '{\"extra__google_cloud_platform__keyfile_dict\": {\"a\": \"b\"}}'
extra: '{\"keyfile_dict\": {\"a\": \"b\"}}'
""",
{"conn_d": {"extra__google_cloud_platform__keyfile_dict": {"a": "b"}}},
{"conn_d": {"keyfile_dict": {"a": "b"}}},
),
],
)
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/google/ads/hooks/test_ads.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ADS_CLIENT = {"key": "value"}
SECRET = "secret"
EXTRAS = {
"extra__google_cloud_platform__keyfile_dict": SECRET,
"keyfile_dict": SECRET,
"google_ads_client": ADS_CLIENT,
}

Expand Down
Loading