diff --git a/providers/dbt/cloud/docs/operators.rst b/providers/dbt/cloud/docs/operators.rst index 3eeb5b04dbb59..24702d5d5634e 100644 --- a/providers/dbt/cloud/docs/operators.rst +++ b/providers/dbt/cloud/docs/operators.rst @@ -82,6 +82,17 @@ via the ``additional_run_config`` dictionary. :start-after: [START howto_operator_dbt_cloud_run_job_async] :end-before: [END howto_operator_dbt_cloud_run_job_async] +You can also trigger a dbt Cloud job without providing the ``job_id``. Instead, you can identify the job +by providing the ``project_name``, ``environment_name``, and ``job_name``. +Please note that it will only work if the above three parameters uniquely identify a job in your account +(i.e. you cannot have two jobs with the same name in the same project and environment). + +.. exampleinclude:: /../../providers/dbt/cloud/tests/system/dbt/cloud/example_dbt_cloud.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dbt_cloud_run_job_without_job_id] + :end-before: [END howto_operator_dbt_cloud_run_job_without_job_id] + .. _howto/operator:DbtCloudJobRunSensor: Poll for status of a dbt Cloud Job run diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py index 815f6779dbeaa..66967654fb3ea 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py @@ -135,6 +135,10 @@ class DbtCloudJobRunException(AirflowException): """An exception that indicates a job run failed to complete.""" +class DbtCloudResourceLookupError(AirflowException): + """Exception raised when a dbt Cloud resource cannot be uniquely identified.""" + + T = TypeVar("T", bound=Any) @@ -356,14 +360,23 @@ def get_account(self, account_id: int | None = None) -> Response: return self._run_and_get_response(endpoint=f"{account_id}/") @fallback_to_default_account - def list_projects(self, account_id: int | None = None) -> list[Response]: + def list_projects( + self, account_id: int | None = None, name_contains: str | None = None + ) -> list[Response]: """ Retrieve metadata for all projects tied to a specified dbt Cloud account. :param account_id: Optional. The ID of a dbt Cloud account. + :param name_contains: Optional. The case-insensitive substring of a dbt Cloud project name to filter by. :return: List of request responses. """ - return self._run_and_get_response(endpoint=f"{account_id}/projects/", paginate=True, api_version="v3") + payload = {"name__icontains": name_contains} if name_contains else None + return self._run_and_get_response( + endpoint=f"{account_id}/projects/", + payload=payload, + paginate=True, + api_version="v3", + ) @fallback_to_default_account def get_project(self, project_id: int, account_id: int | None = None) -> Response: @@ -376,27 +389,73 @@ def get_project(self, project_id: int, account_id: int | None = None) -> Respons """ return self._run_and_get_response(endpoint=f"{account_id}/projects/{project_id}/", api_version="v3") + @fallback_to_default_account + def list_environments( + self, project_id: int, *, name_contains: str | None = None, account_id: int | None = None + ) -> list[Response]: + """ + Retrieve metadata for all environments tied to a specified dbt Cloud project. + + :param project_id: The ID of a dbt Cloud project. + :param name_contains: Optional. The case-insensitive substring of a dbt Cloud environment name to filter by. + :param account_id: Optional. The ID of a dbt Cloud account. + :return: List of request responses. + """ + payload = {"name__icontains": name_contains} if name_contains else None + return self._run_and_get_response( + endpoint=f"{account_id}/projects/{project_id}/environments/", + payload=payload, + paginate=True, + api_version="v3", + ) + + @fallback_to_default_account + def get_environment( + self, project_id: int, environment_id: int, *, account_id: int | None = None + ) -> Response: + """ + Retrieve metadata for a specific project's environment. + + :param project_id: The ID of a dbt Cloud project. + :param environment_id: The ID of a dbt Cloud environment. + :param account_id: Optional. The ID of a dbt Cloud account. + :return: The request response. + """ + return self._run_and_get_response( + endpoint=f"{account_id}/projects/{project_id}/environments/{environment_id}/", api_version="v3" + ) + @fallback_to_default_account def list_jobs( self, account_id: int | None = None, order_by: str | None = None, project_id: int | None = None, + environment_id: int | None = None, + name_contains: str | None = None, ) -> list[Response]: """ Retrieve metadata for all jobs tied to a specified dbt Cloud account. If a ``project_id`` is supplied, only jobs pertaining to this project will be retrieved. + If an ``environment_id`` is supplied, only jobs pertaining to this environment will be retrieved. :param account_id: Optional. The ID of a dbt Cloud account. :param order_by: Optional. Field to order the result by. Use '-' to indicate reverse order. For example, to use reverse order by the run ID use ``order_by=-id``. - :param project_id: The ID of a dbt Cloud project. + :param project_id: Optional. The ID of a dbt Cloud project. + :param environment_id: Optional. The ID of a dbt Cloud environment. + :param name_contains: Optional. The case-insensitive substring of a dbt Cloud job name to filter by. :return: List of request responses. """ + payload = {"order_by": order_by, "project_id": project_id} + if environment_id: + payload["environment_id"] = environment_id + if name_contains: + payload["name__icontains"] = name_contains return self._run_and_get_response( endpoint=f"{account_id}/jobs/", - payload={"order_by": order_by, "project_id": project_id}, + payload=payload, paginate=True, ) @@ -411,6 +470,72 @@ def get_job(self, job_id: int, account_id: int | None = None) -> Response: """ return self._run_and_get_response(endpoint=f"{account_id}/jobs/{job_id}") + @fallback_to_default_account + def get_job_by_name( + self, *, project_name: str, environment_name: str, job_name: str, account_id: int | None = None + ) -> dict: + """ + Retrieve metadata for a specific job by combination of project, environment, and job name. + + Raises DbtCloudResourceLookupError if the job is not found or cannot be uniquely identified by provided parameters. + + :param project_name: The name of a dbt Cloud project. + :param environment_name: The name of a dbt Cloud environment. + :param job_name: The name of a dbt Cloud job. + :param account_id: Optional. The ID of a dbt Cloud account. + :return: The details of a job. + """ + # get project_id using project_name + list_projects_responses = self.list_projects(name_contains=project_name, account_id=account_id) + # flatten & filter the list of responses to find the exact match + projects = [ + project + for response in list_projects_responses + for project in response.json()["data"] + if project["name"] == project_name + ] + if len(projects) != 1: + raise DbtCloudResourceLookupError(f"Found {len(projects)} projects with name `{project_name}`.") + project_id = projects[0]["id"] + + # get environment_id using project_id and environment_name + list_environments_responses = self.list_environments( + project_id=project_id, name_contains=environment_name, account_id=account_id + ) + # flatten & filter the list of responses to find the exact match + environments = [ + env + for response in list_environments_responses + for env in response.json()["data"] + if env["name"] == environment_name + ] + if len(environments) != 1: + raise DbtCloudResourceLookupError( + f"Found {len(environments)} environments with name `{environment_name}` in project `{project_name}`." + ) + environment_id = environments[0]["id"] + + # get job using project_id, environment_id and job_name + list_jobs_responses = self.list_jobs( + project_id=project_id, + environment_id=environment_id, + name_contains=job_name, + account_id=account_id, + ) + # flatten & filter the list of responses to find the exact match + jobs = [ + job + for response in list_jobs_responses + for job in response.json()["data"] + if job["name"] == job_name + ] + if len(jobs) != 1: + raise DbtCloudResourceLookupError( + f"Found {len(jobs)} jobs with name `{job_name}` in environment `{environment_name}` in project `{project_name}`." + ) + + return jobs[0] + @fallback_to_default_account def trigger_job_run( self, diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py index 8795ebf0ca714..d9aab8e9e2f67 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py @@ -57,7 +57,10 @@ class DbtCloudRunJobOperator(BaseOperator): :ref:`howto/operator:DbtCloudRunJobOperator` :param dbt_cloud_conn_id: The connection ID for connecting to dbt Cloud. - :param job_id: The ID of a dbt Cloud job. + :param job_id: The ID of a dbt Cloud job. Required if project_name, environment_name, and job_name are not provided. + :param project_name: Optional. The name of a dbt Cloud project. Used only if ``job_id`` is None. + :param environment_name: Optional. The name of a dbt Cloud environment. Used only if ``job_id`` is None. + :param job_name: Optional. The name of a dbt Cloud job. Used only if ``job_id`` is None. :param account_id: Optional. The ID of a dbt Cloud account. :param trigger_reason: Optional. Description of the reason to trigger the job. Defaults to "Triggered via Apache Airflow by task in the DAG." @@ -86,6 +89,9 @@ class DbtCloudRunJobOperator(BaseOperator): template_fields = ( "dbt_cloud_conn_id", "job_id", + "project_name", + "environment_name", + "job_name", "account_id", "trigger_reason", "steps_override", @@ -99,7 +105,10 @@ def __init__( self, *, dbt_cloud_conn_id: str = DbtCloudHook.default_conn_name, - job_id: int, + job_id: int | None = None, + project_name: str | None = None, + environment_name: str | None = None, + job_name: str | None = None, account_id: int | None = None, trigger_reason: str | None = None, steps_override: list[str] | None = None, @@ -117,6 +126,9 @@ def __init__( self.dbt_cloud_conn_id = dbt_cloud_conn_id self.account_id = account_id self.job_id = job_id + self.project_name = project_name + self.environment_name = environment_name + self.job_name = job_name self.trigger_reason = trigger_reason self.steps_override = steps_override self.schema_override = schema_override @@ -135,6 +147,18 @@ def execute(self, context: Context): f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG." ) + if self.job_id is None: + if not all([self.project_name, self.environment_name, self.job_name]): + raise ValueError( + "Either job_id or project_name, environment_name, and job_name must be provided." + ) + self.job_id = self.hook.get_job_by_name( + account_id=self.account_id, + project_name=self.project_name, + environment_name=self.environment_name, + job_name=self.job_name, + )["id"] + non_terminal_runs = None if self.reuse_existing_run: non_terminal_runs = self.hook.get_job_runs( diff --git a/providers/dbt/cloud/tests/provider_tests/dbt/cloud/hooks/test_dbt.py b/providers/dbt/cloud/tests/provider_tests/dbt/cloud/hooks/test_dbt.py index 590f1b677f10c..1a11662019b46 100644 --- a/providers/dbt/cloud/tests/provider_tests/dbt/cloud/hooks/test_dbt.py +++ b/providers/dbt/cloud/tests/provider_tests/dbt/cloud/hooks/test_dbt.py @@ -17,11 +17,13 @@ from __future__ import annotations import json +from copy import deepcopy from datetime import timedelta from typing import Any -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +from requests.models import Response from airflow.exceptions import AirflowException from airflow.models.connection import Connection @@ -30,6 +32,7 @@ DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus, + DbtCloudResourceLookupError, TokenAuth, fallback_to_default_account, ) @@ -47,12 +50,47 @@ EXTRA_PROXIES = {"proxies": {"https": "http://myproxy:1234"}} TOKEN = "token" PROJECT_ID = 33333 +PROJECT_NAME = "project_name" +ENVIRONMENT_ID = 44444 +ENVIRONMENT_NAME = "environment_name" JOB_ID = 4444 +JOB_NAME = "job_name" RUN_ID = 5555 BASE_URL = "https://cloud.getdbt.com/" SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/" +DEFAULT_LIST_PROJECTS_RESPONSE = { + "data": [ + { + "id": PROJECT_ID, + "name": PROJECT_NAME, + } + ] +} +DEFAULT_LIST_ENVIRONMENTS_RESPONSE = { + "data": [ + { + "id": ENVIRONMENT_ID, + "name": ENVIRONMENT_NAME, + } + ] +} +DEFAULT_LIST_JOBS_RESPONSE = { + "data": [ + { + "id": JOB_ID, + "name": JOB_NAME, + } + ] +} + + +def mock_response_json(response: dict): + run_response = MagicMock(**response, spec=Response) + run_response.json.return_value = response + return run_response + class TestDbtCloudJobRunStatus: valid_job_run_statuses = [ @@ -247,7 +285,32 @@ def test_list_projects(self, mock_http_run, mock_paginate, conn_id, account_id): _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_not_called() hook._paginate.assert_called_once_with( - endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None, proxies=None + endpoint=f"api/v3/accounts/{_account_id}/projects/", + payload=None, + proxies=None, + ) + + @pytest.mark.parametrize( + argnames="conn_id, account_id, name_contains", + argvalues=[(ACCOUNT_ID_CONN, None, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID, PROJECT_NAME)], + ids=["default_account", "explicit_account"], + ) + @patch.object(DbtCloudHook, "run") + @patch.object(DbtCloudHook, "_paginate") + def test_list_projects_with_payload( + self, mock_http_run, mock_paginate, conn_id, account_id, name_contains + ): + hook = DbtCloudHook(conn_id) + hook.list_projects(account_id=account_id, name_contains=name_contains) + + assert hook.method == "GET" + + _account_id = account_id or DEFAULT_ACCOUNT_ID + hook.run.assert_not_called() + hook._paginate.assert_called_once_with( + endpoint=f"api/v3/accounts/{_account_id}/projects/", + payload={"name__icontains": name_contains} if name_contains else None, + proxies=None, ) @pytest.mark.parametrize( @@ -269,6 +332,71 @@ def test_get_project(self, mock_http_run, mock_paginate, conn_id, account_id): ) hook._paginate.assert_not_called() + @pytest.mark.parametrize( + argnames="conn_id, account_id", + argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], + ids=["default_account", "explicit_account"], + ) + @patch.object(DbtCloudHook, "run") + @patch.object(DbtCloudHook, "_paginate") + def test_list_environments(self, mock_http_run, mock_paginate, conn_id, account_id): + hook = DbtCloudHook(conn_id) + hook.list_environments(project_id=PROJECT_ID, account_id=account_id) + + assert hook.method == "GET" + + _account_id = account_id or DEFAULT_ACCOUNT_ID + hook.run.assert_not_called() + hook._paginate.assert_called_once_with( + endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/environments/", + payload=None, + proxies=None, + ) + + @pytest.mark.parametrize( + argnames="conn_id, account_id, name_contains", + argvalues=[(ACCOUNT_ID_CONN, None, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID, ENVIRONMENT_NAME)], + ids=["default_account", "explicit_account"], + ) + @patch.object(DbtCloudHook, "run") + @patch.object(DbtCloudHook, "_paginate") + def test_list_environments_with_payload( + self, mock_http_run, mock_paginate, conn_id, account_id, name_contains + ): + hook = DbtCloudHook(conn_id) + hook.list_environments(project_id=PROJECT_ID, account_id=account_id, name_contains=name_contains) + + assert hook.method == "GET" + + _account_id = account_id or DEFAULT_ACCOUNT_ID + hook.run.assert_not_called() + hook._paginate.assert_called_once_with( + endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/environments/", + payload={"name__icontains": name_contains} if name_contains else None, + proxies=None, + ) + + @pytest.mark.parametrize( + argnames="conn_id, account_id", + argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], + ids=["default_account", "explicit_account"], + ) + @patch.object(DbtCloudHook, "run") + @patch.object(DbtCloudHook, "_paginate") + def test_get_environment(self, mock_http_run, mock_paginate, conn_id, account_id): + hook = DbtCloudHook(conn_id) + hook.get_environment(project_id=PROJECT_ID, environment_id=ENVIRONMENT_ID, account_id=account_id) + + assert hook.method == "GET" + + _account_id = account_id or DEFAULT_ACCOUNT_ID + hook.run.assert_called_once_with( + endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/environments/{ENVIRONMENT_ID}/", + data=None, + extra_options=None, + ) + hook._paginate.assert_not_called() + @pytest.mark.parametrize( argnames="conn_id, account_id", argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], @@ -299,14 +427,25 @@ def test_list_jobs(self, mock_http_run, mock_paginate, conn_id, account_id): @patch.object(DbtCloudHook, "_paginate") def test_list_jobs_with_payload(self, mock_http_run, mock_paginate, conn_id, account_id): hook = DbtCloudHook(conn_id) - hook.list_jobs(project_id=PROJECT_ID, account_id=account_id, order_by="-id") + hook.list_jobs( + project_id=PROJECT_ID, + account_id=account_id, + order_by="-id", + environment_id=ENVIRONMENT_ID, + name_contains=JOB_NAME, + ) assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID hook._paginate.assert_called_once_with( endpoint=f"api/v2/accounts/{_account_id}/jobs/", - payload={"order_by": "-id", "project_id": PROJECT_ID}, + payload={ + "order_by": "-id", + "project_id": PROJECT_ID, + "environment_id": ENVIRONMENT_ID, + "name__icontains": JOB_NAME, + }, proxies=None, ) hook.run.assert_not_called() @@ -330,6 +469,118 @@ def test_get_job(self, mock_http_run, mock_paginate, conn_id, account_id): ) hook._paginate.assert_not_called() + @patch.object(DbtCloudHook, "list_jobs", return_value=[mock_response_json(DEFAULT_LIST_JOBS_RESPONSE)]) + @patch.object( + DbtCloudHook, + "list_environments", + return_value=[mock_response_json(DEFAULT_LIST_ENVIRONMENTS_RESPONSE)], + ) + @patch.object( + DbtCloudHook, "list_projects", return_value=[mock_response_json(DEFAULT_LIST_PROJECTS_RESPONSE)] + ) + def test_get_job_by_name_returns_response( + self, mock_list_projects, mock_list_environments, mock_list_jobs + ): + hook = DbtCloudHook(ACCOUNT_ID_CONN) + job_details = hook.get_job_by_name( + project_name=PROJECT_NAME, + environment_name=ENVIRONMENT_NAME, + job_name=JOB_NAME, + account_id=None, + ) + + assert job_details == DEFAULT_LIST_JOBS_RESPONSE["data"][0] + + @pytest.mark.parametrize( + argnames="project_name, environment_name, job_name", + argvalues=[ + ("dummy_name", ENVIRONMENT_NAME, JOB_NAME), + (PROJECT_NAME, "dummy_name", JOB_NAME), + (PROJECT_NAME, ENVIRONMENT_NAME, JOB_NAME.upper()), + (None, ENVIRONMENT_NAME, JOB_NAME), + (PROJECT_NAME, "", JOB_NAME), + ("", "", ""), + ], + ) + @patch.object(DbtCloudHook, "list_jobs", return_value=[mock_response_json(DEFAULT_LIST_JOBS_RESPONSE)]) + @patch.object( + DbtCloudHook, + "list_environments", + return_value=[mock_response_json(DEFAULT_LIST_ENVIRONMENTS_RESPONSE)], + ) + @patch.object( + DbtCloudHook, "list_projects", return_value=[mock_response_json(DEFAULT_LIST_PROJECTS_RESPONSE)] + ) + def test_get_job_by_incorrect_name_raises_exception( + self, + mock_list_projects, + mock_list_environments, + mock_list_jobs, + project_name, + environment_name, + job_name, + ): + hook = DbtCloudHook(ACCOUNT_ID_CONN) + with pytest.raises(DbtCloudResourceLookupError, match="Found 0"): + hook.get_job_by_name( + project_name=project_name, + environment_name=environment_name, + job_name=job_name, + account_id=None, + ) + + @pytest.mark.parametrize("duplicated", ["projects", "environments", "jobs"]) + def test_get_job_by_duplicate_name_raises_exception(self, duplicated): + hook = DbtCloudHook(ACCOUNT_ID_CONN) + mock_list_jobs_response = deepcopy(DEFAULT_LIST_JOBS_RESPONSE) + mock_list_environments_response = deepcopy(DEFAULT_LIST_ENVIRONMENTS_RESPONSE) + mock_list_projects_response = deepcopy(DEFAULT_LIST_PROJECTS_RESPONSE) + + if duplicated == "projects": + mock_list_projects_response["data"].append( + { + "id": PROJECT_ID + 1, + "name": PROJECT_NAME, + } + ) + elif duplicated == "environments": + mock_list_environments_response["data"].append( + { + "id": ENVIRONMENT_ID + 1, + "name": ENVIRONMENT_NAME, + } + ) + elif duplicated == "jobs": + mock_list_jobs_response["data"].append( + { + "id": JOB_ID + 1, + "name": JOB_NAME, + } + ) + + with ( + patch.object( + DbtCloudHook, "list_jobs", return_value=[mock_response_json(mock_list_jobs_response)] + ), + patch.object( + DbtCloudHook, + "list_environments", + return_value=[mock_response_json(mock_list_environments_response)], + ), + patch.object( + DbtCloudHook, + "list_projects", + return_value=[mock_response_json(mock_list_projects_response)], + ), + ): + with pytest.raises(DbtCloudResourceLookupError, match=f"Found 2 {duplicated}"): + hook.get_job_by_name( + project_name=PROJECT_NAME, + environment_name=ENVIRONMENT_NAME, + job_name=JOB_NAME, + account_id=None, + ) + @pytest.mark.parametrize( argnames="conn_id, account_id", argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], diff --git a/providers/dbt/cloud/tests/provider_tests/dbt/cloud/operators/test_dbt.py b/providers/dbt/cloud/tests/provider_tests/dbt/cloud/operators/test_dbt.py index 315e770d84dbd..8791a09a5ba71 100644 --- a/providers/dbt/cloud/tests/provider_tests/dbt/cloud/operators/test_dbt.py +++ b/providers/dbt/cloud/tests/provider_tests/dbt/cloud/operators/test_dbt.py @@ -43,7 +43,11 @@ ACCOUNT_ID = 22222 TOKEN = "token" PROJECT_ID = 33333 +PROJECT_NAME = "project_name" +ENVIRONMENT_ID = 44444 +ENVIRONMENT_NAME = "environment_name" JOB_ID = 4444 +JOB_NAME = "job_name" RUN_ID = 5555 EXPECTED_JOB_RUN_OP_EXTRA_LINK = ( "https://cloud.getdbt.com/#/accounts/{account_id}/projects/{project_id}/runs/{run_id}/" @@ -75,6 +79,12 @@ } ] } +DEFAULT_ACCOUNT_JOB_RESPONSE = { + "data": { + "id": JOB_ID, + "account_id": DEFAULT_ACCOUNT_ID, + } +} def mock_response_json(response: dict): @@ -200,6 +210,95 @@ def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_jo dbt_op.execute(MagicMock()) assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger is not a DbtCloudRunJobTrigger" + @patch( + "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_by_name", + return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RESPONSE), + ) + @patch( + "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status", + return_value=DbtCloudJobRunStatus.SUCCESS.value, + ) + @patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection") + @patch( + "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run", + return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RUN_RESPONSE), + ) + def test_dbt_run_job_by_name( + self, mock_trigger_job_run, mock_dbt_hook, mock_job_run_status, mock_job_by_name + ): + """ + Test alternative way to run a job by project, + environment and job name instead of job id. + """ + dbt_op = DbtCloudRunJobOperator( + dbt_cloud_conn_id=ACCOUNT_ID_CONN, + task_id=TASK_ID, + project_name=PROJECT_NAME, + environment_name=ENVIRONMENT_NAME, + job_name=JOB_NAME, + check_interval=1, + timeout=3, + dag=self.dag, + ) + dbt_op.execute(MagicMock()) + mock_trigger_job_run.assert_called_once() + + @pytest.mark.parametrize( + argnames="project_name, environment_name, job_name", + argvalues=[ + (None, ENVIRONMENT_NAME, JOB_NAME), + (PROJECT_NAME, "", JOB_NAME), + (PROJECT_NAME, ENVIRONMENT_NAME, None), + ("", "", ""), + ], + ) + @patch( + "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_by_name", + return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RESPONSE), + ) + @patch( + "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status", + return_value=DbtCloudJobRunStatus.SUCCESS.value, + ) + @patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection") + @patch( + "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run", + return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RUN_RESPONSE), + ) + def test_dbt_run_job_by_incorrect_name_raises_exception( + self, + mock_trigger_job_run, + mock_dbt_hook, + mock_job_run_status, + mock_job_by_name, + project_name, + environment_name, + job_name, + ): + """ + Test alternative way to run a job by project, + environment and job name instead of job id. + + This test is to check if the operator raises an exception + when the project, environment or job name is missing. + """ + dbt_op = DbtCloudRunJobOperator( + dbt_cloud_conn_id=ACCOUNT_ID_CONN, + task_id=TASK_ID, + project_name=project_name, + environment_name=environment_name, + job_name=job_name, + check_interval=1, + timeout=3, + dag=self.dag, + ) + with pytest.raises( + ValueError, + match="Either job_id or project_name, environment_name, and job_name must be provided.", + ): + dbt_op.execute(MagicMock()) + mock_trigger_job_run.assert_not_called() + @patch.object( DbtCloudHook, "trigger_job_run", return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RUN_RESPONSE) ) diff --git a/providers/dbt/cloud/tests/system/dbt/cloud/example_dbt_cloud.py b/providers/dbt/cloud/tests/system/dbt/cloud/example_dbt_cloud.py index 9f0b9ed06de75..a9a2fb257ae99 100644 --- a/providers/dbt/cloud/tests/system/dbt/cloud/example_dbt_cloud.py +++ b/providers/dbt/cloud/tests/system/dbt/cloud/example_dbt_cloud.py @@ -67,6 +67,17 @@ ) # [END howto_operator_dbt_cloud_run_job_async] + # [START howto_operator_dbt_cloud_run_job_without_job_id] + trigger_job_run3 = DbtCloudRunJobOperator( + task_id="trigger_job_run3", + project_name="my_dbt_project", + environment_name="prod", + job_name="my_dbt_job", + check_interval=10, + timeout=300, + ) + # [END howto_operator_dbt_cloud_run_job_without_job_id] + # [START howto_operator_dbt_cloud_run_job_sensor] job_run_sensor = DbtCloudJobRunSensor( task_id="job_run_sensor", run_id=trigger_job_run2.output, timeout=20