diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py index c58961db2e8f2..2630098bc4650 100644 --- a/airflow/providers/amazon/aws/operators/rds.py +++ b/airflow/providers/amazon/aws/operators/rds.py @@ -43,18 +43,27 @@ class RdsBaseOperator(BaseOperator): ui_color = "#eeaa88" ui_fgcolor = "#ffffff" - def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: dict | None = None, **kwargs): + def __init__( + self, + *args, + aws_conn_id: str = "aws_conn_id", + region_name: str | None = None, + hook_params: dict | None = None, + **kwargs, + ): if hook_params is not None: warnings.warn( "The parameter hook_params is deprecated and will be removed. " - "If you were using it, please get in touch either on airflow slack, " - "or by opening a github issue on the project. " + "Note that it is also incompatible with deferrable mode. " + "You can use the region_name parameter to specify the region. " + "If you were using hook_params for other purposes, please get in touch either on " + "airflow slack, or by opening a github issue on the project. " "You can mention https://github.com/apache/airflow/pull/32352", AirflowProviderDeprecationWarning, stacklevel=3, # 2 is in the operator's init, 3 is in the user code creating the operator ) - self.hook_params = hook_params or {} - self.hook = RdsHook(aws_conn_id=aws_conn_id, **self.hook_params) + self.region_name = region_name + self.hook = RdsHook(aws_conn_id=aws_conn_id, region_name=region_name, **(hook_params or {})) super().__init__(*args, **kwargs) self._await_interval = 60 # seconds @@ -588,7 +597,7 @@ def execute(self, context: Context) -> str: waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, - hook_params=self.hook_params, + region_name=self.region_name, waiter_name="db_instance_available", # ignoring type because create_db_instance is a dict response=create_db_instance, # type: ignore[arg-type] @@ -674,7 +683,7 @@ def execute(self, context: Context) -> str: waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, - hook_params=self.hook_params, + region_name=self.region_name, waiter_name="db_instance_deleted", # ignoring type because delete_db_instance is a dict response=delete_db_instance, # type: ignore[arg-type] diff --git a/airflow/providers/amazon/aws/triggers/rds.py b/airflow/providers/amazon/aws/triggers/rds.py index 0897f764bece9..0551d67591dfa 100644 --- a/airflow/providers/amazon/aws/triggers/rds.py +++ b/airflow/providers/amazon/aws/triggers/rds.py @@ -47,14 +47,14 @@ def __init__( waiter_delay: int, waiter_max_attempts: int, aws_conn_id: str, - hook_params: dict[str, Any], + region_name: str | None, response: dict[str, Any], ): self.db_instance_identifier = db_instance_identifier self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id - self.hook_params = hook_params + self.region_name = region_name self.waiter_name = waiter_name self.response = response @@ -67,14 +67,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "waiter_delay": str(self.waiter_delay), "waiter_max_attempts": str(self.waiter_max_attempts), "aws_conn_id": self.aws_conn_id, - "hook_params": self.hook_params, + "region_name": self.region_name, "waiter_name": self.waiter_name, "response": self.response, }, ) async def run(self): - self.hook = RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params) + self.hook = RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) async with self.hook.async_conn as client: waiter = client.get_waiter(self.waiter_name) await async_wait( diff --git a/tests/providers/amazon/aws/triggers/test_rds.py b/tests/providers/amazon/aws/triggers/test_rds.py index 5ae64b83c3bde..9c518c8eee5f2 100644 --- a/tests/providers/amazon/aws/triggers/test_rds.py +++ b/tests/providers/amazon/aws/triggers/test_rds.py @@ -31,6 +31,7 @@ TEST_WAITER_DELAY = 10 TEST_WAITER_MAX_ATTEMPTS = 10 TEST_AWS_CONN_ID = "test-aws-id" +TEST_REGION = "sa-east-1" TEST_RESPONSE = { "DBInstance": { "DBInstanceIdentifier": "test-db-instance-identifier", @@ -47,7 +48,7 @@ def test_rds_db_instance_trigger_serialize(self): waiter_delay=TEST_WAITER_DELAY, waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, aws_conn_id=TEST_AWS_CONN_ID, - hook_params={}, + region_name=TEST_REGION, response=TEST_RESPONSE, ) class_path, args = rds_db_instance_trigger.serialize() @@ -58,7 +59,7 @@ def test_rds_db_instance_trigger_serialize(self): assert args["waiter_delay"] == str(TEST_WAITER_DELAY) assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS) assert args["aws_conn_id"] == TEST_AWS_CONN_ID - assert args["hook_params"] == {} + assert args["region_name"] == TEST_REGION assert args["response"] == TEST_RESPONSE @pytest.mark.asyncio @@ -75,7 +76,7 @@ async def test_rds_db_instance_trigger_run(self, mock_async_conn): waiter_delay=TEST_WAITER_DELAY, waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, aws_conn_id=TEST_AWS_CONN_ID, - hook_params={}, + region_name=TEST_REGION, response=TEST_RESPONSE, ) @@ -104,7 +105,7 @@ async def test_rds_db_instance_trigger_run_multiple_attempts(self, mock_async_co waiter_delay=TEST_WAITER_DELAY, waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, aws_conn_id=TEST_AWS_CONN_ID, - hook_params={}, + region_name=TEST_REGION, response=TEST_RESPONSE, ) @@ -135,7 +136,7 @@ async def test_rds_db_instance_trigger_run_attempts_exceeded(self, mock_async_co waiter_delay=TEST_WAITER_DELAY, waiter_max_attempts=2, aws_conn_id=TEST_AWS_CONN_ID, - hook_params={}, + region_name=TEST_REGION, response=TEST_RESPONSE, ) @@ -173,7 +174,7 @@ async def test_rds_db_instance_trigger_run_attempts_failed(self, mock_async_conn waiter_delay=TEST_WAITER_DELAY, waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, aws_conn_id=TEST_AWS_CONN_ID, - hook_params={}, + region_name=TEST_REGION, response=TEST_RESPONSE, )