Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/datacatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from google.api_core.retry import Retry
from google.cloud import datacatalog
from google.cloud.datacatalog_v1beta1 import (
from google.cloud.datacatalog import (
CreateTagRequest,
DataCatalogClient,
Entry,
Expand Down
9 changes: 6 additions & 3 deletions tests/providers/amazon/aws/operators/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

DEFAULT_DATE = timezone.datetime(2019, 1, 1)

QUEUE_NAME = 'test-queue'
QUEUE_URL = f'https://{QUEUE_NAME}'


class TestSQSPublishOperator(unittest.TestCase):
def setUp(self):
Expand All @@ -38,7 +41,7 @@ def setUp(self):
self.operator = SQSPublishOperator(
task_id='test_task',
dag=self.dag,
sqs_queue='test',
sqs_queue=QUEUE_URL,
message_content='hello',
aws_conn_id='aws_default',
)
Expand All @@ -48,13 +51,13 @@ def setUp(self):

@mock_sqs
def test_execute_success(self):
self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)

result = self.operator.execute(self.mock_context)
assert 'MD5OfMessageBody' in result
assert 'MessageId' in result

message = self.sqs_hook.get_conn().receive_message(QueueUrl='test')
message = self.sqs_hook.get_conn().receive_message(QueueUrl=QUEUE_URL)

assert len(message['Messages']) == 1
assert message['Messages'][0]['MessageId'] == result['MessageId']
Expand Down
39 changes: 21 additions & 18 deletions tests/providers/amazon/aws/sensors/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,26 @@

DEFAULT_DATE = timezone.datetime(2017, 1, 1)

QUEUE_NAME = 'test-queue'
QUEUE_URL = f'https://{QUEUE_NAME}'


class TestSQSSensor(unittest.TestCase):
def setUp(self):
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

self.dag = DAG('test_dag_id', default_args=args)
self.sensor = SQSSensor(
task_id='test_task', dag=self.dag, sqs_queue='test', aws_conn_id='aws_default'
task_id='test_task', dag=self.dag, sqs_queue=QUEUE_URL, aws_conn_id='aws_default'
)

self.mock_context = mock.MagicMock()
self.sqs_hook = SQSHook()

@mock_sqs
def test_poke_success(self):
self.sqs_hook.create_queue('test')
self.sqs_hook.send_message(queue_url='test', message_body='hello')
self.sqs_hook.create_queue(QUEUE_NAME)
self.sqs_hook.send_message(queue_url=QUEUE_URL, message_body='hello')

result = self.sensor.poke(self.mock_context)
assert result
Expand All @@ -60,7 +63,7 @@ def test_poke_success(self):
@mock_sqs
def test_poke_no_message_failed(self):

self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)
result = self.sensor.poke(self.mock_context)
assert not result

Expand Down Expand Up @@ -112,40 +115,40 @@ def test_poke_receive_raise_exception(self, mock_conn):
@mock.patch.object(SQSHook, 'get_conn')
def test_poke_visibility_timeout(self, mock_conn):
# Check without visibility_timeout parameter
self.sqs_hook.create_queue('test')
self.sqs_hook.send_message(queue_url='test', message_body='hello')
self.sqs_hook.create_queue(QUEUE_NAME)
self.sqs_hook.send_message(queue_url=QUEUE_URL, message_body='hello')

self.sensor.poke(self.mock_context)

calls_receive_message = [
mock.call().receive_message(QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1)
mock.call().receive_message(QueueUrl=QUEUE_URL, MaxNumberOfMessages=5, WaitTimeSeconds=1)
]
mock_conn.assert_has_calls(calls_receive_message)
# Check with visibility_timeout parameter
self.sensor = SQSSensor(
task_id='test_task2',
dag=self.dag,
sqs_queue='test',
sqs_queue=QUEUE_URL,
aws_conn_id='aws_default',
visibility_timeout=42,
)
self.sensor.poke(self.mock_context)

calls_receive_message = [
mock.call().receive_message(
QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1, VisibilityTimeout=42
QueueUrl=QUEUE_URL, MaxNumberOfMessages=5, WaitTimeSeconds=1, VisibilityTimeout=42
)
]
mock_conn.assert_has_calls(calls_receive_message)

@mock_sqs
def test_poke_message_invalid_filtering(self):
self.sqs_hook.create_queue('test')
self.sqs_hook.send_message(queue_url='test', message_body='hello')
self.sqs_hook.create_queue(QUEUE_NAME)
self.sqs_hook.send_message(queue_url=QUEUE_URL, message_body='hello')
sensor = SQSSensor(
task_id='test_task2',
dag=self.dag,
sqs_queue='test',
sqs_queue=QUEUE_URL,
aws_conn_id='aws_default',
message_filtering='invalid_option',
)
Expand All @@ -155,7 +158,7 @@ def test_poke_message_invalid_filtering(self):

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_literal_values(self, mock_conn):
self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)
matching = [{"id": 11, "body": "a matching message"}]
non_matching = [{"id": 12, "body": "a non-matching message"}]
all = matching + non_matching
Expand Down Expand Up @@ -188,13 +191,13 @@ def mock_delete_message_batch(**kwargs):
# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
mock.call().delete_message_batch(QueueUrl=QUEUE_URL, Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_jsonpath(self, mock_conn):
self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)
matching = [
{"id": 11, "key": {"matches": [1, 2]}},
{"id": 12, "key": {"matches": [3, 4, 5]}},
Expand Down Expand Up @@ -234,13 +237,13 @@ def mock_delete_message_batch(**kwargs):
# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
mock.call().delete_message_batch(QueueUrl=QUEUE_URL, Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_jsonpath_values(self, mock_conn):
self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)
matching = [
{"id": 11, "key": {"matches": [1, 2]}},
{"id": 12, "key": {"matches": [1, 4, 5]}},
Expand Down Expand Up @@ -282,6 +285,6 @@ def mock_delete_message_batch(**kwargs):
# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
mock.call().delete_message_batch(QueueUrl='https://test-queue', Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)
3 changes: 1 addition & 2 deletions tests/providers/google/cloud/hooks/test_datacatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

import pytest
from google.api_core.retry import Retry
from google.cloud.datacatalog_v1beta1 import CreateTagRequest, CreateTagTemplateRequest
from google.cloud.datacatalog_v1beta1.types import Entry, Tag, TagTemplate
from google.cloud.datacatalog import CreateTagRequest, CreateTagTemplateRequest, Entry, Tag, TagTemplate

from airflow import AirflowException
from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook
Expand Down