From 9befc60b25976f0127a81639f10ba9890d625ab5 Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Wed, 5 Feb 2025 15:01:35 +0000 Subject: [PATCH] Create operators for working with Consumer Groups for GCP Apache Kafka --- dev/breeze/tests/test_selective_checks.py | 8 +- generated/provider_dependencies.json | 4 +- providers/apache/kafka/README.rst | 19 ++ providers/apache/kafka/pyproject.toml | 7 + .../apache/kafka/get_provider_info.py | 1 + .../providers/apache/kafka/hooks/base.py | 11 + .../docs/operators/cloud/managed_kafka.rst | 60 ++++ providers/google/provider.yaml | 1 + .../google/cloud/hooks/managed_kafka.py | 227 ++++++++++++++- .../google/cloud/links/managed_kafka.py | 30 ++ .../google/cloud/operators/managed_kafka.py | 265 ++++++++++++++++++ .../providers/google/get_provider_info.py | 1 + .../example_managed_kafka_consumer_group.py | 254 +++++++++++++++++ .../google/cloud/hooks/test_managed_kafka.py | 189 +++++++++++++ .../google/cloud/links/test_managed_kafka.py | 40 +++ .../cloud/operators/test_managed_kafka.py | 131 +++++++++ 16 files changed, 1241 insertions(+), 7 deletions(-) create mode 100644 providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index b4fdfa36a2182..1fcf0e7aa20df 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -1621,7 +1621,7 @@ def test_expected_output_push( "providers/google/tests/unit/google/file.py", ), { - "selected-providers-list-as-string": "amazon apache.beam apache.cassandra " + "selected-providers-list-as-string": "amazon apache.beam apache.cassandra apache.kafka " "cncf.kubernetes common.compat common.sql " "facebook google hashicorp microsoft.azure microsoft.mssql mysql " "openlineage oracle postgres presto salesforce samba sftp ssh trino", @@ -1635,14 +1635,14 @@ def test_expected_output_push( "test-groups": "['core', 'providers']", "docs-build": "true", "docs-list-as-string": "apache-airflow helm-chart amazon apache.beam apache.cassandra " - "cncf.kubernetes common.compat common.sql facebook google hashicorp microsoft.azure " + "apache.kafka cncf.kubernetes common.compat common.sql facebook google hashicorp microsoft.azure " "microsoft.mssql mysql openlineage oracle postgres " "presto salesforce samba sftp ssh trino", "skip-pre-commits": "identity,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk,ts-compile-format-lint-ui", "run-kubernetes-tests": "true", "upgrade-to-newer-dependencies": "false", "core-test-types-list-as-string": "Always CLI", - "providers-test-types-list-as-string": "Providers[amazon] Providers[apache.beam,apache.cassandra,cncf.kubernetes,common.compat,common.sql,facebook," + "providers-test-types-list-as-string": "Providers[amazon] Providers[apache.beam,apache.cassandra,apache.kafka,cncf.kubernetes,common.compat,common.sql,facebook," "hashicorp,microsoft.azure,microsoft.mssql,mysql,openlineage,oracle,postgres,presto," "salesforce,samba,sftp,ssh,trino] Providers[google]", "needs-mypy": "true", @@ -1890,7 +1890,7 @@ def test_upgrade_to_newer_dependencies( pytest.param( ("providers/google/docs/some_file.rst",), { - "docs-list-as-string": "amazon apache.beam apache.cassandra " + "docs-list-as-string": "amazon apache.beam apache.cassandra apache.kafka " "cncf.kubernetes common.compat common.sql facebook google hashicorp " "microsoft.azure microsoft.mssql mysql openlineage oracle " "postgres presto salesforce samba sftp ssh trino", diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index d2dd5e538f33e..b86941fcb0918 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -222,7 +222,9 @@ ], "devel-deps": [], "plugins": [], - "cross-providers-deps": [], + "cross-providers-deps": [ + "google" + ], "excluded-python-versions": [], "state": "ready" }, diff --git a/providers/apache/kafka/README.rst b/providers/apache/kafka/README.rst index cfa6d23c9f7dd..f85a6682b27fb 100644 --- a/providers/apache/kafka/README.rst +++ b/providers/apache/kafka/README.rst @@ -58,5 +58,24 @@ PIP package Version required ``confluent-kafka`` ``>=2.3.0`` =================== ================== +Cross provider package dependencies +----------------------------------- + +Those are dependencies that might be needed in order to use all the features of the package. +You need to install the specified provider packages in order to use them. + +You can install such cross-provider dependencies when installing from PyPI. For example: + +.. code-block:: bash + + pip install apache-airflow-providers-apache-kafka[google] + + +==================================================================================================== ========== +Dependent package Extra +==================================================================================================== ========== +`apache-airflow-providers-google `_ ``google`` +==================================================================================================== ========== + The changelog for the provider package can be found in the `changelog `_. diff --git a/providers/apache/kafka/pyproject.toml b/providers/apache/kafka/pyproject.toml index 51c6126c5cab5..c236cf0e6625f 100644 --- a/providers/apache/kafka/pyproject.toml +++ b/providers/apache/kafka/pyproject.toml @@ -62,6 +62,13 @@ dependencies = [ "confluent-kafka>=2.3.0", ] +# The optional dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +[project.optional-dependencies] +"google" = [ + "apache-airflow-providers-google" +] + [project.urls] "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-apache-kafka/1.7.0" "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-apache-kafka/1.7.0/changelog.html" diff --git a/providers/apache/kafka/src/airflow/providers/apache/kafka/get_provider_info.py b/providers/apache/kafka/src/airflow/providers/apache/kafka/get_provider_info.py index 97287ea00e267..d53eee67dc7aa 100644 --- a/providers/apache/kafka/src/airflow/providers/apache/kafka/get_provider_info.py +++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/get_provider_info.py @@ -90,4 +90,5 @@ def get_provider_info(): } ], "dependencies": ["apache-airflow>=2.9.0", "asgiref>=2.3.0", "confluent-kafka>=2.3.0"], + "optional-dependencies": {"google": ["apache-airflow-providers-google"]}, } diff --git a/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py b/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py index 9b20c7dfc91cc..5d02903a4d692 100644 --- a/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py +++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py @@ -22,6 +22,7 @@ from confluent_kafka.admin import AdminClient from airflow.hooks.base import BaseHook +from airflow.providers.google.cloud.hooks.managed_kafka import ManagedKafkaHook class KafkaBaseHook(BaseHook): @@ -63,6 +64,16 @@ def get_conn(self) -> Any: if not (config.get("bootstrap.servers", None)): raise ValueError("config['bootstrap.servers'] must be provided.") + bootstrap_servers = config.get("bootstrap.servers") + if ( + bootstrap_servers + and bootstrap_servers.find("cloud.goog") != -1 + and bootstrap_servers.find("managedkafka") != -1 + ): + self.log.info("Adding token generation for Google Auth to the confluent configuration.") + hook = ManagedKafkaHook() + token = hook.get_confluent_token + config.update({"oauth_cb": token}) return self._get_client(config) def test_connection(self) -> tuple[bool, str]: diff --git a/providers/google/docs/operators/cloud/managed_kafka.rst b/providers/google/docs/operators/cloud/managed_kafka.rst index a81f81592eeda..791d721827da1 100644 --- a/providers/google/docs/operators/cloud/managed_kafka.rst +++ b/providers/google/docs/operators/cloud/managed_kafka.rst @@ -117,6 +117,66 @@ To update topic you can use :start-after: [START how_to_cloud_managed_kafka_update_topic_operator] :end-before: [END how_to_cloud_managed_kafka_update_topic_operator] +Interacting with Apache Kafka Consumer Groups +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To delete consumer group you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaDeleteConsumerGroupOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_delete_consumer_group_operator] + :end-before: [END how_to_cloud_managed_kafka_delete_consumer_group_operator] + +To get consumer group you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaGetConsumerGroupOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_get_consumer_group_operator] + :end-before: [END how_to_cloud_managed_kafka_get_consumer_group_operator] + +To get a list of consumer groups you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaListConsumerGroupsOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_list_consumer_group_operator] + :end-before: [END how_to_cloud_managed_kafka_list_consumer_group_operator] + +To update consumer group you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaUpdateConsumerGroupOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_update_consumer_group_operator] + :end-before: [END how_to_cloud_managed_kafka_update_consumer_group_operator] + +Using Apache Kafka provider with Google Cloud Managed Service for Apache Kafka +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To produce data to topic you can use +:class:`~airflow.providers.apache.kafka.operators.produce.ProduceToTopicOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_produce_to_topic_operator] + :end-before: [END how_to_cloud_managed_kafka_produce_to_topic_operator] + +To consume data from topic you can use +:class:`~airflow.providers.apache.kafka.operators.produce.ConsumeFromTopicOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_consume_from_topic_operator] + :end-before: [END how_to_cloud_managed_kafka_consume_from_topic_operator] + Reference ^^^^^^^^^ diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml index 3e3a9e7c1063a..a251ef1db9a94 100644 --- a/providers/google/provider.yaml +++ b/providers/google/provider.yaml @@ -1231,6 +1231,7 @@ extra-links: - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink + - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaConsumerGroupLink secrets-backends: diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py index f71e8a158c83f..738727332d4c0 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py @@ -19,12 +19,17 @@ from __future__ import annotations +import base64 +import datetime +import json +import time from collections.abc import Sequence from copy import deepcopy from typing import TYPE_CHECKING from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.cloud.managedkafka_v1 import Cluster, ManagedKafkaClient, Topic, types +from google.auth.transport import requests as google_requests +from google.cloud.managedkafka_v1 import Cluster, ConsumerGroup, ManagedKafkaClient, Topic, types from airflow.exceptions import AirflowException from airflow.providers.google.common.consts import CLIENT_INFO @@ -33,10 +38,62 @@ if TYPE_CHECKING: from google.api_core.operation import Operation from google.api_core.retry import Retry - from google.cloud.managedkafka_v1.services.managed_kafka.pagers import ListClustersPager, ListTopicsPager + from google.auth.credentials import Credentials + from google.cloud.managedkafka_v1.services.managed_kafka.pagers import ( + ListClustersPager, + ListConsumerGroupsPager, + ListTopicsPager, + ) from google.protobuf.field_mask_pb2 import FieldMask +class ManagedKafkaTokenProvider: + """Helper for providing authentication token for establishing connection via confluent to Apache Kafka cluster managed by Google Cloud.""" + + def __init__( + self, + credentials: Credentials, + ): + self._credentials = credentials + self._header = json.dumps(dict(typ="JWT", alg="GOOG_OAUTH2_TOKEN")) + + def _valid_credentials(self): + if not self._credentials.valid: + self._credentials.refresh(google_requests.Request()) + return self._credentials + + def _get_jwt(self, credentials): + return json.dumps( + dict( + exp=credentials.expiry.timestamp(), + iss="Google", + iat=datetime.datetime.now(datetime.timezone.utc).timestamp(), + scope="kafka", + sub=credentials.service_account_email, + ) + ) + + def _b64_encode(self, source): + return base64.urlsafe_b64encode(source.encode("utf-8")).decode("utf-8").rstrip("=") + + def _get_kafka_access_token(self, credentials): + return ".".join( + [ + self._b64_encode(self._header), + self._b64_encode(self._get_jwt(credentials)), + self._b64_encode(credentials.token), + ] + ) + + def confluent_token(self): + credentials = self._valid_credentials() + + utc_expiry = credentials.expiry.replace(tzinfo=datetime.timezone.utc) + expiry_seconds = (utc_expiry - datetime.datetime.now(datetime.timezone.utc)).total_seconds() + + return self._get_kafka_access_token(credentials), time.time() + expiry_seconds + + class ManagedKafkaHook(GoogleBaseHook): """Hook for Managed Service for Apache Kafka APIs.""" @@ -63,6 +120,12 @@ def wait_for_operation(self, operation: Operation, timeout: float | None = None) error = operation.exception(timeout=timeout) raise AirflowException(error) + def get_confluent_token(self): + """Get the authentication token for confluent client.""" + token_provider = ManagedKafkaTokenProvider(credentials=self.get_credentials()) + token = token_provider.confluent_token() + return token + @GoogleBaseHook.fallback_to_default_project_id def create_cluster( self, @@ -481,3 +544,163 @@ def delete_topic( timeout=timeout, metadata=metadata, ) + + @GoogleBaseHook.fallback_to_default_project_id + def list_consumer_groups( + self, + project_id: str, + location: str, + cluster_id: str, + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ListConsumerGroupsPager: + """ + List the consumer groups in a given cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose consumer groups are to be listed. + :param page_size: Optional. The maximum number of consumer groups to return. The service may return + fewer than this value. If unset or zero, all consumer groups for the parent is returned. + :param page_token: Optional. A page token, received from a previous ``ListConsumerGroups`` call. + Provide this to retrieve the subsequent page. When paginating, all other parameters provided to + ``ListConsumerGroups`` must match the call that provided the page token. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + parent = client.cluster_path(project_id, location, cluster_id) + + result = client.list_consumer_groups( + request={ + "parent": parent, + "page_size": page_size, + "page_token": page_token, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_consumer_group( + self, + project_id: str, + location: str, + cluster_id: str, + consumer_group_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> types.ConsumerGroup: + """ + Return the properties of a single consumer group. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose consumer group is to be returned. + :param consumer_group_id: Required. The ID of the consumer group whose configuration to return. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + name = client.consumer_group_path(project_id, location, cluster_id, consumer_group_id) + + result = client.get_consumer_group( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_consumer_group( + self, + project_id: str, + location: str, + cluster_id: str, + consumer_group_id: str, + consumer_group: types.ConsumerGroup | dict, + update_mask: FieldMask | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> types.ConsumerGroup: + """ + Update the properties of a single consumer group. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topic is to be updated. + :param consumer_group_id: Required. The ID of the consumer group whose configuration to update. + :param consumer_group: Required. The consumer_group to update. Its ``name`` field must be populated. + :param update_mask: Required. Field mask is used to specify the fields to be overwritten in the + ConsumerGroup resource by the update. The fields specified in the update_mask are relative to the + resource, not the full request. A field will be overwritten if it is in the mask. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + _consumer_group = ( + deepcopy(consumer_group) + if isinstance(consumer_group, dict) + else ConsumerGroup.to_dict(consumer_group) + ) + _consumer_group["name"] = client.consumer_group_path( + project_id, location, cluster_id, consumer_group_id + ) + + result = client.update_consumer_group( + request={ + "update_mask": update_mask, + "consumer_group": _consumer_group, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_consumer_group( + self, + project_id: str, + location: str, + cluster_id: str, + consumer_group_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Delete a single consumer group. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose consumer group is to be deleted. + :param consumer_group_id: Required. The ID of the consumer group to delete. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + name = client.consumer_group_path(project_id, location, cluster_id, consumer_group_id) + + client.delete_consumer_group( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py index 0aafe2f202daa..45b62901c5515 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py @@ -31,6 +31,10 @@ MANAGED_KAFKA_TOPIC_LINK = ( MANAGED_KAFKA_BASE_LINK + "/{location}/clusters/{cluster_id}/topics/{topic_id}?project={project_id}" ) +MANAGED_KAFKA_CONSUMER_GROUP_LINK = ( + MANAGED_KAFKA_BASE_LINK + + "/{location}/clusters/{cluster_id}/consumer_groups/{consumer_group_id}?project={project_id}" +) class ApacheKafkaClusterLink(BaseGoogleLink): @@ -102,3 +106,29 @@ def persist( "project_id": task_instance.project_id, }, ) + + +class ApacheKafkaConsumerGroupLink(BaseGoogleLink): + """Helper class for constructing Apache Kafka Consumer Group link.""" + + name = "Apache Kafka Consumer Group" + key = "consumer_group_conf" + format_str = MANAGED_KAFKA_CONSUMER_GROUP_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + cluster_id: str, + consumer_group_id: str, + ): + task_instance.xcom_push( + context=context, + key=ApacheKafkaConsumerGroupLink.key, + value={ + "location": task_instance.location, + "cluster_id": cluster_id, + "consumer_group_id": consumer_group_id, + "project_id": task_instance.project_id, + }, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py index 0ded649858f62..b649149ccc0da 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py @@ -32,6 +32,7 @@ from airflow.providers.google.cloud.links.managed_kafka import ( ApacheKafkaClusterLink, ApacheKafkaClusterListLink, + ApacheKafkaConsumerGroupLink, ApacheKafkaTopicLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -788,3 +789,267 @@ def execute(self, context: Context): except NotFound as not_found_err: self.log.info("The Apache Kafka topic ID %s does not exist.", self.topic_id) raise AirflowException(not_found_err) + + +class ManagedKafkaListConsumerGroupsOperator(ManagedKafkaBaseOperator): + """ + List the consumer groups in a given cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose consumer groups are to be listed. + :param page_size: Optional. The maximum number of consumer groups to return. The service may return + fewer than this value. If unset or zero, all consumer groups for the parent is returned. + :param page_token: Optional. A page token, received from a previous ``ListConsumerGroups`` call. + Provide this to retrieve the subsequent page. When paginating, all other parameters provided to + ``ListConsumerGroups`` must match the call that provided the page token. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple({"cluster_id"} | set(ManagedKafkaBaseOperator.template_fields)) + operator_extra_links = (ApacheKafkaClusterLink(),) + + def __init__( + self, + cluster_id: str, + page_size: int | None = None, + page_token: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.page_size = page_size + self.page_token = page_token + + def execute(self, context: Context): + ApacheKafkaClusterLink.persist(context=context, task_instance=self, cluster_id=self.cluster_id) + self.log.info("Listing Consumer Groups for cluster %s.", self.cluster_id) + try: + consumer_group_list_pager = self.hook.list_consumer_groups( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + page_size=self.page_size, + page_token=self.page_token, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.xcom_push( + context=context, + key="consumer_group_page", + value=types.ListConsumerGroupsResponse.to_dict(consumer_group_list_pager._response), + ) + except Exception as error: + raise AirflowException(error) + return [types.ConsumerGroup.to_dict(consumer_group) for consumer_group in consumer_group_list_pager] + + +class ManagedKafkaGetConsumerGroupOperator(ManagedKafkaBaseOperator): + """ + Return the properties of a single consumer group. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose consumer group is to be returned. + :param consumer_group_id: Required. The ID of the consumer group whose configuration to return. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "consumer_group_id"} | set(ManagedKafkaBaseOperator.template_fields) + ) + operator_extra_links = (ApacheKafkaConsumerGroupLink(),) + + def __init__( + self, + cluster_id: str, + consumer_group_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.consumer_group_id = consumer_group_id + + def execute(self, context: Context): + ApacheKafkaConsumerGroupLink.persist( + context=context, + task_instance=self, + cluster_id=self.cluster_id, + consumer_group_id=self.consumer_group_id, + ) + self.log.info("Getting Consumer Group: %s", self.consumer_group_id) + try: + consumer_group = self.hook.get_consumer_group( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + consumer_group_id=self.consumer_group_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info( + "The consumer group %s from cluster %s was retrieved.", + self.consumer_group_id, + self.cluster_id, + ) + return types.ConsumerGroup.to_dict(consumer_group) + except NotFound as not_found_err: + self.log.info("The Consumer Group %s does not exist.", self.consumer_group_id) + raise AirflowException(not_found_err) + + +class ManagedKafkaUpdateConsumerGroupOperator(ManagedKafkaBaseOperator): + """ + Update the properties of a single consumer group. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topic is to be updated. + :param consumer_group_id: Required. The ID of the consumer group whose configuration to update. + :param consumer_group: Required. The consumer_group to update. Its ``name`` field must be populated. + :param update_mask: Required. Field mask is used to specify the fields to be overwritten in the + ConsumerGroup resource by the update. The fields specified in the update_mask are relative to the + resource, not the full request. A field will be overwritten if it is in the mask. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "consumer_group_id", "consumer_group", "update_mask"} + | set(ManagedKafkaBaseOperator.template_fields) + ) + operator_extra_links = (ApacheKafkaConsumerGroupLink(),) + + def __init__( + self, + cluster_id: str, + consumer_group_id: str, + consumer_group: types.Topic | dict, + update_mask: FieldMask | dict, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.consumer_group_id = consumer_group_id + self.consumer_group = consumer_group + self.update_mask = update_mask + + def execute(self, context: Context): + ApacheKafkaConsumerGroupLink.persist( + context=context, + task_instance=self, + cluster_id=self.cluster_id, + consumer_group_id=self.consumer_group_id, + ) + self.log.info("Updating an Apache Kafka consumer group.") + try: + consumer_group_obj = self.hook.update_consumer_group( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + consumer_group_id=self.consumer_group_id, + consumer_group=self.consumer_group, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Apache Kafka consumer group %s was updated.", self.consumer_group_id) + return types.ConsumerGroup.to_dict(consumer_group_obj) + except NotFound as not_found_err: + self.log.info("The Consumer Group %s does not exist.", self.consumer_group_id) + raise AirflowException(not_found_err) + except Exception as error: + raise AirflowException(error) + + +class ManagedKafkaDeleteConsumerGroupOperator(ManagedKafkaBaseOperator): + """ + Delete a single consumer group. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose consumer group is to be deleted. + :param consumer_group_id: Required. The ID of the consumer group to delete. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "consumer_group_id"} | set(ManagedKafkaBaseOperator.template_fields) + ) + + def __init__( + self, + cluster_id: str, + consumer_group_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.consumer_group_id = consumer_group_id + + def execute(self, context: Context): + try: + self.log.info("Deleting Apache Kafka consumer group: %s", self.consumer_group_id) + self.hook.delete_consumer_group( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + consumer_group_id=self.consumer_group_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Apache Kafka consumer group was deleted.") + except NotFound as not_found_err: + self.log.info("The Apache Kafka consumer group ID %s does not exist.", self.consumer_group_id) + raise AirflowException(not_found_err) diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index cd9b0a75acc6a..f97a32f8a5ff7 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -1570,6 +1570,7 @@ def get_provider_info(): "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink", "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink", "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink", + "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaConsumerGroupLink", ], "secrets-backends": [ "airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend" diff --git a/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py new file mode 100644 index 0000000000000..1c9929a569eb3 --- /dev/null +++ b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py @@ -0,0 +1,254 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Example Airflow DAG for Google Cloud Managed Service for Apache Kafka testing Topic operations. +""" + +from __future__ import annotations + +import json +import logging +import os +import random +from datetime import datetime + +from airflow.decorators import task +from airflow.models import Connection +from airflow.models.dag import DAG +from airflow.providers.apache.kafka.operators.consume import ConsumeFromTopicOperator +from airflow.providers.apache.kafka.operators.produce import ProduceToTopicOperator +from airflow.providers.google.cloud.operators.managed_kafka import ( + ManagedKafkaCreateClusterOperator, + ManagedKafkaCreateTopicOperator, + ManagedKafkaDeleteClusterOperator, + ManagedKafkaDeleteConsumerGroupOperator, + ManagedKafkaGetConsumerGroupOperator, + ManagedKafkaListConsumerGroupsOperator, + ManagedKafkaUpdateConsumerGroupOperator, +) +from airflow.settings import Session +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "managed_kafka_consumer_group_operations" +LOCATION = "us-central1" + +CLUSTER_ID = f"cluster_{DAG_ID}_{ENV_ID}".replace("_", "-") +CLUSTER_CONF = { + "gcp_config": { + "access_config": { + "network_configs": [ + {"subnet": f"projects/{PROJECT_ID}/regions/{LOCATION}/subnetworks/default"}, + ], + }, + }, + "capacity_config": { + "vcpu_count": 3, + "memory_bytes": 3221225472, + }, +} +TOPIC_ID = f"topic_{DAG_ID}_{ENV_ID}".replace("_", "-") +TOPIC_CONF = { + "partition_count": 3, + "replication_factor": 3, +} +CONSUMER_GROUP_ID = f"consumer_group_{DAG_ID}_{ENV_ID}".replace("_", "-") +CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}" +PORT = "9092" +BOOTSTRAP_URL = f"bootstrap.{CLUSTER_ID}.{LOCATION}.managedkafka.{PROJECT_ID}.cloud.goog:{PORT}" + +log = logging.getLogger(__name__) + + +def producer(): + """Produce and submit 10 messages""" + + for i in range(10): + now = datetime.now() + datetime_string = now.strftime("%Y-%m-%d %H:%M:%S") + + message_data = {"random_id": f"{ENV_ID}_{random.randint(1, 100)}", "date_time": datetime_string} + + yield ( + json.dumps(i), + json.dumps(message_data), + ) + + +def consumer(message): + "Take in consumed messages and print its contents to the logs." + + message_content = json.loads(message.value()) + random_id = message_content["random_id"] + date_time = message_content["date_time"] + log.info("id: %s, date_time: %s", random_id, date_time) + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "managed_kafka", "consumer_group"], +) as dag: + create_cluster = ManagedKafkaCreateClusterOperator( + task_id="create_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster=CLUSTER_CONF, + cluster_id=CLUSTER_ID, + ) + + create_topic = ManagedKafkaCreateTopicOperator( + task_id="create_topic", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + topic_id=TOPIC_ID, + topic=TOPIC_CONF, + ) + + @task + def create_connection(connection_id: str): + conn = Connection( + conn_id=connection_id, + conn_type="kafka", + ) + conn_extra = { + "bootstrap.servers": BOOTSTRAP_URL, + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "group.id": CONSUMER_GROUP_ID, + } + conn_extra_json = json.dumps(conn_extra) + conn.set_extra(conn_extra_json) + + session = Session() + log.info("Removing connection %s if it exists", connection_id) + query = session.query(Connection).filter(Connection.conn_id == connection_id) + query.delete() + + session.add(conn) + session.commit() + log.info("Connection %s created", connection_id) + + create_connection_task = create_connection(connection_id=CONNECTION_ID) + + # [START how_to_cloud_managed_kafka_produce_to_topic_operator] + produce_to_topic = ProduceToTopicOperator( + task_id="produce_to_topic", + kafka_config_id=CONNECTION_ID, + topic=TOPIC_ID, + producer_function=producer, + poll_timeout=10, + ) + # [END how_to_cloud_managed_kafka_produce_to_topic_operator] + + # [START how_to_cloud_managed_kafka_consume_from_topic_operator] + consume_from_topic = ConsumeFromTopicOperator( + task_id="consume_from_topic", + kafka_config_id=CONNECTION_ID, + topics=[TOPIC_ID], + apply_function=consumer, + poll_timeout=20, + max_messages=20, + max_batch_size=20, + ) + # [END how_to_cloud_managed_kafka_consume_from_topic_operator] + + # [START how_to_cloud_managed_kafka_update_consumer_group_operator] + update_consumer_group = ManagedKafkaUpdateConsumerGroupOperator( + task_id="update_consumer_group", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + consumer_group_id=CONSUMER_GROUP_ID, + consumer_group={}, + update_mask={}, + ) + # [END how_to_cloud_managed_kafka_update_consumer_group_operator] + + # [START how_to_cloud_managed_kafka_get_consumer_group_operator] + get_consumer_group = ManagedKafkaGetConsumerGroupOperator( + task_id="get_consumer_group", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + consumer_group_id=CONSUMER_GROUP_ID, + ) + # [END how_to_cloud_managed_kafka_get_consumer_group_operator] + + # [START how_to_cloud_managed_kafka_delete_consumer_group_operator] + delete_consumer_group = ManagedKafkaDeleteConsumerGroupOperator( + task_id="delete_consumer_group", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + consumer_group_id=CONSUMER_GROUP_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END how_to_cloud_managed_kafka_delete_consumer_group_operator] + + delete_cluster = ManagedKafkaDeleteClusterOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + + # [START how_to_cloud_managed_kafka_list_consumer_group_operator] + list_consumer_groups = ManagedKafkaListConsumerGroupsOperator( + task_id="list_consumer_groups", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + ) + # [END how_to_cloud_managed_kafka_list_consumer_group_operator] + + ( + # TEST SETUP + create_cluster + >> create_topic + >> create_connection_task + >> produce_to_topic + >> consume_from_topic + # TEST BODY + >> update_consumer_group + >> get_consumer_group + >> list_consumer_groups + >> delete_consumer_group + # TEST TEARDOWN + >> delete_cluster + ) + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py index 7261f079555cb..c8e2a131fca35 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py @@ -66,6 +66,8 @@ "replication_factor": 1912, } +TEST_CONSUMER_GROUP_ID: str = "test-consumer-group-id" + BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" MANAGED_KAFKA_STRING = "airflow.providers.google.cloud.hooks.managed_kafka.{}" @@ -301,6 +303,98 @@ def test_list_topics(self, mock_client) -> None: TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID ) + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_delete_consumer_group(self, mock_client) -> None: + self.hook.delete_consumer_group( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.delete_consumer_group.assert_called_once_with( + request=dict(name=mock_client.return_value.consumer_group_path.return_value), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.consumer_group_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_CONSUMER_GROUP_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_get_consumer_group(self, mock_client) -> None: + self.hook.get_consumer_group( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.get_consumer_group.assert_called_once_with( + request=dict( + name=mock_client.return_value.consumer_group_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.consumer_group_path.assert_called_once_with( + TEST_PROJECT_ID, + TEST_LOCATION, + TEST_CLUSTER_ID, + TEST_CONSUMER_GROUP_ID, + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_update_consumer_group(self, mock_client) -> None: + self.hook.update_consumer_group( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + consumer_group={}, + update_mask={}, + ) + mock_client.assert_called_once() + mock_client.return_value.update_consumer_group.assert_called_once_with( + request=dict( + update_mask={}, + consumer_group={ + "name": mock_client.return_value.consumer_group_path.return_value, + **{}, + }, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.consumer_group_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_CONSUMER_GROUP_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_list_consumer_groups(self, mock_client) -> None: + self.hook.list_consumer_groups( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.list_consumer_groups.assert_called_once_with( + request=dict( + parent=mock_client.return_value.cluster_path.return_value, + page_size=None, + page_token=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + class TestManagedKafkaWithoutDefaultProjectIdHook: def setup_method(self): @@ -535,3 +629,98 @@ def test_list_topics(self, mock_client) -> None: mock_client.return_value.cluster_path.assert_called_once_with( TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_delete_consumer_group(self, mock_client) -> None: + self.hook.delete_consumer_group( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.delete_consumer_group.assert_called_once_with( + request=dict(name=mock_client.return_value.consumer_group_path.return_value), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.consumer_group_path.assert_called_once_with( + TEST_PROJECT_ID, + TEST_LOCATION, + TEST_CLUSTER_ID, + TEST_CONSUMER_GROUP_ID, + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_get_consumer_group(self, mock_client) -> None: + self.hook.get_consumer_group( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.get_consumer_group.assert_called_once_with( + request=dict( + name=mock_client.return_value.consumer_group_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.consumer_group_path.assert_called_once_with( + TEST_PROJECT_ID, + TEST_LOCATION, + TEST_CLUSTER_ID, + TEST_CONSUMER_GROUP_ID, + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_update_consumer_group(self, mock_client) -> None: + self.hook.update_consumer_group( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + consumer_group={}, + update_mask={}, + ) + mock_client.assert_called_once() + mock_client.return_value.update_consumer_group.assert_called_once_with( + request=dict( + update_mask={}, + consumer_group={ + "name": mock_client.return_value.consumer_group_path.return_value, + **{}, + }, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.consumer_group_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_CONSUMER_GROUP_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_list_consumer_groups(self, mock_client) -> None: + self.hook.list_consumer_groups( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.list_consumer_groups.assert_called_once_with( + request=dict( + parent=mock_client.return_value.cluster_path.return_value, + page_size=None, + page_token=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) diff --git a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py index 7bf671c68e669..8867b8c0616d2 100644 --- a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py +++ b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py @@ -22,6 +22,7 @@ from airflow.providers.google.cloud.links.managed_kafka import ( ApacheKafkaClusterLink, ApacheKafkaClusterListLink, + ApacheKafkaConsumerGroupLink, ApacheKafkaTopicLink, ) @@ -29,6 +30,7 @@ TEST_CLUSTER_ID = "test-cluster-id" TEST_PROJECT_ID = "test-project-id" TEST_TOPIC_ID = "test-topic-id" +TEST_CONSUMER_GROUP_ID = "test-consumer-group-id" EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_NAME = "Apache Kafka Cluster" EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_KEY = "cluster_conf" EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_FORMAT_STR = ( @@ -42,6 +44,11 @@ EXPECTED_MANAGED_KAFKA_TOPIC_LINK_FORMAT_STR = ( "/managedkafka/{location}/clusters/{cluster_id}/topics/{topic_id}?project={project_id}" ) +EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_NAME = "Apache Kafka Consumer Group" +EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_KEY = "consumer_group_conf" +EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_FORMAT_STR = ( + "/managedkafka/{location}/clusters/{cluster_id}/consumer_groups/{consumer_group_id}?project={project_id}" +) class TestApacheKafkaClusterLink: @@ -125,3 +132,36 @@ def test_persist(self): "project_id": TEST_PROJECT_ID, }, ) + + +class TestApacheKafkaConsumerGroupLink: + def test_class_attributes(self): + assert ApacheKafkaConsumerGroupLink.key == EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_KEY + assert ApacheKafkaConsumerGroupLink.name == EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_NAME + assert ( + ApacheKafkaConsumerGroupLink.format_str == EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_FORMAT_STR + ) + + def test_persist(self): + mock_context, mock_task_instance = ( + mock.MagicMock(), + mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID), + ) + + ApacheKafkaConsumerGroupLink.persist( + context=mock_context, + task_instance=mock_task_instance, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + ) + + mock_task_instance.xcom_push.assert_called_once_with( + context=mock_context, + key=EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_KEY, + value={ + "location": TEST_LOCATION, + "cluster_id": TEST_CLUSTER_ID, + "consumer_group_id": TEST_CONSUMER_GROUP_ID, + "project_id": TEST_PROJECT_ID, + }, + ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py index e9407cc0a50ca..fd41068201439 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py +++ b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py @@ -24,12 +24,16 @@ ManagedKafkaCreateClusterOperator, ManagedKafkaCreateTopicOperator, ManagedKafkaDeleteClusterOperator, + ManagedKafkaDeleteConsumerGroupOperator, ManagedKafkaDeleteTopicOperator, ManagedKafkaGetClusterOperator, + ManagedKafkaGetConsumerGroupOperator, ManagedKafkaGetTopicOperator, ManagedKafkaListClustersOperator, + ManagedKafkaListConsumerGroupsOperator, ManagedKafkaListTopicsOperator, ManagedKafkaUpdateClusterOperator, + ManagedKafkaUpdateConsumerGroupOperator, ManagedKafkaUpdateTopicOperator, ) @@ -80,6 +84,8 @@ "replication_factor": 1912, } +TEST_CONSUMER_GROUP_ID: str = "test-consumer-group-id" + class TestManagedKafkaCreateClusterOperator: @mock.patch(MANAGED_KAFKA_PATH.format("types.Cluster.to_dict")) @@ -393,3 +399,128 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) + + +class TestManagedKafkaListConsumerGroupsOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.ListConsumerGroupsResponse.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("types.ConsumerGroup.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_cluster_dict_mock, to_clusters_dict_mock): + page_token = "page_token" + page_size = 42 + + op = ManagedKafkaListConsumerGroupsOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + page_size=page_size, + page_token=page_token, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.list_consumer_groups.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + page_size=page_size, + page_token=page_token, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaGetConsumerGroupOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.ConsumerGroup.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ManagedKafkaGetConsumerGroupOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.get_consumer_group.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaUpdateConsumerGroupOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.ConsumerGroup.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ManagedKafkaUpdateConsumerGroupOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + consumer_group={}, + update_mask={}, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.update_consumer_group.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + consumer_group={}, + update_mask={}, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaDeleteConsumerGroupOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook): + op = ManagedKafkaDeleteConsumerGroupOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.delete_consumer_group.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + consumer_group_id=TEST_CONSUMER_GROUP_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + )