diff --git a/airflow/providers/google/cloud/example_dags/example_pubsub.py b/airflow/providers/google/cloud/example_dags/example_pubsub.py index 728531a25b728..bd2cd0868db71 100644 --- a/airflow/providers/google/cloud/example_dags/example_pubsub.py +++ b/airflow/providers/google/cloud/example_dags/example_pubsub.py @@ -25,13 +25,14 @@ from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.pubsub import ( PubSubCreateSubscriptionOperator, PubSubCreateTopicOperator, PubSubDeleteSubscriptionOperator, - PubSubDeleteTopicOperator, PubSubPublishMessageOperator, + PubSubDeleteTopicOperator, PubSubPublishMessageOperator, PubSubPullOperator, ) from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") -TOPIC = "PubSubTestTopic" +TOPIC_FOR_SENSOR_DAG = "PubSubSensorTestTopic" +TOPIC_FOR_OPERATOR_DAG = "PubSubOperatorTestTopic" MESSAGE = {"data": b"Tool", "attributes": {"name": "wrench", "mass": "1.3kg", "count": "3"}} default_args = {"start_date": days_ago(1)} @@ -45,23 +46,23 @@ # [END howto_operator_gcp_pubsub_pull_messages_result_cmd] with models.DAG( - "example_gcp_pubsub", + "example_gcp_pubsub_sensor", default_args=default_args, schedule_interval=None, # Override to match your needs -) as example_dag: +) as example_sensor_dag: # [START howto_operator_gcp_pubsub_create_topic] create_topic = PubSubCreateTopicOperator( - task_id="create_topic", topic=TOPIC, project_id=GCP_PROJECT_ID + task_id="create_topic", topic=TOPIC_FOR_SENSOR_DAG, project_id=GCP_PROJECT_ID ) # [END howto_operator_gcp_pubsub_create_topic] # [START howto_operator_gcp_pubsub_create_subscription] subscribe_task = PubSubCreateSubscriptionOperator( - task_id="subscribe_task", project_id=GCP_PROJECT_ID, topic=TOPIC + task_id="subscribe_task", project_id=GCP_PROJECT_ID, topic=TOPIC_FOR_SENSOR_DAG ) # [END howto_operator_gcp_pubsub_create_subscription] - # [START howto_operator_gcp_pubsub_pull_message] + # [START howto_operator_gcp_pubsub_pull_message_with_sensor] subscription = "{{ task_instance.xcom_pull('subscribe_task') }}" pull_messages = PubSubPullSensor( @@ -70,7 +71,7 @@ project_id=GCP_PROJECT_ID, subscription=subscription, ) - # [END howto_operator_gcp_pubsub_pull_message] + # [END howto_operator_gcp_pubsub_pull_message_with_sensor] # [START howto_operator_gcp_pubsub_pull_messages_result] pull_messages_result = BashOperator( @@ -82,7 +83,7 @@ publish_task = PubSubPublishMessageOperator( task_id="publish_task", project_id=GCP_PROJECT_ID, - topic=TOPIC, + topic=TOPIC_FOR_SENSOR_DAG, messages=[MESSAGE, MESSAGE, MESSAGE], ) # [END howto_operator_gcp_pubsub_publish] @@ -97,9 +98,72 @@ # [START howto_operator_gcp_pubsub_delete_topic] delete_topic = PubSubDeleteTopicOperator( - task_id="delete_topic", topic=TOPIC, project_id=GCP_PROJECT_ID + task_id="delete_topic", topic=TOPIC_FOR_SENSOR_DAG, project_id=GCP_PROJECT_ID ) # [END howto_operator_gcp_pubsub_delete_topic] create_topic >> subscribe_task >> publish_task subscribe_task >> pull_messages >> pull_messages_result >> unsubscribe_task >> delete_topic + + +with models.DAG( + "example_gcp_pubsub_operator", + default_args=default_args, + schedule_interval=None, # Override to match your needs +) as example_operator_dag: + # [START howto_operator_gcp_pubsub_create_topic] + create_topic = PubSubCreateTopicOperator( + task_id="create_topic", topic=TOPIC_FOR_OPERATOR_DAG, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_gcp_pubsub_create_topic] + + # [START howto_operator_gcp_pubsub_create_subscription] + subscribe_task = PubSubCreateSubscriptionOperator( + task_id="subscribe_task", project_id=GCP_PROJECT_ID, topic=TOPIC_FOR_OPERATOR_DAG + ) + # [END howto_operator_gcp_pubsub_create_subscription] + + # [START howto_operator_gcp_pubsub_pull_message_with_operator] + subscription = "{{ task_instance.xcom_pull('subscribe_task') }}" + + pull_messages = PubSubPullOperator( + task_id="pull_messages", + ack_messages=True, + project_id=GCP_PROJECT_ID, + subscription=subscription, + ) + # [END howto_operator_gcp_pubsub_pull_message_with_operator] + + # [START howto_operator_gcp_pubsub_pull_messages_result] + pull_messages_result = BashOperator( + task_id="pull_messages_result", bash_command=echo_cmd + ) + # [END howto_operator_gcp_pubsub_pull_messages_result] + + # [START howto_operator_gcp_pubsub_publish] + publish_task = PubSubPublishMessageOperator( + task_id="publish_task", + project_id=GCP_PROJECT_ID, + topic=TOPIC_FOR_OPERATOR_DAG, + messages=[MESSAGE, MESSAGE, MESSAGE], + ) + # [END howto_operator_gcp_pubsub_publish] + + # [START howto_operator_gcp_pubsub_unsubscribe] + unsubscribe_task = PubSubDeleteSubscriptionOperator( + task_id="unsubscribe_task", + project_id=GCP_PROJECT_ID, + subscription="{{ task_instance.xcom_pull('subscribe_task') }}", + ) + # [END howto_operator_gcp_pubsub_unsubscribe] + + # [START howto_operator_gcp_pubsub_delete_topic] + delete_topic = PubSubDeleteTopicOperator( + task_id="delete_topic", topic=TOPIC_FOR_OPERATOR_DAG, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_gcp_pubsub_delete_topic] + + ( + create_topic >> subscribe_task >> publish_task + >> pull_messages >> pull_messages_result >> unsubscribe_task >> delete_topic + ) diff --git a/airflow/providers/google/cloud/hooks/pubsub.py b/airflow/providers/google/cloud/hooks/pubsub.py index 13c32850178c6..a581fa8c90381 100644 --- a/airflow/providers/google/cloud/hooks/pubsub.py +++ b/airflow/providers/google/cloud/hooks/pubsub.py @@ -28,7 +28,7 @@ from google.api_core.retry import Retry from google.cloud.exceptions import NotFound from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient -from google.cloud.pubsub_v1.types import Duration, MessageStoragePolicy, PushConfig +from google.cloud.pubsub_v1.types import Duration, MessageStoragePolicy, PushConfig, ReceivedMessage from googleapiclient.errors import HttpError from airflow.providers.google.cloud.hooks.base import CloudBaseHook @@ -460,7 +460,7 @@ def pull( retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None, - ) -> List[Dict]: + ) -> List[ReceivedMessage]: """ Pulls up to ``max_messages`` messages from Pub/Sub subscription. @@ -496,7 +496,7 @@ def pull( subscriber = self.subscriber_client subscription_path = SubscriberClient.subscription_path(project_id, subscription) # noqa E501 # pylint: disable=no-member,line-too-long - self.log.info("Pulling mex %d messages from subscription (path) %s", max_messages, subscription_path) + self.log.info("Pulling max %d messages from subscription (path) %s", max_messages, subscription_path) try: # pylint: disable=no-member response = subscriber.pull( @@ -517,7 +517,8 @@ def pull( def acknowledge( self, subscription: str, - ack_ids: List[str], + ack_ids: Optional[List[str]] = None, + messages: Optional[List[ReceivedMessage]] = None, project_id: Optional[str] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, @@ -529,9 +530,12 @@ def acknowledge( :param subscription: the Pub/Sub subscription name to delete; do not include the 'projects/{project}/topics/' prefix. :type subscription: str - :param ack_ids: List of ReceivedMessage ackIds from a previous pull - response + :param ack_ids: List of ReceivedMessage ackIds from a previous pull response. + Mutually exclusive with ``messages`` argument. :type ack_ids: list + :param messages: List of ReceivedMessage objects to acknowledge. + Mutually exclusive with ``ack_ids`` argument. + :type messages: list :param project_id: Optional, the GCP project name or ID in which to create the topic If set to None or missing, the default project_id from the GCP connection is used. :type project_id: str @@ -545,8 +549,20 @@ def acknowledge( :param metadata: (Optional) Additional metadata that is provided to the method. :type metadata: Sequence[Tuple[str, str]]] """ + if not project_id: raise ValueError("Project ID should be set.") + + if ack_ids is not None and messages is None: + pass + elif ack_ids is None and messages is not None: + ack_ids = [ + message.ack_id + for message in messages + ] + else: + raise ValueError("One and only one of 'ack_ids' and 'messages' arguments have to be provided") + subscriber = self.subscriber_client subscription_path = SubscriberClient.subscription_path(project_id, subscription) # noqa E501 # pylint: disable=no-member,line-too-long diff --git a/airflow/providers/google/cloud/operators/pubsub.py b/airflow/providers/google/cloud/operators/pubsub.py index 1aa2680fb7176..233d3fb79c4e6 100644 --- a/airflow/providers/google/cloud/operators/pubsub.py +++ b/airflow/providers/google/cloud/operators/pubsub.py @@ -19,10 +19,11 @@ This module contains Google PubSub operators. """ import warnings -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from google.api_core.retry import Retry -from google.cloud.pubsub_v1.types import Duration, MessageStoragePolicy, PushConfig +from google.cloud.pubsub_v1.types import Duration, MessageStoragePolicy, PushConfig, ReceivedMessage +from google.protobuf.json_format import MessageToDict from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.pubsub import PubSubHook @@ -666,3 +667,123 @@ def execute(self, context): self.log.info("Publishing to topic %s", self.topic) hook.publish(project_id=self.project_id, topic=self.topic, messages=self.messages) self.log.info("Published to topic %s", self.topic) + + +class PubSubPullOperator(BaseOperator): + """Pulls messages from a PubSub subscription and passes them through XCom. + If the queue is empty, returns empty list - never waits for messages. + If you do need to wait, please use :class:`airflow.providers.google.cloud.sensors.PubSubPullSensor` + instead. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PubSubPullSensor` + + This sensor operator will pull up to ``max_messages`` messages from the + specified PubSub subscription. When the subscription returns messages, + the poke method's criteria will be fulfilled and the messages will be + returned from the operator and passed through XCom for downstream tasks. + + If ``ack_messages`` is set to True, messages will be immediately + acknowledged before being returned, otherwise, downstream tasks will be + responsible for acknowledging them. + + ``project`` and ``subscription`` are templated so you can use + variables in them. + + :param project: the GCP project ID for the subscription (templated) + :type project: str + :param subscription: the Pub/Sub subscription name. Do not include the + full subscription path. + :type subscription: str + :param max_messages: The maximum number of messages to retrieve per + PubSub pull request + :type max_messages: int + :param ack_messages: If True, each message will be acknowledged + immediately rather than by any downstream tasks + :type ack_messages: bool + :param gcp_conn_id: The connection ID to use connecting to + Google Cloud Platform. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request + must have domain-wide delegation enabled. + :type delegate_to: str + :param messages_callback: (Optional) Callback to process received messages. + It's return value will be saved to XCom. + If you are pulling large messages, you probably want to provide a custom callback. + If not provided, the default implementation will convert `ReceivedMessage` objects + into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function. + :type messages_callback: Optional[Callable[[List[ReceivedMessage], Dict[str, Any]], Any]] + """ + template_fields = ['project_id', 'subscription'] + + @apply_defaults + def __init__( + self, + project_id: str, + subscription: str, + max_messages: int = 5, + ack_messages: bool = False, + messages_callback: Optional[Callable[[List[ReceivedMessage], Dict[str, Any]], Any]] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, + **kwargs + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.project_id = project_id + self.subscription = subscription + self.max_messages = max_messages + self.ack_messages = ack_messages + self.messages_callback = messages_callback + + def execute(self, context): + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + ) + + pulled_messages = hook.pull( + project_id=self.project_id, + subscription=self.subscription, + max_messages=self.max_messages, + return_immediately=True, + ) + + handle_messages = self.messages_callback or self._default_message_callback + + ret = handle_messages(pulled_messages, context) + + if pulled_messages and self.ack_messages: + hook.acknowledge( + project_id=self.project_id, + subscription=self.subscription, + messages=pulled_messages, + ) + + return ret + + def _default_message_callback( + self, + pulled_messages: List[ReceivedMessage], + context: Dict[str, Any], # pylint: disable=unused-argument + ): + """ + This method can be overridden by subclasses or by `messages_callback` constructor argument. + This default implementation converts `ReceivedMessage` objects into JSON-serializable dicts. + + :param pulled_messages: messages received from the topic. + :type pulled_messages: List[ReceivedMessage] + :param context: same as in `execute` + :return: value to be saved to XCom. + """ + + messages_json = [ + MessageToDict(m) + for m in pulled_messages + ] + + return messages_json diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index 4380152796d65..75bfab02d4ec8 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -19,8 +19,9 @@ This module contains a Google PubSub sensor. """ import warnings -from typing import Optional +from typing import Any, Callable, Dict, List, Optional +from google.cloud.pubsub_v1.types import ReceivedMessage from google.protobuf.json_format import MessageToDict from airflow.providers.google.cloud.hooks.pubsub import PubSubHook @@ -30,11 +31,16 @@ class PubSubPullSensor(BaseSensorOperator): """Pulls messages from a PubSub subscription and passes them through XCom. + Always waits for at least one message to be returned from the subscription. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:PubSubPullSensor` + .. seealso:: + If you don't want to wait for at least one message to come, use Operator instead: + :class:`airflow.providers.google.cloud.operators.PubSubPullOperator` + This sensor operator will pull up to ``max_messages`` messages from the specified PubSub subscription. When the subscription returns messages, the poke method's criteria will be fulfilled and the messages will be @@ -55,8 +61,15 @@ class PubSubPullSensor(BaseSensorOperator): :param max_messages: The maximum number of messages to retrieve per PubSub pull request :type max_messages: int - :param return_immediately: If True, instruct the PubSub API to return - immediately if no messages are available for delivery. + :param return_immediately: + (Deprecated) This is an underlying PubSub API implementation detail. + It has no real effect on Sensor behaviour other than some internal wait time before retrying + on empty queue. + The Sensor task will (by definition) always wait for a message, regardless of this argument value. + + If you want a non-blocking task that does not to wait for messages, please use + :class:`airflow.providers.google.cloud.operators.PubSubPullOperator` + instead. :type return_immediately: bool :param ack_messages: If True, each message will be acknowledged immediately rather than by any downstream tasks @@ -68,6 +81,12 @@ class PubSubPullSensor(BaseSensorOperator): For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str + :param messages_callback: (Optional) Callback to process received messages. + It's return value will be saved to XCom. + If you are pulling large messages, you probably want to provide a custom callback. + If not provided, the default implementation will convert `ReceivedMessage` objects + into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function. + :type messages_callback: Optional[Callable[[List[ReceivedMessage], Dict[str, Any]], Any]] """ template_fields = ['project_id', 'subscription'] ui_color = '#ff7f50' @@ -78,14 +97,15 @@ def __init__( project_id: str, subscription: str, max_messages: int = 5, - return_immediately: bool = False, + return_immediately: bool = True, ack_messages: bool = False, gcp_conn_id: str = 'google_cloud_default', + messages_callback: Optional[Callable[[List[ReceivedMessage], Dict[str, Any]], Any]] = None, delegate_to: Optional[str] = None, project: Optional[str] = None, *args, - **kwargs) -> None: - + **kwargs + ) -> None: # To preserve backward compatibility # TODO: remove one day if project: @@ -94,6 +114,18 @@ def __init__( "the project_id parameter.", DeprecationWarning, stacklevel=2) project_id = project + if not return_immediately: + warnings.warn( + "The return_immediately parameter is deprecated.\n" + " It exposes what is really just an implementation detail of underlying PubSub API.\n" + " It has no effect on PubSubPullSensor behaviour.\n" + " It should be left as default value of True.\n" + " If is here only because of backwards compatibility.\n" + " If may be removed in the future.\n", + DeprecationWarning, + stacklevel=2 + ) + super().__init__(*args, **kwargs) self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to @@ -102,27 +134,59 @@ def __init__( self.max_messages = max_messages self.return_immediately = return_immediately self.ack_messages = ack_messages + self.messages_callback = messages_callback - self._messages = None + self._return_value = None def execute(self, context): """Overridden to allow messages to be passed""" super().execute(context) - return self._messages + return self._return_value def poke(self, context): - hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + ) + pulled_messages = hook.pull( project_id=self.project_id, subscription=self.subscription, max_messages=self.max_messages, - return_immediately=self.return_immediately + return_immediately=self.return_immediately, ) - self._messages = [MessageToDict(m) for m in pulled_messages] + handle_messages = self.messages_callback or self._default_message_callback + + self._return_value = handle_messages(pulled_messages, context) + + if pulled_messages and self.ack_messages: + hook.acknowledge( + project_id=self.project_id, + subscription=self.subscription, + messages=pulled_messages, + ) - if self._messages and self.ack_messages: - ack_ids = [m['ackId'] for m in self._messages if m.get('ackId')] - hook.acknowledge(project_id=self.project_id, subscription=self.subscription, ack_ids=ack_ids) - return self._messages + return bool(pulled_messages) + + def _default_message_callback( + self, + pulled_messages: List[ReceivedMessage], + context: Dict[str, Any], # pylint: disable=unused-argument + ): + """ + This method can be overridden by subclasses or by `messages_callback` constructor argument. + This default implementation converts `ReceivedMessage` objects into JSON-serializable dicts. + + :param pulled_messages: messages received from the topic. + :type pulled_messages: List[ReceivedMessage] + :param context: same as in `execute` + :return: value to be saved to XCom. + """ + + messages_json = [ + MessageToDict(m) + for m in pulled_messages + ] + + return messages_json diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py index 2d797434ba296..e45d511f6da85 100644 --- a/airflow/sensors/base_sensor_operator.py +++ b/airflow/sensors/base_sensor_operator.py @@ -19,7 +19,7 @@ import hashlib from datetime import timedelta from time import sleep -from typing import Dict, Iterable +from typing import Any, Dict, Iterable from airflow.exceptions import ( AirflowException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, @@ -103,7 +103,7 @@ def poke(self, context: Dict) -> bool: """ raise AirflowException('Override me.') - def execute(self, context: Dict) -> None: + def execute(self, context: Dict) -> Any: started_at = timezone.utcnow() try_number = 1 if self.reschedule: diff --git a/docs/howto/operator/gcp/pubsub.rst b/docs/howto/operator/gcp/pubsub.rst index 33346111817d5..b246e4ead79ba 100644 --- a/docs/howto/operator/gcp/pubsub.rst +++ b/docs/howto/operator/gcp/pubsub.rst @@ -89,8 +89,13 @@ and pass them through XCom. .. exampleinclude:: ../../../../airflow/providers/google/cloud/example_dags/example_pubsub.py :language: python - :start-after: [START howto_operator_gcp_pubsub_pull_message] - :end-before: [END howto_operator_gcp_pubsub_pull_message] + :start-after: [START howto_operator_gcp_pubsub_pull_message_with_sensor] + :end-before: [END howto_operator_gcp_pubsub_pull_message_with_sensor] + +.. exampleinclude:: ../../../../airflow/providers/google/cloud/example_dags/example_pubsub.py + :language: python + :start-after: [START howto_operator_gcp_pubsub_pull_message_with_operator] + :end-before: [END howto_operator_gcp_pubsub_pull_message_with_operator] To pull messages from XCom use the :class:`~airflow.operators.bash.BashOperator`. diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py index 5c5dd41ba0082..a760f4a618c63 100644 --- a/tests/providers/google/cloud/hooks/test_pubsub.py +++ b/tests/providers/google/cloud/hooks/test_pubsub.py @@ -17,10 +17,13 @@ # under the License. import unittest +from typing import List import mock from google.api_core.exceptions import AlreadyExists, GoogleAPICallError from google.cloud.exceptions import NotFound +from google.cloud.pubsub_v1.types import ReceivedMessage +from google.protobuf.json_format import ParseDict from googleapiclient.errors import HttpError from parameterized import parameterized @@ -58,6 +61,21 @@ def setUp(self): new=mock_init): self.pubsub_hook = PubSubHook(gcp_conn_id='test') + def _generate_messages(self, count) -> List[ReceivedMessage]: + return [ + ParseDict( + { + "ack_id": str(i), + "message": { + "data": f'Message {i}'.encode('utf8'), + "attributes": {"type": "generated message"}, + }, + }, + ReceivedMessage(), + ) + for i in range(1, count + 1) + ] + @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook.client_info", new_callable=mock.PropertyMock) @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook._get_credentials") @@ -389,7 +407,7 @@ def test_pull_fails_on_exception(self, exception, mock_service): ) @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) - def test_acknowledge(self, mock_service): + def test_acknowledge_by_ack_ids(self, mock_service): ack_method = mock_service.acknowledge self.pubsub_hook.acknowledge( @@ -405,6 +423,23 @@ def test_acknowledge(self, mock_service): metadata=None ) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) + def test_acknowledge_by_message_objects(self, mock_service): + ack_method = mock_service.acknowledge + + self.pubsub_hook.acknowledge( + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + messages=self._generate_messages(3), + ) + ack_method.assert_called_once_with( + subscription=EXPANDED_SUBSCRIPTION, + ack_ids=['1', '2', '3'], + retry=None, + timeout=None, + metadata=None, + ) + @parameterized.expand([ (exception,) for exception in [ HttpError(resp={'status': '404'}, content=EMPTY_CONTENT), diff --git a/tests/providers/google/cloud/operators/test_pubsub.py b/tests/providers/google/cloud/operators/test_pubsub.py index bc1db0b014472..e9265da48665e 100644 --- a/tests/providers/google/cloud/operators/test_pubsub.py +++ b/tests/providers/google/cloud/operators/test_pubsub.py @@ -17,12 +17,15 @@ # under the License. import unittest +from typing import Any, Dict, List import mock +from google.cloud.pubsub_v1.types import ReceivedMessage +from google.protobuf.json_format import MessageToDict, ParseDict from airflow.providers.google.cloud.operators.pubsub import ( PubSubCreateSubscriptionOperator, PubSubCreateTopicOperator, PubSubDeleteSubscriptionOperator, - PubSubDeleteTopicOperator, PubSubPublishMessageOperator, + PubSubDeleteTopicOperator, PubSubPublishMessageOperator, PubSubPullOperator, ) TASK_ID = 'test-task-id' @@ -36,11 +39,9 @@ }, {'data': b'Knock, knock'}, {'attributes': {'foo': ''}}] -TEST_POKE_INTERVAl = 0 class TestPubSubTopicCreateOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_failifexists(self, mock_hook): operator = PubSubCreateTopicOperator( @@ -87,7 +88,6 @@ def test_succeedifexists(self, mock_hook): class TestPubSubTopicDeleteOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): operator = PubSubDeleteTopicOperator( @@ -108,7 +108,6 @@ def test_execute(self, mock_hook): class TestPubSubSubscriptionCreateOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): operator = PubSubCreateSubscriptionOperator( @@ -193,7 +192,6 @@ def test_execute_no_subscription(self, mock_hook): class TestPubSubSubscriptionDeleteOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): operator = PubSubDeleteSubscriptionOperator( @@ -209,20 +207,115 @@ def test_execute(self, mock_hook): fail_if_not_exists=False, retry=None, timeout=None, - metadata=None + metadata=None, ) class TestPubSubPublishOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_publish(self, mock_hook): - operator = PubSubPublishMessageOperator(task_id=TASK_ID, - project_id=TEST_PROJECT, - topic=TEST_TOPIC, - messages=TEST_MESSAGES) + operator = PubSubPublishMessageOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + messages=TEST_MESSAGES, + ) operator.execute(None) mock_hook.return_value.publish.assert_called_once_with( project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES ) + + +class TestPubSubPullOperator(unittest.TestCase): + def _generate_messages(self, count): + return [ + ParseDict( + { + "ack_id": "%s" % i, + "message": { + "data": 'Message {}'.format(i).encode('utf8'), + "attributes": {"type": "generated message"}, + }, + }, + ReceivedMessage(), + ) + for i in range(1, count + 1) + ] + + def _generate_dicts(self, count): + return [ + MessageToDict(m) + for m in self._generate_messages(count) + ] + + @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') + def test_execute_no_messages(self, mock_hook): + operator = PubSubPullOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + ) + + mock_hook.return_value.pull.return_value = [] + self.assertEqual([], operator.execute({})) + + @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') + def test_execute_with_ack_messages(self, mock_hook): + operator = PubSubPullOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + ack_messages=True, + ) + + generated_messages = self._generate_messages(5) + generated_dicts = self._generate_dicts(5) + mock_hook.return_value.pull.return_value = generated_messages + + self.assertEqual(generated_dicts, operator.execute({})) + mock_hook.return_value.acknowledge.assert_called_once_with( + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + messages=generated_messages, + ) + + @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') + def test_execute_with_messages_callback(self, mock_hook): + generated_messages = self._generate_messages(5) + messages_callback_return_value = 'asdfg' + + def messages_callback( + pulled_messages: List[ReceivedMessage], + context: Dict[str, Any], + ): + assert pulled_messages == generated_messages + + assert isinstance(context, dict) + for key in context.keys(): + assert isinstance(key, str) + + return messages_callback_return_value + + messages_callback = mock.Mock(side_effect=messages_callback) + + operator = PubSubPullOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + messages_callback=messages_callback, + ) + + mock_hook.return_value.pull.return_value = generated_messages + + response = operator.execute({}) + mock_hook.return_value.pull.assert_called_once_with( + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + max_messages=5, + return_immediately=True + ) + + messages_callback.assert_called_once() + + assert response == messages_callback_return_value diff --git a/tests/providers/google/cloud/operators/test_pubsub_system.py b/tests/providers/google/cloud/operators/test_pubsub_system.py index 2b000ecabd560..8594dde2279ec 100644 --- a/tests/providers/google/cloud/operators/test_pubsub_system.py +++ b/tests/providers/google/cloud/operators/test_pubsub_system.py @@ -25,5 +25,9 @@ @pytest.mark.credential_file(GCP_PUBSUB_KEY) class PubSubSystemTest(GoogleSystemTest): @provide_gcp_context(GCP_PUBSUB_KEY) - def test_run_example_dag(self): - self.run_dag(dag_id="example_gcp_pubsub", dag_folder=CLOUD_DAG_FOLDER) + def test_run_example_sensor_dag(self): + self.run_dag(dag_id="example_gcp_pubsub_sensor", dag_folder=CLOUD_DAG_FOLDER) + + @provide_gcp_context(GCP_PUBSUB_KEY) + def test_run_example_operator_dag(self): + self.run_dag(dag_id="example_gcp_pubsub_operator", dag_folder=CLOUD_DAG_FOLDER) diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py index 88fe1874d586f..c2305a0e8216f 100644 --- a/tests/providers/google/cloud/sensors/test_pubsub.py +++ b/tests/providers/google/cloud/sensors/test_pubsub.py @@ -17,6 +17,7 @@ # under the License. import unittest +from typing import Any, Dict, List import mock from google.cloud.pubsub_v1.types import ReceivedMessage @@ -47,28 +48,40 @@ def _generate_messages(self, count): ] def _generate_dicts(self, count): - return [MessageToDict(m) for m in self._generate_messages(count)] + return [ + MessageToDict(m) + for m in self._generate_messages(count) + ] @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') def test_poke_no_messages(self, mock_hook): - operator = PubSubPullSensor(task_id=TASK_ID, project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION) + operator = PubSubPullSensor( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + ) + mock_hook.return_value.pull.return_value = [] - self.assertEqual([], operator.poke(None)) + self.assertEqual(False, operator.poke({})) @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') def test_poke_with_ack_messages(self, mock_hook): - operator = PubSubPullSensor(task_id=TASK_ID, project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - ack_messages=True) + operator = PubSubPullSensor( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + ack_messages=True, + ) + generated_messages = self._generate_messages(5) - generated_dicts = self._generate_dicts(5) + mock_hook.return_value.pull.return_value = generated_messages - self.assertEqual(generated_dicts, operator.poke(None)) + + self.assertEqual(True, operator.poke({})) mock_hook.return_value.acknowledge.assert_called_once_with( project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, - ack_ids=['1', '2', '3', '4', '5'] + messages=generated_messages, ) @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') @@ -77,31 +90,80 @@ def test_execute(self, mock_hook): task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, - poke_interval=0 + poke_interval=0, ) + generated_messages = self._generate_messages(5) generated_dicts = self._generate_dicts(5) mock_hook.return_value.pull.return_value = generated_messages - response = operator.execute(None) + + response = operator.execute({}) mock_hook.return_value.pull.assert_called_once_with( project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=5, - return_immediately=False + return_immediately=True ) self.assertEqual(generated_dicts, response) @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') def test_execute_timeout(self, mock_hook): - operator = PubSubPullSensor(task_id=TASK_ID, project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - poke_interval=0, timeout=1) + operator = PubSubPullSensor( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + poke_interval=0, + timeout=1, + ) + mock_hook.return_value.pull.return_value = [] + with self.assertRaises(AirflowSensorTimeout): - operator.execute(None) + operator.execute({}) mock_hook.return_value.pull.assert_called_once_with( project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=5, return_immediately=False ) + + @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') + def test_execute_with_messages_callback(self, mock_hook): + generated_messages = self._generate_messages(5) + messages_callback_return_value = 'asdfg' + + def messages_callback( + pulled_messages: List[ReceivedMessage], + context: Dict[str, Any], + ): + assert pulled_messages == generated_messages + + assert isinstance(context, dict) + for key in context.keys(): + assert isinstance(key, str) + + return messages_callback_return_value + + messages_callback = mock.Mock(side_effect=messages_callback) + + operator = PubSubPullSensor( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + poke_interval=0, + messages_callback=messages_callback, + ) + + mock_hook.return_value.pull.return_value = generated_messages + + response = operator.execute({}) + mock_hook.return_value.pull.assert_called_once_with( + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + max_messages=5, + return_immediately=True + ) + + messages_callback.assert_called_once() + + assert response == messages_callback_return_value