Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow/providers/amazon/aws/links/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import TYPE_CHECKING, ClassVar

from airflow.models import BaseOperatorLink, XCom
from airflow.providers.amazon.aws.utils.suppress import return_on_error

if TYPE_CHECKING:
from airflow.models import BaseOperator
Expand Down Expand Up @@ -60,6 +61,7 @@ def format_link(self, **kwargs) -> str:
except KeyError:
return ""

@return_on_error("")
def get_link(
self,
operator: BaseOperator,
Expand All @@ -77,6 +79,7 @@ def get_link(
return self.format_link(**conf) if conf else ""

@classmethod
@return_on_error(None)
def persist(
cls, context: Context, operator: BaseOperator, region_name: str, aws_partition: str, **kwargs
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/links/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class EmrLogsLink(BaseAwsLink):
format_str = BASE_AWS_CONSOLE_LINK + "/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/"

def format_link(self, **kwargs) -> str:
if not kwargs["log_uri"]:
if not kwargs.get("log_uri"):
return ""
return super().format_link(**kwargs)

Expand Down
35 changes: 30 additions & 5 deletions tests/providers/amazon/aws/links/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

from abc import abstractmethod
from typing import TYPE_CHECKING, NamedTuple
from unittest.mock import MagicMock
from unittest import mock

import pytest

from airflow.models.xcom import XCom
from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink
from airflow.serialization.serialized_objects import SerializedDAG
from tests.test_utils.mock_operators import MockOperator

if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.models.taskinstance import TaskInstance

XCOM_KEY = "test_xcom_key"
CUSTOM_KEYS = {
Expand Down Expand Up @@ -63,7 +64,7 @@ class TestBaseAwsLink:
],
)
def test_persist(self, region_name, aws_partition, keywords, expected_value):
mock_context = MagicMock()
mock_context = mock.MagicMock()

SimpleBaseAwsLink.persist(
context=mock_context,
Expand All @@ -81,7 +82,7 @@ def test_persist(self, region_name, aws_partition, keywords, expected_value):
)

def test_disable_xcom_push(self):
mock_context = MagicMock()
mock_context = mock.MagicMock()
SimpleBaseAwsLink.persist(
context=mock_context,
operator=MockOperator(task_id="test_task_id", do_xcom_push=False),
Expand All @@ -91,6 +92,21 @@ def test_disable_xcom_push(self):
ti = mock_context["ti"]
ti.xcom_push.assert_not_called()

def test_suppress_error_on_xcom_push(self):
mock_context = mock.MagicMock()
with mock.patch.object(MockOperator, "xcom_push", side_effect=PermissionError("FakeError")) as m:
SimpleBaseAwsLink.persist(
context=mock_context,
operator=MockOperator(task_id="test_task_id"),
region_name="eu-east-1",
aws_partition="aws",
)
m.assert_called_once_with(
mock_context,
key="test_xcom_key",
value={"region_name": "eu-east-1", "aws_domain": "aws.amazon.com"},
)


def link_test_operator(*links):
"""Helper for create mock operator class with extra links"""
Expand Down Expand Up @@ -162,7 +178,7 @@ def assert_extra_link_url(
"""Helper method for create extra link URL from the parameters."""
task, ti = self.create_op_and_ti(self.link_class, dag_id="test_extra_link", task_id=self.task_id)

mock_context = MagicMock()
mock_context = mock.MagicMock()
mock_context.__getitem__.side_effect = {"ti": ti}.__getitem__

self.link_class.persist(
Expand Down Expand Up @@ -209,6 +225,15 @@ def test_empty_xcom(self):
deserialized_task.get_extra_links(ti, self.link_class.name) == ""
), "Operator link should be empty for deserialized task with no XCom push"

def test_suppress_error_on_xcom_pull(self):
"""Test ignore any error on XCom pull"""
with mock.patch.object(XCom, "get_value", side_effect=OSError("FakeError")) as m:
op, ti = self.create_op_and_ti(
self.link_class, dag_id="test_error_on_xcom_pull", task_id=self.task_id
)
self.link_class().get_link(op, ti_key=ti.key)
m.assert_called_once()

@abstractmethod
def test_extra_link(self, **kwargs):
"""Test: Expected URL Link."""
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/links/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_extra_link(self):
@pytest.mark.parametrize(
"log_url_extra",
[
pytest.param({}, id="no-log-uri", marks=pytest.mark.xfail),
pytest.param({}, id="no-log-uri"),
pytest.param({"log_uri": None}, id="log-uri-none"),
pytest.param({"log_uri": ""}, id="log-uri-empty"),
],
Expand Down