From aa75b2047fe99c824da10d9776c7175e37f12097 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Thu, 26 Aug 2021 03:15:55 +0300 Subject: [PATCH 01/19] add endpoint for aborting a dag run #15888 --- .../endpoints/dag_run_endpoint.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index feb6d24cf3247..512c62340a617 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -28,11 +28,13 @@ dagrun_schema, dagruns_batch_form_schema, ) +from airflow import DAG from airflow.models import DagModel, DagRun from airflow.security import permissions from airflow.utils.session import provide_session from airflow.utils.state import State from airflow.utils.types import DagRunType +from airflow.exceptions import DagRunNotFound @security.requires_access( @@ -270,4 +272,28 @@ def post_dag_run(dag_id, session): detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun logical date: '{logical_date}' already exists" ) - raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists") + +@security.requires_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE) + ] +) +@provide_session +def post_abort_dag_run(dag_id, dag_run_id): + """Set a state of a dag run.""" + dag: DAG = current_app.dag_bag.get_dag(dag_id) + if not dag: + raise NotFound(title="DAG not found", detail=f"DAG with dag_id: '{dag_id}' not found") + + dag_run: DagRun = dag.get_dagrun(run_id=dag_run_id) + if not dag_run: + error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' + raise DagRunNotFound(error_message) + + dag_run.set_state(state=State.FAILED) + task_instances = dag_run.get_task_instances() + for ti in task_instances: + dag.set_task_instance_state(task_id=ti.task_id, execution_date=ti.execution_date, state=State.FAILED) + return dagrun_schema.dump(dag_run) From ddd62a0918d7ea87dde8a2fabdc518f65c235dc8 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Thu, 26 Aug 2021 03:36:36 +0300 Subject: [PATCH 02/19] update api --- airflow/api_connexion/openapi/v1.yaml | 37 +++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 2ce6804bbd888..bf7d75b481073 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -604,6 +604,43 @@ paths: '404': $ref: '#/components/responses/NotFound' + /dags/{dag_id}/dagRuns/abort: + parameters: + - $ref: '#/components/parameters/DAGID' + - $ref: '#/components/parameters/DAGRunID' + + post: + summary: Abort a DAG run + description: > + Sets a DAG run and all of its task instances to a failed state + x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint + operationId: post_abort_dag_run + tags: [DAGRun] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRun' + + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRun' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthenticated' + '409': + $ref: '#/components/responses/AlreadyExists' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + /dags/~/dagRuns/list: post: summary: List DAG runs (batch) From c85dc9977b458d2903b88a8a6204848ab1d09b5e Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Thu, 26 Aug 2021 12:38:48 +0300 Subject: [PATCH 03/19] change endpoint to accept any state --- airflow/api_connexion/endpoints/dag_run_endpoint.py | 11 ++++------- airflow/api_connexion/openapi/v1.yaml | 6 +++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index fd317151f216c..8215b91c1a567 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -283,19 +283,16 @@ def post_dag_run(dag_id, session): ] ) @provide_session -def post_abort_dag_run(dag_id, dag_run_id): +def set_dag_run_state(dag_id, run_id, state): """Set a state of a dag run.""" dag: DAG = current_app.dag_bag.get_dag(dag_id) if not dag: raise NotFound(title="DAG not found", detail=f"DAG with dag_id: '{dag_id}' not found") - dag_run: DagRun = dag.get_dagrun(run_id=dag_run_id) + dag_run: DagRun = dag.get_dagrun(run_id=run_id) if not dag_run: - error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' + error_message = f'Dag Run id {run_id} not found in dag {dag_id}' raise DagRunNotFound(error_message) - dag_run.set_state(state=State.FAILED) - task_instances = dag_run.get_task_instances() - for ti in task_instances: - dag.set_task_instance_state(task_id=ti.task_id, execution_date=ti.execution_date, state=State.FAILED) + dag_run.set_state(state=state) return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index bf7d75b481073..ba9bf7e494277 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -610,11 +610,11 @@ paths: - $ref: '#/components/parameters/DAGRunID' post: - summary: Abort a DAG run + summary: Set the state of a DAG run description: > - Sets a DAG run and all of its task instances to a failed state + Set the state of a DAG run x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint - operationId: post_abort_dag_run + operationId: set_dag_run_state tags: [DAGRun] requestBody: required: true From e2ef141a3af55dee76ce92ca0ad4cbf73046a084 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Fri, 27 Aug 2021 03:13:02 +0300 Subject: [PATCH 04/19] created tests for endpoint --- .../endpoints/dag_run_endpoint.py | 27 ++-- airflow/api_connexion/openapi/v1.yaml | 42 +++++-- .../api_connexion/schemas/dag_run_schema.py | 8 ++ .../endpoints/test_dag_run_endpoint.py | 115 ++++++++++++++++++ 4 files changed, 171 insertions(+), 21 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 8215b91c1a567..e0a3e974272a5 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -28,11 +28,10 @@ dagrun_schema, dagruns_batch_form_schema, ) -from airflow import DAG from airflow.models import DagModel, DagRun from airflow.security import permissions from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import State, DagRunState from airflow.utils.types import DagRunType from airflow.exceptions import DagRunNotFound @@ -278,21 +277,27 @@ def post_dag_run(dag_id, session): @security.requires_access( [ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE) + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN) ] ) @provide_session -def set_dag_run_state(dag_id, run_id, state): +def post_set_dag_run_state(dag_id, session) -> dict: """Set a state of a dag run.""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) - if not dag: - raise NotFound(title="DAG not found", detail=f"DAG with dag_id: '{dag_id}' not found") + try: + post_body = dagrun_schema.load(request.json, session=session, unknown="include") + except ValidationError as err: + raise BadRequest(detail=str(err)) + dag_run_id, state = post_body['run_id'], post_body['state'] + dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + if dag_run is None: + raise NotFound( + "DAGRun not found", + detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found", + ) - dag_run: DagRun = dag.get_dagrun(run_id=run_id) if not dag_run: - error_message = f'Dag Run id {run_id} not found in dag {dag_id}' + error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise DagRunNotFound(error_message) - dag_run.set_state(state=state) + dag_run.set_state(state=DagRunState(state.lower()).value) return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index ba9bf7e494277..5e847467a9449 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -604,24 +604,23 @@ paths: '404': $ref: '#/components/responses/NotFound' - /dags/{dag_id}/dagRuns/abort: + /dags/{dag_id}/dagRuns/updateDagRunState: parameters: - $ref: '#/components/parameters/DAGID' - - $ref: '#/components/parameters/DAGRunID' post: - summary: Set the state of a DAG run + summary: Set a state of DAG run description: > - Set the state of a DAG run + Set a state of DAG run x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint - operationId: set_dag_run_state - tags: [DAGRun] + operationId: post_set_dag_run_state + tags: [UpdateDagRunState] requestBody: required: true content: application/json: schema: - $ref: '#/components/schemas/DAGRun' + $ref: '#/components/schemas/UpdateDagRunState' responses: '200': @@ -631,11 +630,9 @@ paths: schema: $ref: '#/components/schemas/DAGRun' '400': - $ref: '#/components/responses/BadRequest' + $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthenticated' - '409': - $ref: '#/components/responses/AlreadyExists' '403': $ref: '#/components/responses/PermissionDenied' '404': @@ -2033,6 +2030,31 @@ components: required: - dag_id + UpdateDagRunState: + type: object + properties: + dag_run_id: + type: string + nullable: true + description: | + Run ID. + + The value of this field can be set only when creating the object. If you try to modify the + field of an existing object, the request fails with an BAD_REQUEST error. + + If not provided, a value will be generated based on execution_date. + + If the specified dag_run_id is in use, the creation request fails with an ALREADY_EXISTS error. + + This together with DAG_ID are a unique key. + dag_id: + type: string + readOnly: true + state: + $ref: '#/components/schemas/DagState' + required: + - dag_id + DAGRunCollection: type: object description: Collection of DAG runs. diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index 62ac172ce5525..a5c6dc4d9e7a6 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -104,6 +104,14 @@ def autofill(self, data, **kwargs): return data +class SetDagRunStateFormSchema(Schema): + """Schema for handling the request of setting state of DAG run""" + + run_id = auto_field(data_key='dag_run_id') + dag_id = auto_field(dump_only=True) + state = DagStateField(dump_only=True) + + class DAGRunCollection(NamedTuple): """List of DAGRuns with metadata""" diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 71b52747dcff7..9fc31181a8bfd 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -19,12 +19,14 @@ import pytest from parameterized import parameterized +from freezegun import freeze_time from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import DAG, DagModel, DagRun from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session, provide_session +from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -1154,3 +1156,116 @@ def test_should_raises_403_unauthorized(self): environ_overrides={'REMOTE_USER': "test_view_dags"}, ) assert response.status_code == 403 + + +class TestPostSetDagRunState(TestDagRunEndpoint): + + @parameterized.expand([ + ("TEST_DAG_ID", "TEST_DAG_RUN_ID_1", "failed"), + ("TEST_DAG_ID", "TEST_DAG_RUN_ID_1", "success") + ]) + @freeze_time(TestDagRunEndpoint.default_time) + def test_should_respond_200(self, dag_id, dag_run_id, state): + test_time = timezone.parse(self.default_time) + with create_session() as session: + dag = DagModel(dag_id=dag_id) + dag_run = DagRun( + dag_id=dag_id, + run_id=dag_run_id, + state=DagRunState.RUNNING, + run_type=DagRunType.MANUAL, + execution_date=test_time, + start_date=test_time, + external_trigger=True + ) + session.add(dag) + session.add(dag_run) + + request_json = { + "dag_run_id": dag_run_id, + "state": state + } + + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns/updateDagRunState", + json=request_json, + environ_overrides={"REMOTE_USER": "test"}, + ) + + assert response.status_code == 200 + assert response.json == { + 'conf': {}, + 'dag_id': 'TEST_DAG_ID', + 'dag_run_id': 'TEST_DAG_RUN_ID_1', + 'end_date': self.default_time, + 'execution_date': self.default_time, + 'external_trigger': True, + 'logical_date': self.default_time, + 'start_date': self.default_time, + 'state': state + } + + def test_should_response_400_for_non_existing_dag_run_state(self): + test_time = timezone.parse(self.default_time) + with create_session() as session, freeze_time(self.default_time): + dag_id = "TEST_DAG_ID" + dag_run_id = "TEST_DAG_RUN_ID_1" + dag = DagModel(dag_id=dag_id) + dag_run = DagRun( + dag_id=dag_id, + run_id=dag_run_id, + state=DagRunState.RUNNING, + run_type=DagRunType.MANUAL, + execution_date=test_time, + start_date=test_time, + external_trigger=True + ) + session.add(dag) + session.add(dag_run) + + request_json = { + "dag_run_id": dag_run_id, + "state": "madeUpState" + } + + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns/updateDagRunState", + json=request_json, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + 'detail': "'madeUpState' is not one of ['queued', 'running', 'success', 'failed'] - 'state'", + 'status': 400, + 'title': 'Bad Request', + 'type': EXCEPTIONS_LINK_MAP[400] + } + + def test_should_raises_401_unauthenticated(self, session): + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns/updateDagRunState", + json={ + "run_id": "TEST_DAG_RUN_ID_1", + "state": 'success', + }, + ) + + assert_401(response) + + def test_should_raise_403_forbidden(self): + response = self.client.get( + "api/v1/dags/TEST_DAG_ID/dagRuns/updateDagRunState", + environ_overrides={'REMOTE_USER': "test_no_permissions"}, + ) + assert response.status_code == 403 + + def test_should_respond_404(self): + response = self.client.get( + "api/v1/dags/INVALID_DAG_ID/dagRuns/updateDagRunState", + json={ + "run_id": "TEST_DAG_RUN_ID_1", + "state": 'success', + }, + environ_overrides={"REMOTE_USER": "test"} + ) + assert response.status_code == 404 From 0ad1bb38822a1244075dc75227524f989b316ef7 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Fri, 27 Aug 2021 22:22:12 +0300 Subject: [PATCH 05/19] merge session as well --- airflow/api_connexion/endpoints/dag_run_endpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index e0a3e974272a5..0056806a2ef80 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -281,14 +281,14 @@ def post_dag_run(dag_id, session): ] ) @provide_session -def post_set_dag_run_state(dag_id, session) -> dict: +def post_set_dag_run_state(dag_id: str, session) -> dict: """Set a state of a dag run.""" try: post_body = dagrun_schema.load(request.json, session=session, unknown="include") except ValidationError as err: raise BadRequest(detail=str(err)) dag_run_id, state = post_body['run_id'], post_body['state'] - dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + dag_run: DagRun = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() if dag_run is None: raise NotFound( "DAGRun not found", @@ -300,4 +300,5 @@ def post_set_dag_run_state(dag_id, session) -> dict: raise DagRunNotFound(error_message) dag_run.set_state(state=DagRunState(state.lower()).value) + session.merge(dag_run) return dagrun_schema.dump(dag_run) From e1308fa18d1e31b6201d70bcdd6e94060305b1fb Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Fri, 27 Aug 2021 23:14:56 +0300 Subject: [PATCH 06/19] Change type hint of dag run variable to optional Co-authored-by: Tzu-ping Chung --- airflow/api_connexion/endpoints/dag_run_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 0056806a2ef80..f62215298882d 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -288,7 +288,7 @@ def post_set_dag_run_state(dag_id: str, session) -> dict: except ValidationError as err: raise BadRequest(detail=str(err)) dag_run_id, state = post_body['run_id'], post_body['state'] - dag_run: DagRun = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + dag_run: Optional[DagRun] = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() if dag_run is None: raise NotFound( "DAGRun not found", From 3c3b4a56b6f8f4f762509674eaa481b26c7743f5 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Fri, 27 Aug 2021 23:37:56 +0300 Subject: [PATCH 07/19] add import --- airflow/api_connexion/endpoints/dag_run_endpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index f62215298882d..90c51aa88c179 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -17,6 +17,7 @@ from flask import current_app, g, request from marshmallow import ValidationError from sqlalchemy import or_ +from typing import Optional from airflow._vendor.connexion import NoContent from airflow.api_connexion import security From 0c72cb79b3fee74703bab479ca9de21eec309986 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Sat, 28 Aug 2021 20:21:00 +0300 Subject: [PATCH 08/19] PR requested changes --- .../endpoints/dag_run_endpoint.py | 10 ++------- airflow/api_connexion/openapi/v1.yaml | 22 +++---------------- .../endpoints/test_dag_run_endpoint.py | 21 +++++++++--------- 3 files changed, 15 insertions(+), 38 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 90c51aa88c179..a61a202d07486 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -282,24 +282,18 @@ def post_dag_run(dag_id, session): ] ) @provide_session -def post_set_dag_run_state(dag_id: str, session) -> dict: +def post_set_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: """Set a state of a dag run.""" try: post_body = dagrun_schema.load(request.json, session=session, unknown="include") except ValidationError as err: raise BadRequest(detail=str(err)) - dag_run_id, state = post_body['run_id'], post_body['state'] dag_run: Optional[DagRun] = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() if dag_run is None: - raise NotFound( - "DAGRun not found", - detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found", - ) - - if not dag_run: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise DagRunNotFound(error_message) + state = post_body['state'] dag_run.set_state(state=DagRunState(state.lower()).value) session.merge(dag_run) return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 5e847467a9449..525b712badeea 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -604,9 +604,10 @@ paths: '404': $ref: '#/components/responses/NotFound' - /dags/{dag_id}/dagRuns/updateDagRunState: + /dags/{dag_id}/dagRuns/{dag_run_id}/state: parameters: - $ref: '#/components/parameters/DAGID' + - $ref: '#/components/parameters/DAGRunID' post: summary: Set a state of DAG run @@ -2033,27 +2034,10 @@ components: UpdateDagRunState: type: object properties: - dag_run_id: - type: string - nullable: true - description: | - Run ID. - - The value of this field can be set only when creating the object. If you try to modify the - field of an existing object, the request fails with an BAD_REQUEST error. - - If not provided, a value will be generated based on execution_date. - - If the specified dag_run_id is in use, the creation request fails with an ALREADY_EXISTS error. - - This together with DAG_ID are a unique key. - dag_id: - type: string - readOnly: true state: $ref: '#/components/schemas/DagState' required: - - dag_id + - state DAGRunCollection: type: object diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 9fc31181a8bfd..618ee8b677b81 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1182,12 +1182,11 @@ def test_should_respond_200(self, dag_id, dag_run_id, state): session.add(dag_run) request_json = { - "dag_run_id": dag_run_id, "state": state } response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns/updateDagRunState", + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", json=request_json, environ_overrides={"REMOTE_USER": "test"}, ) @@ -1224,12 +1223,11 @@ def test_should_response_400_for_non_existing_dag_run_state(self): session.add(dag_run) request_json = { - "dag_run_id": dag_run_id, "state": "madeUpState" } response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns/updateDagRunState", + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", json=request_json, environ_overrides={"REMOTE_USER": "test"}, ) @@ -1243,9 +1241,8 @@ def test_should_response_400_for_non_existing_dag_run_state(self): def test_should_raises_401_unauthenticated(self, session): response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns/updateDagRunState", + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", json={ - "run_id": "TEST_DAG_RUN_ID_1", "state": 'success', }, ) @@ -1253,17 +1250,19 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) def test_should_raise_403_forbidden(self): - response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns/updateDagRunState", + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", + json={ + "state": 'success', + }, environ_overrides={'REMOTE_USER': "test_no_permissions"}, ) assert response.status_code == 403 def test_should_respond_404(self): - response = self.client.get( - "api/v1/dags/INVALID_DAG_ID/dagRuns/updateDagRunState", + response = self.client.post( + "api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", json={ - "run_id": "TEST_DAG_RUN_ID_1", "state": 'success', }, environ_overrides={"REMOTE_USER": "test"} From 6cb54fce0cb970703012447362199403ed1e7ae7 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Sat, 28 Aug 2021 22:34:19 +0300 Subject: [PATCH 09/19] apply PR suggestion Co-authored-by: Tzu-ping Chung --- airflow/api_connexion/endpoints/dag_run_endpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index a61a202d07486..ae33d84a28108 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -284,14 +284,14 @@ def post_dag_run(dag_id, session): @provide_session def post_set_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: """Set a state of a dag run.""" - try: - post_body = dagrun_schema.load(request.json, session=session, unknown="include") - except ValidationError as err: - raise BadRequest(detail=str(err)) dag_run: Optional[DagRun] = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise DagRunNotFound(error_message) + try: + post_body = dagrun_schema.load(request.json, session=session) + except ValidationError as err: + raise BadRequest(detail=str(err)) state = post_body['state'] dag_run.set_state(state=DagRunState(state.lower()).value) From ea92dd26697511f7a3617681c0424374cda92053 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Sat, 28 Aug 2021 22:55:24 +0300 Subject: [PATCH 10/19] fix tests and remove schema load --- airflow/api_connexion/endpoints/dag_run_endpoint.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index ae33d84a28108..9d4e554313c88 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -284,16 +284,13 @@ def post_dag_run(dag_id, session): @provide_session def post_set_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: """Set a state of a dag run.""" - dag_run: Optional[DagRun] = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + dag_run: Optional[DagRun] = session.query(DagRun) \ + .filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' - raise DagRunNotFound(error_message) - try: - post_body = dagrun_schema.load(request.json, session=session) - except ValidationError as err: - raise BadRequest(detail=str(err)) + raise NotFound(error_message) - state = post_body['state'] + state = request.json['state'] dag_run.set_state(state=DagRunState(state.lower()).value) session.merge(dag_run) return dagrun_schema.dump(dag_run) From 5dd182bca691a9150aede7fa4ed5f2a71e96c51d Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Sun, 29 Aug 2021 01:14:18 +0300 Subject: [PATCH 11/19] fix schema and run pre commit --- .../endpoints/dag_run_endpoint.py | 23 +++++++++------ .../api_connexion/schemas/dag_run_schema.py | 7 ++--- .../endpoints/test_dag_run_endpoint.py | 28 ++++++++----------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 9d4e554313c88..bae21016183f1 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional + from flask import current_app, g, request from marshmallow import ValidationError from sqlalchemy import or_ -from typing import Optional from airflow._vendor.connexion import NoContent from airflow.api_connexion import security @@ -28,13 +29,14 @@ dagrun_collection_schema, dagrun_schema, dagruns_batch_form_schema, + set_dagrun_state_form_schema, ) +from airflow.exceptions import DagRunNotFound from airflow.models import DagModel, DagRun from airflow.security import permissions from airflow.utils.session import provide_session -from airflow.utils.state import State, DagRunState +from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType -from airflow.exceptions import DagRunNotFound @security.requires_access( @@ -278,19 +280,24 @@ def post_dag_run(dag_id, session): @security.requires_access( [ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN) + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), ] ) @provide_session def post_set_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: """Set a state of a dag run.""" - dag_run: Optional[DagRun] = session.query(DagRun) \ - .filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + dag_run: Optional[DagRun] = ( + session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + ) if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) + try: + post_body = set_dagrun_state_form_schema.load(request.json, session=session) + except ValidationError as err: + raise BadRequest(detail=str(err)) - state = request.json['state'] - dag_run.set_state(state=DagRunState(state.lower()).value) + state = post_body['state'] + dag_run.set_state(state=state) session.merge(dag_run) return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index a5c6dc4d9e7a6..b71ced23ae813 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -104,12 +104,10 @@ def autofill(self, data, **kwargs): return data -class SetDagRunStateFormSchema(Schema): +class SetDagRunStateFormSchema(SQLAlchemySchema): """Schema for handling the request of setting state of DAG run""" - run_id = auto_field(data_key='dag_run_id') - dag_id = auto_field(dump_only=True) - state = DagStateField(dump_only=True) + state = DagStateField() class DAGRunCollection(NamedTuple): @@ -149,4 +147,5 @@ class Meta: dagrun_schema = DAGRunSchema() dagrun_collection_schema = DAGRunCollectionSchema() +set_dagrun_state_form_schema = SetDagRunStateFormSchema() dagruns_batch_form_schema = DagRunsBatchFormSchema() diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 618ee8b677b81..ef8fd24927d7e 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -18,8 +18,8 @@ from unittest import mock import pytest -from parameterized import parameterized from freezegun import freeze_time +from parameterized import parameterized from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import DAG, DagModel, DagRun @@ -1159,11 +1159,9 @@ def test_should_raises_403_unauthorized(self): class TestPostSetDagRunState(TestDagRunEndpoint): - - @parameterized.expand([ - ("TEST_DAG_ID", "TEST_DAG_RUN_ID_1", "failed"), - ("TEST_DAG_ID", "TEST_DAG_RUN_ID_1", "success") - ]) + @parameterized.expand( + [("TEST_DAG_ID", "TEST_DAG_RUN_ID_1", "failed"), ("TEST_DAG_ID", "TEST_DAG_RUN_ID_1", "success")] + ) @freeze_time(TestDagRunEndpoint.default_time) def test_should_respond_200(self, dag_id, dag_run_id, state): test_time = timezone.parse(self.default_time) @@ -1176,14 +1174,12 @@ def test_should_respond_200(self, dag_id, dag_run_id, state): run_type=DagRunType.MANUAL, execution_date=test_time, start_date=test_time, - external_trigger=True + external_trigger=True, ) session.add(dag) session.add(dag_run) - request_json = { - "state": state - } + request_json = {"state": state} response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", @@ -1201,7 +1197,7 @@ def test_should_respond_200(self, dag_id, dag_run_id, state): 'external_trigger': True, 'logical_date': self.default_time, 'start_date': self.default_time, - 'state': state + 'state': state, } def test_should_response_400_for_non_existing_dag_run_state(self): @@ -1217,14 +1213,12 @@ def test_should_response_400_for_non_existing_dag_run_state(self): run_type=DagRunType.MANUAL, execution_date=test_time, start_date=test_time, - external_trigger=True + external_trigger=True, ) session.add(dag) session.add(dag_run) - request_json = { - "state": "madeUpState" - } + request_json = {"state": "madeUpState"} response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", @@ -1236,7 +1230,7 @@ def test_should_response_400_for_non_existing_dag_run_state(self): 'detail': "'madeUpState' is not one of ['queued', 'running', 'success', 'failed'] - 'state'", 'status': 400, 'title': 'Bad Request', - 'type': EXCEPTIONS_LINK_MAP[400] + 'type': EXCEPTIONS_LINK_MAP[400], } def test_should_raises_401_unauthenticated(self, session): @@ -1265,6 +1259,6 @@ def test_should_respond_404(self): json={ "state": 'success', }, - environ_overrides={"REMOTE_USER": "test"} + environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 404 From 5a6b7a8a9c921ac935076c1ace42c952d8f76c0d Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Wed, 1 Sep 2021 01:29:44 +0300 Subject: [PATCH 12/19] PR feedback --- .../endpoints/dag_run_endpoint.py | 6 +-- airflow/api_connexion/openapi/v1.yaml | 10 +++-- .../api_connexion/schemas/dag_run_schema.py | 7 ++-- .../endpoints/test_dag_run_endpoint.py | 40 ++++++++++--------- 4 files changed, 34 insertions(+), 29 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index bae21016183f1..74aaf38eeb31e 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -284,7 +284,7 @@ def post_dag_run(dag_id, session): ] ) @provide_session -def post_set_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: +def update_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: """Set a state of a dag run.""" dag_run: Optional[DagRun] = ( session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() @@ -293,11 +293,11 @@ def post_set_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) try: - post_body = set_dagrun_state_form_schema.load(request.json, session=session) + post_body = set_dagrun_state_form_schema.load(request.json) except ValidationError as err: raise BadRequest(detail=str(err)) state = post_body['state'] - dag_run.set_state(state=state) + dag_run.set_state(state=DagRunState(state).value) session.merge(dag_run) return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 525b712badeea..65e9ff1269bfa 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -614,7 +614,7 @@ paths: description: > Set a state of DAG run x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint - operationId: post_set_dag_run_state + operationId: update_dag_run_state tags: [UpdateDagRunState] requestBody: required: true @@ -2035,9 +2035,11 @@ components: type: object properties: state: - $ref: '#/components/schemas/DagState' - required: - - state + description: The state to set this DagRun + type: string + enum: + - success + - failed DAGRunCollection: type: object diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index b71ced23ae813..5f8a170c538e0 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -18,7 +18,7 @@ import json from typing import List, NamedTuple -from marshmallow import fields, post_dump, pre_load +from marshmallow import fields, post_dump, pre_load, validate from marshmallow.schema import Schema from marshmallow.validate import Range from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -29,6 +29,7 @@ from airflow.api_connexion.schemas.enum_schemas import DagStateField from airflow.models.dagrun import DagRun from airflow.utils import timezone +from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -104,10 +105,10 @@ def autofill(self, data, **kwargs): return data -class SetDagRunStateFormSchema(SQLAlchemySchema): +class SetDagRunStateFormSchema(Schema): """Schema for handling the request of setting state of DAG run""" - state = DagStateField() + state = DagStateField(validate=validate.OneOf([State.SUCCESS, State.FAILED])) class DAGRunCollection(NamedTuple): diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index ef8fd24927d7e..7459c1e1029a1 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1159,11 +1159,11 @@ def test_should_raises_403_unauthorized(self): class TestPostSetDagRunState(TestDagRunEndpoint): - @parameterized.expand( - [("TEST_DAG_ID", "TEST_DAG_RUN_ID_1", "failed"), ("TEST_DAG_ID", "TEST_DAG_RUN_ID_1", "success")] - ) + @parameterized.expand([("failed",), ("success",)]) @freeze_time(TestDagRunEndpoint.default_time) - def test_should_respond_200(self, dag_id, dag_run_id, state): + def test_should_respond_200(self, state): + dag_id = "TEST_DAG_ID" + dag_run_id = "TEST_DAG_RUN_ID_1" test_time = timezone.parse(self.default_time) with create_session() as session: dag = DagModel(dag_id=dag_id) @@ -1200,9 +1200,11 @@ def test_should_respond_200(self, dag_id, dag_run_id, state): 'state': state, } - def test_should_response_400_for_non_existing_dag_run_state(self): + @parameterized.expand([("running",), ("queued",)]) + @freeze_time(TestDagRunEndpoint.default_time) + def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state): test_time = timezone.parse(self.default_time) - with create_session() as session, freeze_time(self.default_time): + with create_session() as session: dag_id = "TEST_DAG_ID" dag_run_id = "TEST_DAG_RUN_ID_1" dag = DagModel(dag_id=dag_id) @@ -1218,20 +1220,20 @@ def test_should_response_400_for_non_existing_dag_run_state(self): session.add(dag) session.add(dag_run) - request_json = {"state": "madeUpState"} + request_json = {"state": invalid_state} - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", - json=request_json, - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 400 - assert response.json == { - 'detail': "'madeUpState' is not one of ['queued', 'running', 'success', 'failed'] - 'state'", - 'status': 400, - 'title': 'Bad Request', - 'type': EXCEPTIONS_LINK_MAP[400], - } + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", + json=request_json, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + 'detail': f"'{invalid_state}' is not one of ['success', 'failed'] - 'state'", + 'status': 400, + 'title': 'Bad Request', + 'type': EXCEPTIONS_LINK_MAP[400], + } def test_should_raises_401_unauthenticated(self, session): response = self.client.post( From 09f1d4fe7ed449e7abc6fad616445db9f9bc6a0d Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Wed, 1 Sep 2021 16:43:47 +0300 Subject: [PATCH 13/19] More PR feedback --- .../endpoints/dag_run_endpoint.py | 2 +- airflow/api_connexion/openapi/v1.yaml | 3 +- .../api_connexion/schemas/dag_run_schema.py | 2 +- .../endpoints/test_dag_run_endpoint.py | 34 ++++--------------- 4 files changed, 9 insertions(+), 32 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 74aaf38eeb31e..f721b8607bb61 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -298,6 +298,6 @@ def update_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: raise BadRequest(detail=str(err)) state = post_body['state'] - dag_run.set_state(state=DagRunState(state).value) + dag_run.set_state(state=DagRunState(state)) session.merge(dag_run) return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 65e9ff1269bfa..26c56c49a5181 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -611,8 +611,7 @@ paths: post: summary: Set a state of DAG run - description: > - Set a state of DAG run + description: Set a state of DAG run x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint operationId: update_dag_run_state tags: [UpdateDagRunState] diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index 5f8a170c538e0..92a4125e55d65 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -108,7 +108,7 @@ def autofill(self, data, **kwargs): class SetDagRunStateFormSchema(Schema): """Schema for handling the request of setting state of DAG run""" - state = DagStateField(validate=validate.OneOf([State.SUCCESS, State.FAILED])) + state = DagStateField(validate=validate.OneOf([State.SUCCESS.value, State.FAILED.value])) class DAGRunCollection(NamedTuple): diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 7459c1e1029a1..ef3b0eab7687c 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -18,7 +18,6 @@ from unittest import mock import pytest -from freezegun import freeze_time from parameterized import parameterized from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP @@ -26,7 +25,6 @@ from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session, provide_session -from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -1160,22 +1158,12 @@ def test_should_raises_403_unauthorized(self): class TestPostSetDagRunState(TestDagRunEndpoint): @parameterized.expand([("failed",), ("success",)]) - @freeze_time(TestDagRunEndpoint.default_time) - def test_should_respond_200(self, state): + @pytest.fixture(scope="module") + def test_should_respond_200(self, state, dag_maker): dag_id = "TEST_DAG_ID" - dag_run_id = "TEST_DAG_RUN_ID_1" - test_time = timezone.parse(self.default_time) with create_session() as session: dag = DagModel(dag_id=dag_id) - dag_run = DagRun( - dag_id=dag_id, - run_id=dag_run_id, - state=DagRunState.RUNNING, - run_type=DagRunType.MANUAL, - execution_date=test_time, - start_date=test_time, - external_trigger=True, - ) + dag_run = dag_maker.create_dagrun() session.add(dag) session.add(dag_run) @@ -1201,22 +1189,12 @@ def test_should_respond_200(self, state): } @parameterized.expand([("running",), ("queued",)]) - @freeze_time(TestDagRunEndpoint.default_time) - def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state): - test_time = timezone.parse(self.default_time) + @pytest.fixture(scope="module") + def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, dag_maker): with create_session() as session: dag_id = "TEST_DAG_ID" - dag_run_id = "TEST_DAG_RUN_ID_1" dag = DagModel(dag_id=dag_id) - dag_run = DagRun( - dag_id=dag_id, - run_id=dag_run_id, - state=DagRunState.RUNNING, - run_type=DagRunType.MANUAL, - execution_date=test_time, - start_date=test_time, - external_trigger=True, - ) + dag_run = dag_maker.create_dagrun() session.add(dag) session.add(dag_run) From 9caf0a8043502871d0ba6dd28b1a825b09b5ba0e Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Wed, 1 Sep 2021 17:31:10 +0300 Subject: [PATCH 14/19] use DagRunState instead of State --- airflow/api_connexion/schemas/dag_run_schema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index 92a4125e55d65..fa51f10ed48c0 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -29,7 +29,7 @@ from airflow.api_connexion.schemas.enum_schemas import DagStateField from airflow.models.dagrun import DagRun from airflow.utils import timezone -from airflow.utils.state import State +from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -108,7 +108,7 @@ def autofill(self, data, **kwargs): class SetDagRunStateFormSchema(Schema): """Schema for handling the request of setting state of DAG run""" - state = DagStateField(validate=validate.OneOf([State.SUCCESS.value, State.FAILED.value])) + state = DagStateField(validate=validate.OneOf([DagRunState.SUCCESS.value, DagRunState.FAILED.value])) class DAGRunCollection(NamedTuple): From 65fa088ec23384abfbf04506e17a8f8055db286f Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Wed, 1 Sep 2021 19:47:54 +0300 Subject: [PATCH 15/19] Apply suggestions from code review PR suggestions in batch Co-authored-by: Ephraim Anierobi --- .../endpoints/test_dag_run_endpoint.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index ef3b0eab7687c..8105a2463bd71 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1158,19 +1158,17 @@ def test_should_raises_403_unauthorized(self): class TestPostSetDagRunState(TestDagRunEndpoint): @parameterized.expand([("failed",), ("success",)]) - @pytest.fixture(scope="module") def test_should_respond_200(self, state, dag_maker): dag_id = "TEST_DAG_ID" - with create_session() as session: - dag = DagModel(dag_id=dag_id) - dag_run = dag_maker.create_dagrun() - session.add(dag) - session.add(dag_run) + dag_run_id = 'TEST_DAG_RUN_ID' + with dag_maker(dag_id): + DummyOperator(task_id='task_id') + dag_maker.create_dagrun(run_id=dag_run_id) request_json = {"state": state} response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/state", json=request_json, environ_overrides={"REMOTE_USER": "test"}, ) @@ -1178,8 +1176,8 @@ def test_should_respond_200(self, state, dag_maker): assert response.status_code == 200 assert response.json == { 'conf': {}, - 'dag_id': 'TEST_DAG_ID', - 'dag_run_id': 'TEST_DAG_RUN_ID_1', + 'dag_id': dag_id, + 'dag_run_id': dag_run_id, 'end_date': self.default_time, 'execution_date': self.default_time, 'external_trigger': True, From fb11ba5c42c2547e6550480eca1c5db14ea76f99 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Wed, 1 Sep 2021 23:05:22 +0300 Subject: [PATCH 16/19] better tests --- .../endpoints/test_dag_run_endpoint.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 8105a2463bd71..e4dd7f8a4b2ef 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -18,10 +18,12 @@ from unittest import mock import pytest +from freezegun import freeze_time from parameterized import parameterized from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import DAG, DagModel, DagRun +from airflow.operators.dummy import DummyOperator from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session, provide_session @@ -1157,7 +1159,8 @@ def test_should_raises_403_unauthorized(self): class TestPostSetDagRunState(TestDagRunEndpoint): - @parameterized.expand([("failed",), ("success",)]) + @pytest.mark.parametrize("state", ["failed", "success"]) + @freeze_time(TestDagRunEndpoint.default_time) def test_should_respond_200(self, state, dag_maker): dag_id = "TEST_DAG_ID" dag_run_id = 'TEST_DAG_RUN_ID' @@ -1179,22 +1182,21 @@ def test_should_respond_200(self, state, dag_maker): 'dag_id': dag_id, 'dag_run_id': dag_run_id, 'end_date': self.default_time, - 'execution_date': self.default_time, - 'external_trigger': True, - 'logical_date': self.default_time, - 'start_date': self.default_time, + 'execution_date': dag_maker.start_date.isoformat(), + 'external_trigger': False, + 'logical_date': dag_maker.start_date.isoformat(), + 'start_date': dag_maker.start_date.isoformat(), 'state': state, } - @parameterized.expand([("running",), ("queued",)]) - @pytest.fixture(scope="module") + @pytest.mark.parametrize('invalid_state', ["running", "queued"]) + @freeze_time(TestDagRunEndpoint.default_time) def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, dag_maker): - with create_session() as session: - dag_id = "TEST_DAG_ID" - dag = DagModel(dag_id=dag_id) - dag_run = dag_maker.create_dagrun() - session.add(dag) - session.add(dag_run) + dag_id = "TEST_DAG_ID" + dag_run_id = 'TEST_DAG_RUN_ID' + with dag_maker(dag_id): + DummyOperator(task_id='task_id') + dag_maker.create_dagrun(run_id=dag_run_id) request_json = {"state": invalid_state} From 12e43d6c3653aefdf730d1ff88d99d612a5c7bcc Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Thu, 2 Sep 2021 12:09:40 +0300 Subject: [PATCH 17/19] change post action to patch --- airflow/api_connexion/openapi/v1.yaml | 63 +++++++++---------- .../endpoints/test_dag_run_endpoint.py | 20 +++--- 2 files changed, 39 insertions(+), 44 deletions(-) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 26c56c49a5181..cec61b3f340af 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -604,40 +604,6 @@ paths: '404': $ref: '#/components/responses/NotFound' - /dags/{dag_id}/dagRuns/{dag_run_id}/state: - parameters: - - $ref: '#/components/parameters/DAGID' - - $ref: '#/components/parameters/DAGRunID' - - post: - summary: Set a state of DAG run - description: Set a state of DAG run - x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint - operationId: update_dag_run_state - tags: [UpdateDagRunState] - requestBody: - required: true - content: - application/json: - schema: - $ref: '#/components/schemas/UpdateDagRunState' - - responses: - '200': - description: Success. - content: - application/json: - schema: - $ref: '#/components/schemas/DAGRun' - '400': - $ref: '#/components/responses/BadRequest' - '401': - $ref: '#/components/responses/Unauthenticated' - '403': - $ref: '#/components/responses/PermissionDenied' - '404': - $ref: '#/components/responses/NotFound' - /dags/~/dagRuns/list: post: summary: List DAG runs (batch) @@ -708,6 +674,35 @@ paths: '404': $ref: '#/components/responses/NotFound' + patch: + summary: Set a state of DAG run + description: Set a state of DAG run + x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint + operationId: update_dag_run_state + tags: [ UpdateDagRunState ] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateDagRunState' + + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRun' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + /eventLogs: get: summary: List log entries diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index e4dd7f8a4b2ef..47d89cba7c44a 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1170,8 +1170,8 @@ def test_should_respond_200(self, state, dag_maker): request_json = {"state": state} - response = self.client.post( - f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/state", + response = self.client.patch( + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", json=request_json, environ_overrides={"REMOTE_USER": "test"}, ) @@ -1200,8 +1200,8 @@ def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, request_json = {"state": invalid_state} - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", + response = self.client.patch( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", json=request_json, environ_overrides={"REMOTE_USER": "test"}, ) @@ -1214,8 +1214,8 @@ def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, } def test_should_raises_401_unauthenticated(self, session): - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", + response = self.client.patch( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", json={ "state": 'success', }, @@ -1224,8 +1224,8 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) def test_should_raise_403_forbidden(self): - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", + response = self.client.patch( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", json={ "state": 'success', }, @@ -1234,8 +1234,8 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 def test_should_respond_404(self): - response = self.client.post( - "api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/state", + response = self.client.patch( + "api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", json={ "state": 'success', }, From 229dd1ac756e2c2a7893ba3b065d4bf718dbf6da Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Thu, 2 Sep 2021 12:35:38 +0300 Subject: [PATCH 18/19] change summary and description of endpoint Co-authored-by: Tzu-ping Chung --- airflow/api_connexion/openapi/v1.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index cec61b3f340af..8dae5f2fbc6b6 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -675,8 +675,8 @@ paths: $ref: '#/components/responses/NotFound' patch: - summary: Set a state of DAG run - description: Set a state of DAG run + summary: Modify a DAG run + description: Modify a DAG run x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint operationId: update_dag_run_state tags: [ UpdateDagRunState ] From 003e7720f139e23c3f86c5a4ec3592ac00975b78 Mon Sep 17 00:00:00 2001 From: bbenshalom Date: Thu, 2 Sep 2021 14:13:36 +0300 Subject: [PATCH 19/19] fix for flake8 --- airflow/api_connexion/endpoints/dag_run_endpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index f721b8607bb61..e816aac27e157 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -31,7 +31,6 @@ dagruns_batch_form_schema, set_dagrun_state_form_schema, ) -from airflow.exceptions import DagRunNotFound from airflow.models import DagModel, DagRun from airflow.security import permissions from airflow.utils.session import provide_session