diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index f357ce88d8a37..80a40bb734954 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -529,6 +529,11 @@ def conn_config(self) -> AwsConnectionWrapper: conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify ) + @property + def service_config(self) -> dict: + service_name = self.client_type or self.resource_type + return self.conn_config.get_service_config(service_name) + @property def region_name(self) -> str | None: """AWS Region Name read-only property.""" diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index b474cc12a7d1a..89c9261cb6d40 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -24,6 +24,7 @@ import logging import re import shutil +import warnings from contextlib import suppress from copy import deepcopy from datetime import datetime @@ -64,7 +65,17 @@ def wrapper(*args, **kwargs) -> T: if "bucket_name" not in bound_args.arguments: self = args[0] - if self.conn_config and self.conn_config.schema: + + if "bucket_name" in self.service_config: + bound_args.arguments["bucket_name"] = self.service_config["bucket_name"] + elif self.conn_config and self.conn_config.schema: + warnings.warn( + "s3 conn_type, and the associated schema field, is deprecated." + " Please use aws conn_type instead, and specify `bucket_name`" + " in `service_config.s3` within `extras`.", + DeprecationWarning, + stacklevel=2, + ) bound_args.arguments["bucket_name"] = self.conn_config.schema return func(*bound_args.args, **bound_args.kwargs) diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py index 237109fa75ae1..f5a30240d4b2d 100644 --- a/airflow/providers/amazon/aws/utils/connection_wrapper.py +++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -125,6 +125,9 @@ class AwsConnectionWrapper(LoggingMixin): def conn_repr(self): return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})" + def get_service_config(self, service_name): + return self.extra_dejson.get("service_config", {}).get(service_name, {}) + def __post_init__(self, conn: Connection): if isinstance(conn, type(self)): # For every field with init=False we copy reference value from original wrapper diff --git a/docs/apache-airflow-providers-amazon/connections/aws.rst b/docs/apache-airflow-providers-amazon/connections/aws.rst index e885887980d17..ef667756e60da 100644 --- a/docs/apache-airflow-providers-amazon/connections/aws.rst +++ b/docs/apache-airflow-providers-amazon/connections/aws.rst @@ -67,6 +67,8 @@ Extra (optional) Specify the extra parameters (as json dictionary) that can be used in AWS connection. All parameters are optional. + * ``service_config``: json used to specify configuration/parameters for different AWS services, such as S3 or STS. + The following extra parameters used to create an initial :external:py:class:`boto3.session.Session`: * ``aws_access_key_id``: AWS access key ID used for the initial connection. @@ -269,6 +271,23 @@ This assumes all other Connection fields eg **AWS Access Key ID** or **AWS Secre "assume_role_kwargs": { "something":"something" } } +5. Using **service_config** to specify configuration for services such as S3, STS, and EMR + +.. code-block:: json + + { + "service_config": { + "s3": { + "bucket_name": "awesome-bucket" + }, + "sts": { + "endpoint_url": "https://example.org" + }, + "emr": { + "job_flow_overrides": {"Name": "PiCalc", "ReleaseLabel": "emr-6.7.0"}, + "endpoint_url": "https://emr.example.org" + } + } The following settings may be used within the ``assume_role_with_saml`` container in Extra. diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index b4dcc0e1daba6..6ee49ffb2671c 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -467,16 +467,24 @@ def test_delete_bucket_if_not_bucket_exist(self, s3_bucket): assert mock_hook.delete_bucket(bucket_name=s3_bucket, force_delete=True) assert ctx.value.response["Error"]["Code"] == "NoSuchBucket" - @mock.patch.object(S3Hook, "get_connection", return_value=Connection(schema="test_bucket")) + @mock.patch.object( + S3Hook, + "get_connection", + return_value=Connection(extra={"service_config": {"s3": {"bucket_name": "bucket_name"}}}), + ) def test_provide_bucket_name(self, mock_get_connection): class FakeS3Hook(S3Hook): @provide_bucket_name def test_function(self, bucket_name=None): return bucket_name - hook = FakeS3Hook() - assert hook.test_function() == "test_bucket" - assert hook.test_function(bucket_name="bucket") == "bucket" + fake_s3_hook = FakeS3Hook() + + test_bucket_name = fake_s3_hook.test_function() + assert test_bucket_name == "bucket_name" + + test_bucket_name = fake_s3_hook.test_function(bucket_name="bucket") + assert test_bucket_name == "bucket" def test_delete_objects_key_does_not_exist(self, s3_bucket): # The behaviour of delete changed in recent version of s3 mock libraries.