diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 081fa42b6ca32..07da0177a7e66 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -661,6 +661,7 @@ "google-cloud-kms>=2.15.0", "google-cloud-language>=2.9.0", "google-cloud-logging>=3.5.0", + "google-cloud-managedkafka>=0.1.6", "google-cloud-memcache>=1.7.0", "google-cloud-monitoring>=2.18.0", "google-cloud-orchestration-airflow>=1.10.0", diff --git a/providers/google/README.rst b/providers/google/README.rst index b1f413e70ed0d..a71f913479c41 100644 --- a/providers/google/README.rst +++ b/providers/google/README.rst @@ -95,6 +95,7 @@ PIP package Version required ``google-cloud-kms`` ``>=2.15.0`` ``google-cloud-language`` ``>=2.9.0`` ``google-cloud-logging`` ``>=3.5.0`` +``google-cloud-managedkafka`` ``>=0.1.6`` ``google-cloud-memcache`` ``>=1.7.0`` ``google-cloud-monitoring`` ``>=2.18.0`` ``google-cloud-orchestration-airflow`` ``>=1.10.0`` diff --git a/providers/google/docs/operators/cloud/managed_kafka.rst b/providers/google/docs/operators/cloud/managed_kafka.rst new file mode 100644 index 0000000000000..0016076f183f7 --- /dev/null +++ b/providers/google/docs/operators/cloud/managed_kafka.rst @@ -0,0 +1,78 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Google Cloud Managed Service for Apache Kafka Operators +======================================================= + +The `Google Cloud Managed Service for Apache Kafka `__ +helps you set up, secure, maintain, and scale Apache Kafka clusters. + +Interacting with Apache Kafka Cluster +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To create an Apache Kafka cluster you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaCreateClusterOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_cluster.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_create_cluster_operator] + :end-before: [END how_to_cloud_managed_kafka_create_cluster_operator] + +To delete cluster you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaDeleteClusterOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_cluster.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_delete_cluster_operator] + :end-before: [END how_to_cloud_managed_kafka_delete_cluster_operator] + +To get cluster you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaGetClusterOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_cluster.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_get_cluster_operator] + :end-before: [END how_to_cloud_managed_kafka_get_cluster_operator] + +To get a list of clusters you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaListClustersOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_cluster.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_list_cluster_operator] + :end-before: [END how_to_cloud_managed_kafka_list_cluster_operator] + +To update cluster you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaUpdateClusterOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_cluster.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_update_cluster_operator] + :end-before: [END how_to_cloud_managed_kafka_update_cluster_operator] + +Reference +^^^^^^^^^ + +For further information, look at: + +* `Client Library Documentation `__ +* `Product Documentation `__ diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml index 665a6943342a8..525026c54ff34 100644 --- a/providers/google/provider.yaml +++ b/providers/google/provider.yaml @@ -199,6 +199,11 @@ integrations: - /docs/apache-airflow-providers-google/operators/cloud/life_sciences.rst logo: /docs/integration-logos/Google-Cloud-Life-Sciences.png tags: [gcp] + - integration-name: Google Cloud Managed Service for Apache Kafka + external-doc-url: https://cloud.google.com/managed-service-for-apache-kafka/docs/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/managed_kafka.rst + tags: [gcp] - integration-name: Google Cloud Memorystore external-doc-url: https://cloud.google.com/memorystore/ how-to-guide: @@ -599,6 +604,9 @@ operators: - integration-name: Google Cloud Batch python-modules: - airflow.providers.google.cloud.operators.cloud_batch + - integration-name: Google Cloud Managed Service for Apache Kafka + python-modules: + - airflow.providers.google.cloud.operators.managed_kafka sensors: - integration-name: Google BigQuery @@ -877,6 +885,9 @@ hooks: - integration-name: Google Cloud Batch python-modules: - airflow.providers.google.cloud.hooks.cloud_batch + - integration-name: Google Cloud Managed Service for Apache Kafka + python-modules: + - airflow.providers.google.cloud.hooks.managed_kafka triggers: - integration-name: Google BigQuery Data Transfer Service @@ -1215,6 +1226,8 @@ extra-links: - airflow.providers.google.cloud.links.translate.TranslationModelsListLink - airflow.providers.google.cloud.links.translate.TranslateResultByOutputConfigLink - airflow.providers.google.cloud.links.translate.TranslationGlossariesListLink + - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink + - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink secrets-backends: diff --git a/providers/google/pyproject.toml b/providers/google/pyproject.toml index 11ff422d3d03e..0f3601dbaa1e6 100644 --- a/providers/google/pyproject.toml +++ b/providers/google/pyproject.toml @@ -93,6 +93,7 @@ dependencies = [ "google-cloud-kms>=2.15.0", "google-cloud-language>=2.9.0", "google-cloud-logging>=3.5.0", + "google-cloud-managedkafka>=0.1.6", "google-cloud-memcache>=1.7.0", "google-cloud-monitoring>=2.18.0", "google-cloud-orchestration-airflow>=1.10.0", diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py new file mode 100644 index 0000000000000..48768666f8fe0 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py @@ -0,0 +1,288 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Managed Service for Apache Kafka hook.""" + +from __future__ import annotations + +from collections.abc import Sequence +from copy import deepcopy +from typing import TYPE_CHECKING + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.consts import CLIENT_INFO +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud.managedkafka_v1 import Cluster, ManagedKafkaClient, types + +if TYPE_CHECKING: + from google.api_core.operation import Operation + from google.api_core.retry import Retry + from google.cloud.managedkafka_v1.services.managed_kafka.pagers import ListClustersPager + from google.protobuf.field_mask_pb2 import FieldMask + + +class ManagedKafkaHook(GoogleBaseHook): + """Hook for Managed Service for Apache Kafka APIs.""" + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(gcp_conn_id, impersonation_chain, **kwargs) + + def get_managed_kafka_client(self) -> ManagedKafkaClient: + """Return ManagedKafkaClient object.""" + return ManagedKafkaClient( + credentials=self.get_credentials(), + client_info=CLIENT_INFO, + ) + + def wait_for_operation(self, operation: Operation, timeout: float | None = None): + """Wait for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + @GoogleBaseHook.fallback_to_default_project_id + def create_cluster( + self, + project_id: str, + location: str, + cluster: types.Cluster | dict, + cluster_id: str, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Create a new Apache Kafka cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster: Required. Configuration of the cluster to create. Its ``name`` field is ignored. + :param cluster_id: Required. The ID to use for the cluster, which will become the final component of + the cluster's name. The ID must be 1-63 characters long, and match the regular expression + ``[a-z]([-a-z0-9]*[a-z0-9])?`` to comply with RFC 1035. This value is structured like: ``my-cluster-id``. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + to avoid duplication of requests. If a request times out or fails, retrying with the same ID + allows the server to recognize the previous attempt. For at least 60 minutes, the server ignores + duplicate requests bearing the same ID. For example, consider a situation where you make an + initial request and the request times out. If you make the request again with the same request ID + within 60 minutes of the last request, the server checks if an original operation with the same + request ID was received. If so, the server ignores the second request. The request ID must be a + valid UUID. A zero UUID is not supported (00000000-0000-0000-0000-000000000000). + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + parent = client.common_location_path(project_id, location) + + operation = client.create_cluster( + request={ + "parent": parent, + "cluster_id": cluster_id, + "cluster": cluster, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + def list_clusters( + self, + project_id: str, + location: str, + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ListClustersPager: + """ + List the clusters in a given project and location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param page_size: Optional. The maximum number of clusters to return. The service may return fewer + than this value. If unspecified, server will pick an appropriate default. + :param page_token: Optional. A page token, received from a previous ``ListClusters`` call. Provide + this to retrieve the subsequent page. + When paginating, all other parameters provided to ``ListClusters`` must match the call that + provided the page token. + :param filter: Optional. Filter expression for the result. + :param order_by: Optional. Order by fields for the result. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + parent = client.common_location_path(project_id, location) + + result = client.list_clusters( + request={ + "parent": parent, + "page_size": page_size, + "page_token": page_token, + "filter": filter, + "order_by": order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_cluster( + self, + project_id: str, + location: str, + cluster_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> types.Cluster: + """ + Return the properties of a single cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose configuration to return. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + name = client.cluster_path(project_id, location, cluster_id) + + result = client.get_cluster( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_cluster( + self, + project_id: str, + location: str, + cluster_id: str, + cluster: types.Cluster | dict, + update_mask: FieldMask | dict, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Update the properties of a single cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose configuration to update. + :param cluster: Required. The cluster to update. + :param update_mask: Required. Field mask is used to specify the fields to be overwritten in the + cluster resource by the update. The fields specified in the update_mask are relative to the + resource, not the full request. A field will be overwritten if it is in the mask. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + to avoid duplication of requests. If a request times out or fails, retrying with the same ID + allows the server to recognize the previous attempt. For at least 60 minutes, the server ignores + duplicate requests bearing the same ID. + For example, consider a situation where you make an initial request and the request times out. If + you make the request again with the same request ID within 60 minutes of the last request, the + server checks if an original operation with the same request ID was received. If so, the server + ignores the second request. + The request ID must be a valid UUID. A zero UUID is not supported (00000000-0000-0000-0000-000000000000). + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + _cluster = deepcopy(cluster) if isinstance(cluster, dict) else Cluster.to_dict(cluster) + _cluster["name"] = client.cluster_path(project_id, location, cluster_id) + + operation = client.update_cluster( + request={ + "update_mask": update_mask, + "cluster": _cluster, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + def delete_cluster( + self, + project_id: str, + location: str, + cluster_id: str, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete a single cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster to delete. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + to avoid duplication of requests. If a request times out or fails, retrying with the same ID + allows the server to recognize the previous attempt. For at least 60 minutes, the server ignores + duplicate requests bearing the same ID. + For example, consider a situation where you make an initial request and the request times out. If + you make the request again with the same request ID within 60 minutes of the last request, the + server checks if an original operation with the same request ID was received. If so, the server + ignores the second request. + The request ID must be a valid UUID. A zero UUID is not supported (00000000-0000-0000-0000-000000000000). + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + name = client.cluster_path(project_id, location, cluster_id) + + operation = client.delete_cluster( + request={ + "name": name, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation diff --git a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py new file mode 100644 index 0000000000000..00c626b3814a8 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +MANAGED_KAFKA_BASE_LINK = "/managedkafka" +MANAGED_KAFKA_CLUSTER_LINK = ( + MANAGED_KAFKA_BASE_LINK + "/{location}/clusters/{cluster_id}?project={project_id}" +) +MANAGED_KAFKA_CLUSTER_LIST_LINK = MANAGED_KAFKA_BASE_LINK + "/clusters?project={project_id}" + + +class ApacheKafkaClusterLink(BaseGoogleLink): + """Helper class for constructing Apache Kafka Cluster link.""" + + name = "Apache Kafka Cluster" + key = "cluster_conf" + format_str = MANAGED_KAFKA_CLUSTER_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + cluster_id: str, + ): + task_instance.xcom_push( + context=context, + key=ApacheKafkaClusterLink.key, + value={ + "location": task_instance.location, + "cluster_id": cluster_id, + "project_id": task_instance.project_id, + }, + ) + + +class ApacheKafkaClusterListLink(BaseGoogleLink): + """Helper class for constructing Apache Kafka Clusters link.""" + + name = "Apache Kafka Cluster List" + key = "cluster_list_conf" + format_str = MANAGED_KAFKA_CLUSTER_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + ): + task_instance.xcom_push( + context=context, + key=ApacheKafkaClusterListLink.key, + value={ + "project_id": task_instance.project_id, + }, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py new file mode 100644 index 0000000000000..ebf03856216dd --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py @@ -0,0 +1,451 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Managed Service for Apache Kafka operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.managed_kafka import ManagedKafkaHook +from airflow.providers.google.cloud.links.managed_kafka import ( + ApacheKafkaClusterLink, + ApacheKafkaClusterListLink, +) +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from google.api_core.exceptions import AlreadyExists, NotFound +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud.managedkafka_v1 import types + +if TYPE_CHECKING: + from airflow.utils.context import Context + from google.api_core.retry import Retry + from google.protobuf.field_mask_pb2 import FieldMask + + +class ManagedKafkaBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Managed Kafka operators. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "location", + "gcp_conn_id", + "project_id", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> ManagedKafkaHook: + return ManagedKafkaHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class ManagedKafkaCreateClusterOperator(ManagedKafkaBaseOperator): + """ + Create a new Apache Kafka cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster: Required. Configuration of the cluster to create. Its ``name`` field is ignored. + :param cluster_id: Required. The ID to use for the cluster, which will become the final component of + the cluster's name. The ID must be 1-63 characters long, and match the regular expression + ``[a-z]([-a-z0-9]*[a-z0-9])?`` to comply with RFC 1035. This value is structured like: ``my-cluster-id``. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + to avoid duplication of requests. If a request times out or fails, retrying with the same ID + allows the server to recognize the previous attempt. For at least 60 minutes, the server ignores + duplicate requests bearing the same ID. For example, consider a situation where you make an + initial request and the request times out. If you make the request again with the same request ID + within 60 minutes of the last request, the server checks if an original operation with the same + request ID was received. If so, the server ignores the second request. The request ID must be a + valid UUID. A zero UUID is not supported (00000000-0000-0000-0000-000000000000). + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster", "cluster_id"} | set(ManagedKafkaBaseOperator.template_fields) + ) + operator_extra_links = (ApacheKafkaClusterLink(),) + + def __init__( + self, + cluster: types.Cluster | dict, + cluster_id: str, + request_id: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster = cluster + self.cluster_id = cluster_id + self.request_id = request_id + + def execute(self, context: Context): + self.log.info("Creating an Apache Kafka cluster.") + ApacheKafkaClusterLink.persist(context=context, task_instance=self, cluster_id=self.cluster_id) + try: + operation = self.hook.create_cluster( + project_id=self.project_id, + location=self.location, + cluster=self.cluster, + cluster_id=self.cluster_id, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Waiting for operation to complete...") + cluster = self.hook.wait_for_operation(operation=operation, timeout=self.timeout) + self.log.info("Apache Kafka cluster was created.") + return types.Cluster.to_dict(cluster) + except AlreadyExists: + self.log.info("Apache Kafka cluster %s already exists.", self.cluster_id) + cluster = self.hook.get_cluster( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return types.Cluster.to_dict(cluster) + + +class ManagedKafkaListClustersOperator(ManagedKafkaBaseOperator): + """ + List the clusters in a given project and location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param page_size: Optional. The maximum number of clusters to return. The service may return fewer + than this value. If unspecified, server will pick an appropriate default. + :param page_token: Optional. A page token, received from a previous ``ListClusters`` call. Provide + this to retrieve the subsequent page. + When paginating, all other parameters provided to ``ListClusters`` must match the call that + provided the page token. + :param filter: Optional. Filter expression for the result. + :param order_by: Optional. Order by fields for the result. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple({"page_token"} | set(ManagedKafkaBaseOperator.template_fields)) + operator_extra_links = (ApacheKafkaClusterListLink(),) + + def __init__( + self, + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.page_size = page_size + self.page_token = page_token + self.filter = filter + self.order_by = order_by + + def execute(self, context: Context): + ApacheKafkaClusterListLink.persist(context=context, task_instance=self) + self.log.info("Listing Clusters from location %s.", self.location) + try: + cluster_list_pager = self.hook.list_clusters( + project_id=self.project_id, + location=self.location, + page_size=self.page_size, + page_token=self.page_token, + filter=self.filter, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.xcom_push( + context=context, + key="cluster_page", + value=types.ListClustersResponse.to_dict(cluster_list_pager._response), + ) + except Exception as error: + raise AirflowException(error) + return [types.Cluster.to_dict(cluster) for cluster in cluster_list_pager] + + +class ManagedKafkaGetClusterOperator(ManagedKafkaBaseOperator): + """ + Get an Apache Kafka cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose configuration to return. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple({"cluster_id"} | set(ManagedKafkaBaseOperator.template_fields)) + operator_extra_links = (ApacheKafkaClusterLink(),) + + def __init__( + self, + cluster_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + + def execute(self, context: Context): + ApacheKafkaClusterLink.persist( + context=context, + task_instance=self, + cluster_id=self.cluster_id, + ) + self.log.info("Getting Cluster: %s", self.cluster_id) + try: + cluster = self.hook.get_cluster( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Cluster was gotten.") + return types.Cluster.to_dict(cluster) + except NotFound as not_found_err: + self.log.info("The Cluster %s does not exist.", self.cluster_id) + raise AirflowException(not_found_err) + + +class ManagedKafkaUpdateClusterOperator(ManagedKafkaBaseOperator): + """ + Update the properties of a single cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose configuration to update. + :param cluster: Required. The cluster to update. + :param update_mask: Required. Field mask is used to specify the fields to be overwritten in the + cluster resource by the update. The fields specified in the update_mask are relative to the + resource, not the full request. A field will be overwritten if it is in the mask. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + to avoid duplication of requests. If a request times out or fails, retrying with the same ID + allows the server to recognize the previous attempt. For at least 60 minutes, the server ignores + duplicate requests bearing the same ID. + For example, consider a situation where you make an initial request and the request times out. If + you make the request again with the same request ID within 60 minutes of the last request, the + server checks if an original operation with the same request ID was received. If so, the server + ignores the second request. + The request ID must be a valid UUID. A zero UUID is not supported (00000000-0000-0000-0000-000000000000). + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "cluster", "update_mask"} | set(ManagedKafkaBaseOperator.template_fields) + ) + operator_extra_links = (ApacheKafkaClusterLink(),) + + def __init__( + self, + cluster_id: str, + cluster: types.Cluster | dict, + update_mask: FieldMask | dict, + request_id: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.cluster = cluster + self.update_mask = update_mask + self.request_id = request_id + + def execute(self, context: Context): + ApacheKafkaClusterLink.persist( + context=context, + task_instance=self, + cluster_id=self.cluster_id, + ) + self.log.info("Updating an Apache Kafka cluster.") + try: + operation = self.hook.update_cluster( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + cluster=self.cluster, + update_mask=self.update_mask, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Waiting for operation to complete...") + cluster = self.hook.wait_for_operation(operation=operation, timeout=self.timeout) + self.log.info("Apache Kafka cluster %s was updated.", self.cluster_id) + return types.Cluster.to_dict(cluster) + except NotFound as not_found_err: + self.log.info("The Cluster %s does not exist.", self.cluster_id) + raise AirflowException(not_found_err) + except Exception as error: + raise AirflowException(error) + + +class ManagedKafkaDeleteClusterOperator(ManagedKafkaBaseOperator): + """ + Delete an Apache Kafka cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster to delete. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + to avoid duplication of requests. If a request times out or fails, retrying with the same ID + allows the server to recognize the previous attempt. For at least 60 minutes, the server ignores + duplicate requests bearing the same ID. + For example, consider a situation where you make an initial request and the request times out. If + you make the request again with the same request ID within 60 minutes of the last request, the + server checks if an original operation with the same request ID was received. If so, the server + ignores the second request. + The request ID must be a valid UUID. A zero UUID is not supported (00000000-0000-0000-0000-000000000000). + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple({"cluster_id"} | set(ManagedKafkaBaseOperator.template_fields)) + + def __init__( + self, + cluster_id: str, + request_id: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.request_id = request_id + + def execute(self, context: Context): + try: + self.log.info("Deleting Apache Kafka cluster: %s", self.cluster_id) + operation = self.hook.delete_cluster( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Waiting for operation to complete...") + self.hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Apache Kafka cluster was deleted.") + except NotFound as not_found_err: + self.log.info("The Apache Kafka cluster ID %s does not exist.", self.cluster_id) + raise AirflowException(not_found_err) diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index cc3e9e768e783..c01bf8bed0108 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -220,6 +220,12 @@ def get_provider_info(): "logo": "/docs/integration-logos/Google-Cloud-Life-Sciences.png", "tags": ["gcp"], }, + { + "integration-name": "Google Cloud Managed Service for Apache Kafka", + "external-doc-url": "https://cloud.google.com/managed-service-for-apache-kafka/docs/", + "how-to-guide": ["/docs/apache-airflow-providers-google/operators/cloud/managed_kafka.rst"], + "tags": ["gcp"], + }, { "integration-name": "Google Cloud Memorystore", "external-doc-url": "https://cloud.google.com/memorystore/", @@ -753,6 +759,10 @@ def get_provider_info(): "integration-name": "Google Cloud Batch", "python-modules": ["airflow.providers.google.cloud.operators.cloud_batch"], }, + { + "integration-name": "Google Cloud Managed Service for Apache Kafka", + "python-modules": ["airflow.providers.google.cloud.operators.managed_kafka"], + }, ], "sensors": [ { @@ -1112,6 +1122,10 @@ def get_provider_info(): "integration-name": "Google Cloud Batch", "python-modules": ["airflow.providers.google.cloud.hooks.cloud_batch"], }, + { + "integration-name": "Google Cloud Managed Service for Apache Kafka", + "python-modules": ["airflow.providers.google.cloud.hooks.managed_kafka"], + }, ], "triggers": [ { @@ -1551,6 +1565,8 @@ def get_provider_info(): "airflow.providers.google.cloud.links.translate.TranslationModelsListLink", "airflow.providers.google.cloud.links.translate.TranslateResultByOutputConfigLink", "airflow.providers.google.cloud.links.translate.TranslationGlossariesListLink", + "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink", + "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink", ], "secrets-backends": [ "airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend" @@ -1595,6 +1611,7 @@ def get_provider_info(): "google-cloud-kms>=2.15.0", "google-cloud-language>=2.9.0", "google-cloud-logging>=3.5.0", + "google-cloud-managedkafka>=0.1.6", "google-cloud-memcache>=1.7.0", "google-cloud-monitoring>=2.18.0", "google-cloud-orchestration-airflow>=1.10.0", diff --git a/providers/google/tests/system/google/cloud/managed_kafka/__init__.py b/providers/google/tests/system/google/cloud/managed_kafka/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/google/tests/system/google/cloud/managed_kafka/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_cluster.py b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_cluster.py new file mode 100644 index 0000000000000..52adf9873a76b --- /dev/null +++ b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_cluster.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Example Airflow DAG for Google Cloud Managed Service for Apache Kafka testing Cluster operations. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.managed_kafka import ( + ManagedKafkaCreateClusterOperator, + ManagedKafkaDeleteClusterOperator, + ManagedKafkaGetClusterOperator, + ManagedKafkaListClustersOperator, + ManagedKafkaUpdateClusterOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "managed_kafka_cluster_operations" +LOCATION = "us-central1" + +CLUSTER_ID = f"cluster_{DAG_ID}_{ENV_ID}".replace("_", "-") +CLUSTER_CONF = { + "gcp_config": { + "access_config": { + "network_configs": [ + {"subnet": f"projects/{PROJECT_ID}/regions/{LOCATION}/subnetworks/default"}, + ], + }, + }, + "capacity_config": { + "vcpu_count": 3, + "memory_bytes": 3221225472, + }, +} +CLUSTER_TO_UPDATE = { + "capacity_config": { + "vcpu_count": 3, + "memory_bytes": 8589934592, + } +} +CLUSTER_UPDATE_MASK = {"paths": ["capacity_config"]} + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "managed_kafka", "cluster"], +) as dag: + # [START how_to_cloud_managed_kafka_create_cluster_operator] + create_cluster = ManagedKafkaCreateClusterOperator( + task_id="create_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster=CLUSTER_CONF, + cluster_id=CLUSTER_ID, + ) + # [END how_to_cloud_managed_kafka_create_cluster_operator] + + # [START how_to_cloud_managed_kafka_update_cluster_operator] + update_cluster = ManagedKafkaUpdateClusterOperator( + task_id="update_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + cluster=CLUSTER_TO_UPDATE, + update_mask=CLUSTER_UPDATE_MASK, + ) + # [END how_to_cloud_managed_kafka_update_cluster_operator] + + # [START how_to_cloud_managed_kafka_get_cluster_operator] + get_cluster = ManagedKafkaGetClusterOperator( + task_id="get_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + ) + # [END how_to_cloud_managed_kafka_get_cluster_operator] + + # [START how_to_cloud_managed_kafka_delete_cluster_operator] + delete_cluster = ManagedKafkaDeleteClusterOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END how_to_cloud_managed_kafka_delete_cluster_operator] + + # [START how_to_cloud_managed_kafka_list_cluster_operator] + list_clusters = ManagedKafkaListClustersOperator( + task_id="list_clusters", + project_id=PROJECT_ID, + location=LOCATION, + ) + # [END how_to_cloud_managed_kafka_list_cluster_operator] + + ( + [ + create_cluster >> update_cluster >> get_cluster >> delete_cluster, + list_clusters, + ] + ) + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py new file mode 100644 index 0000000000000..16cb0d35cb9f1 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py @@ -0,0 +1,291 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from google.api_core.gapic_v1.method import DEFAULT + +from airflow.providers.google.cloud.hooks.managed_kafka import ManagedKafkaHook +from unit.google.cloud.utils.base_gcp_mock import ( + mock_base_gcp_hook_default_project_id, + mock_base_gcp_hook_no_default_project_id, +) + +TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_LOCATION: str = "test-location" +TEST_PROJECT_ID: str = "test-project-id" +TEST_CLUSTER_ID: str = "test-cluster-id" +TEST_CLUSTER: dict = { + "gcp_config": { + "access_config": { + "network_configs": { + "subnet": "subnet_value", + }, + }, + }, + "capacity_config": { + "vcpu_count": 1094, + "memory_bytes": 1311, + }, +} +TEST_CLUSTER_UPDATE_MASK: dict = {"paths": ["gcp_config.access_config.network_configs.subnet"]} +TEST_UPDATED_CLUSTER: dict = { + "gcp_config": { + "access_config": { + "network_configs": { + "subnet": "new_subnet_value", + }, + }, + }, +} + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +MANAGED_KAFKA_STRING = "airflow.providers.google.cloud.hooks.managed_kafka.{}" + + +class TestManagedKafkaWithDefaultProjectIdHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = ManagedKafkaHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_create_cluster(self, mock_client) -> None: + self.hook.create_cluster( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster=TEST_CLUSTER, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.create_cluster.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + cluster=TEST_CLUSTER, + cluster_id=TEST_CLUSTER_ID, + request_id=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_LOCATION) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_delete_cluster(self, mock_client) -> None: + self.hook.delete_cluster( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.delete_cluster.assert_called_once_with( + request=dict(name=mock_client.return_value.cluster_path.return_value, request_id=None), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_get_cluster(self, mock_client) -> None: + self.hook.get_cluster( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.get_cluster.assert_called_once_with( + request=dict( + name=mock_client.return_value.cluster_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_update_cluster(self, mock_client) -> None: + self.hook.update_cluster( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster=TEST_UPDATED_CLUSTER, + cluster_id=TEST_CLUSTER_ID, + update_mask=TEST_CLUSTER_UPDATE_MASK, + ) + mock_client.assert_called_once() + mock_client.return_value.update_cluster.assert_called_once_with( + request=dict( + update_mask=TEST_CLUSTER_UPDATE_MASK, + cluster={ + "name": mock_client.return_value.cluster_path.return_value, + **TEST_UPDATED_CLUSTER, + }, + request_id=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_list_clusters(self, mock_client) -> None: + self.hook.list_clusters( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + ) + mock_client.assert_called_once() + mock_client.return_value.list_clusters.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + order_by=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_LOCATION) + + +class TestManagedKafkaWithoutDefaultProjectIdHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id + ): + self.hook = ManagedKafkaHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_create_cluster(self, mock_client) -> None: + self.hook.create_cluster( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster=TEST_CLUSTER, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.create_cluster.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + cluster=TEST_CLUSTER, + cluster_id=TEST_CLUSTER_ID, + request_id=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_LOCATION) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_delete_cluster(self, mock_client) -> None: + self.hook.delete_cluster( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.delete_cluster.assert_called_once_with( + request=dict(name=mock_client.return_value.cluster_path.return_value, request_id=None), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_get_cluster(self, mock_client) -> None: + self.hook.get_cluster( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.get_cluster.assert_called_once_with( + request=dict( + name=mock_client.return_value.cluster_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_update_cluster(self, mock_client) -> None: + self.hook.update_cluster( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster=TEST_UPDATED_CLUSTER, + cluster_id=TEST_CLUSTER_ID, + update_mask=TEST_CLUSTER_UPDATE_MASK, + ) + mock_client.assert_called_once() + mock_client.return_value.update_cluster.assert_called_once_with( + request=dict( + update_mask=TEST_CLUSTER_UPDATE_MASK, + cluster={ + "name": mock_client.return_value.cluster_path.return_value, + **TEST_UPDATED_CLUSTER, + }, + request_id=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_list_clusters(self, mock_client) -> None: + self.hook.list_clusters( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + ) + mock_client.assert_called_once() + mock_client.return_value.list_clusters.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + page_size=None, + page_token=None, + filter=None, + order_by=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_LOCATION) diff --git a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py new file mode 100644 index 0000000000000..add83f74d56bf --- /dev/null +++ b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.google.cloud.links.managed_kafka import ( + ApacheKafkaClusterLink, + ApacheKafkaClusterListLink, +) + +TEST_LOCATION = "test-location" +TEST_CLUSTER_ID = "test-cluster-id" +TEST_PROJECT_ID = "test-project-id" +EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_NAME = "Apache Kafka Cluster" +EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_KEY = "cluster_conf" +EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_FORMAT_STR = ( + "/managedkafka/{location}/clusters/{cluster_id}?project={project_id}" +) +EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_NAME = "Apache Kafka Cluster List" +EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_KEY = "cluster_list_conf" +EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_FORMAT_STR = "/managedkafka/clusters?project={project_id}" + + +class TestApacheKafkaClusterLink: + def test_class_attributes(self): + assert ApacheKafkaClusterLink.key == EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_KEY + assert ApacheKafkaClusterLink.name == EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_NAME + assert ApacheKafkaClusterLink.format_str == EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_FORMAT_STR + + def test_persist(self): + mock_context, mock_task_instance = ( + mock.MagicMock(), + mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID), + ) + + ApacheKafkaClusterLink.persist( + context=mock_context, + task_instance=mock_task_instance, + cluster_id=TEST_CLUSTER_ID, + ) + + mock_task_instance.xcom_push.assert_called_once_with( + context=mock_context, + key=EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_KEY, + value={ + "location": TEST_LOCATION, + "cluster_id": TEST_CLUSTER_ID, + "project_id": TEST_PROJECT_ID, + }, + ) + + +class TestApacheKafkaClusterListLink: + def test_class_attributes(self): + assert ApacheKafkaClusterListLink.key == EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_KEY + assert ApacheKafkaClusterListLink.name == EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_NAME + assert ApacheKafkaClusterListLink.format_str == EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_FORMAT_STR + + def test_persist(self): + mock_context, mock_task_instance = mock.MagicMock(), mock.MagicMock(project_id=TEST_PROJECT_ID) + + ApacheKafkaClusterListLink.persist( + context=mock_context, + task_instance=mock_task_instance, + ) + + mock_task_instance.xcom_push.assert_called_once_with( + context=mock_context, + key=EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_KEY, + value={ + "project_id": TEST_PROJECT_ID, + }, + ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py new file mode 100644 index 0000000000000..4b5bc5c71257d --- /dev/null +++ b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from google.api_core.retry import Retry + +from airflow.providers.google.cloud.operators.managed_kafka import ( + ManagedKafkaCreateClusterOperator, + ManagedKafkaDeleteClusterOperator, + ManagedKafkaGetClusterOperator, + ManagedKafkaListClustersOperator, + ManagedKafkaUpdateClusterOperator, +) + +MANAGED_KAFKA_PATH = "airflow.providers.google.cloud.operators.managed_kafka.{}" +TIMEOUT = 120 +RETRY = mock.MagicMock(Retry) +METADATA = [("key", "value")] + +TASK_ID = "test_task_id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "test-location" +GCP_CONN_ID = "test-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] + +TEST_CLUSTER_ID: str = "test-cluster-id" +TEST_CLUSTER: dict = { + "gcp_config": { + "access_config": { + "network_configs": { + "subnet": "subnet_value", + }, + }, + }, + "capacity_config": { + "vcpu_count": 1094, + "memory_bytes": 1311, + }, +} +TEST_CLUSTER_UPDATE_MASK: dict = {"paths": ["gcp_config.access_config.network_configs.subnet"]} +TEST_UPDATED_CLUSTER: dict = { + "gcp_config": { + "access_config": { + "network_configs": { + "subnet": "new_subnet_value", + }, + }, + }, +} + + +class TestManagedKafkaCreateClusterOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.Cluster.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ManagedKafkaCreateClusterOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster=TEST_CLUSTER, + cluster_id=TEST_CLUSTER_ID, + request_id=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_cluster.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster=TEST_CLUSTER, + cluster_id=TEST_CLUSTER_ID, + request_id=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaListClustersOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.ListClustersResponse.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("types.Cluster.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_cluster_dict_mock, to_clusters_dict_mock): + page_token = "page_token" + page_size = 42 + filter = "filter" + order_by = "order_by" + + op = ManagedKafkaListClustersOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + filter=filter, + order_by=order_by, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.list_clusters.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + filter=filter, + order_by=order_by, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaGetClusterOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.Cluster.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ManagedKafkaGetClusterOperator( + task_id=TASK_ID, + cluster_id=TEST_CLUSTER_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.get_cluster.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaUpdateClusterOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.Cluster.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ManagedKafkaUpdateClusterOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_UPDATED_CLUSTER, + update_mask=TEST_CLUSTER_UPDATE_MASK, + request_id=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.update_cluster.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_UPDATED_CLUSTER, + update_mask=TEST_CLUSTER_UPDATE_MASK, + request_id=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaDeleteClusterOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook): + op = ManagedKafkaDeleteClusterOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + request_id=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.delete_cluster.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + request_id=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 1ae11c8d87d78..08445cce69f9a 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -469,6 +469,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator", "airflow.providers.google.cloud.operators.dataproc._DataprocStartStopClusterBaseOperator", "airflow.providers.google.cloud.operators.dataplex.DataplexCatalogBaseOperator", + "airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaBaseOperator", "airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator", "airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator", "airflow.providers.google.marketing_platform.operators.search_ads._GoogleSearchAdsBaseOperator",