diff --git a/airflow/providers/amazon/aws/links/base_aws.py b/airflow/providers/amazon/aws/links/base_aws.py index 97130fabd6cc7..83da4dd93cba9 100644 --- a/airflow/providers/amazon/aws/links/base_aws.py +++ b/airflow/providers/amazon/aws/links/base_aws.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, ClassVar from airflow.models import BaseOperatorLink, XCom +from airflow.providers.amazon.aws.utils.suppress import return_on_error if TYPE_CHECKING: from airflow.models import BaseOperator @@ -60,6 +61,7 @@ def format_link(self, **kwargs) -> str: except KeyError: return "" + @return_on_error("") def get_link( self, operator: BaseOperator, @@ -77,6 +79,7 @@ def get_link( return self.format_link(**conf) if conf else "" @classmethod + @return_on_error(None) def persist( cls, context: Context, operator: BaseOperator, region_name: str, aws_partition: str, **kwargs ) -> None: diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index 6c8cd2181eee1..1bd651a00cfb0 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -43,7 +43,7 @@ class EmrLogsLink(BaseAwsLink): format_str = BASE_AWS_CONSOLE_LINK + "/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/" def format_link(self, **kwargs) -> str: - if not kwargs["log_uri"]: + if not kwargs.get("log_uri"): return "" return super().format_link(**kwargs) diff --git a/tests/providers/amazon/aws/links/test_base_aws.py b/tests/providers/amazon/aws/links/test_base_aws.py index a8bf17c3db4ea..546ead164d6e0 100644 --- a/tests/providers/amazon/aws/links/test_base_aws.py +++ b/tests/providers/amazon/aws/links/test_base_aws.py @@ -18,16 +18,17 @@ from abc import abstractmethod from typing import TYPE_CHECKING, NamedTuple -from unittest.mock import MagicMock +from unittest import mock import pytest +from airflow.models.xcom import XCom from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink from airflow.serialization.serialized_objects import SerializedDAG from tests.test_utils.mock_operators import MockOperator if TYPE_CHECKING: - from airflow.models import TaskInstance + from airflow.models.taskinstance import TaskInstance XCOM_KEY = "test_xcom_key" CUSTOM_KEYS = { @@ -63,7 +64,7 @@ class TestBaseAwsLink: ], ) def test_persist(self, region_name, aws_partition, keywords, expected_value): - mock_context = MagicMock() + mock_context = mock.MagicMock() SimpleBaseAwsLink.persist( context=mock_context, @@ -81,7 +82,7 @@ def test_persist(self, region_name, aws_partition, keywords, expected_value): ) def test_disable_xcom_push(self): - mock_context = MagicMock() + mock_context = mock.MagicMock() SimpleBaseAwsLink.persist( context=mock_context, operator=MockOperator(task_id="test_task_id", do_xcom_push=False), @@ -91,6 +92,21 @@ def test_disable_xcom_push(self): ti = mock_context["ti"] ti.xcom_push.assert_not_called() + def test_suppress_error_on_xcom_push(self): + mock_context = mock.MagicMock() + with mock.patch.object(MockOperator, "xcom_push", side_effect=PermissionError("FakeError")) as m: + SimpleBaseAwsLink.persist( + context=mock_context, + operator=MockOperator(task_id="test_task_id"), + region_name="eu-east-1", + aws_partition="aws", + ) + m.assert_called_once_with( + mock_context, + key="test_xcom_key", + value={"region_name": "eu-east-1", "aws_domain": "aws.amazon.com"}, + ) + def link_test_operator(*links): """Helper for create mock operator class with extra links""" @@ -162,7 +178,7 @@ def assert_extra_link_url( """Helper method for create extra link URL from the parameters.""" task, ti = self.create_op_and_ti(self.link_class, dag_id="test_extra_link", task_id=self.task_id) - mock_context = MagicMock() + mock_context = mock.MagicMock() mock_context.__getitem__.side_effect = {"ti": ti}.__getitem__ self.link_class.persist( @@ -209,6 +225,15 @@ def test_empty_xcom(self): deserialized_task.get_extra_links(ti, self.link_class.name) == "" ), "Operator link should be empty for deserialized task with no XCom push" + def test_suppress_error_on_xcom_pull(self): + """Test ignore any error on XCom pull""" + with mock.patch.object(XCom, "get_value", side_effect=OSError("FakeError")) as m: + op, ti = self.create_op_and_ti( + self.link_class, dag_id="test_error_on_xcom_pull", task_id=self.task_id + ) + self.link_class().get_link(op, ti_key=ti.key) + m.assert_called_once() + @abstractmethod def test_extra_link(self, **kwargs): """Test: Expected URL Link.""" diff --git a/tests/providers/amazon/aws/links/test_emr.py b/tests/providers/amazon/aws/links/test_emr.py index 59c883362a92a..c7f12983e8fb7 100644 --- a/tests/providers/amazon/aws/links/test_emr.py +++ b/tests/providers/amazon/aws/links/test_emr.py @@ -68,7 +68,7 @@ def test_extra_link(self): @pytest.mark.parametrize( "log_url_extra", [ - pytest.param({}, id="no-log-uri", marks=pytest.mark.xfail), + pytest.param({}, id="no-log-uri"), pytest.param({"log_uri": None}, id="log-uri-none"), pytest.param({"log_uri": ""}, id="log-uri-empty"), ],