diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index c67af4005bdc6..cd317ad819979 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -14,9 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from flask import request +from sqlalchemy import and_, func +from sqlalchemy.orm.session import Session -# TODO(mik-laj): We have to implement it. -# Do you want to help? Please look at: sshttps://github.com/apache/airflow/issues/8134 +from airflow.api_connexion import parameters +from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.schemas.xcom_schema import ( + XComCollection, XComCollectionItemSchema, XComCollectionSchema, xcom_collection_item_schema, + xcom_collection_schema, +) +from airflow.models import DagRun as DR, XCom +from airflow.utils.session import provide_session def delete_xcom_entry(): @@ -26,18 +35,58 @@ def delete_xcom_entry(): raise NotImplementedError("Not implemented yet.") -def get_xcom_entries(): +@provide_session +def get_xcom_entries( + dag_id: str, + dag_run_id: str, + task_id: str, + session: Session +) -> XComCollectionSchema: """ Get all XCom values """ - raise NotImplementedError("Not implemented yet.") + offset = request.args.get(parameters.page_offset, 0) + limit = min(int(request.args.get(parameters.page_limit, 100)), 100) + query = session.query(XCom) + if dag_id != '~': + query = query.filter(XCom.dag_id == dag_id) + query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date)) + else: + query.join(DR, XCom.execution_date == DR.execution_date) + if task_id != '~': + query = query.filter(XCom.task_id == task_id) + if dag_run_id != '~': + query = query.filter(DR.run_id == dag_run_id) + query = query.order_by( + XCom.execution_date, XCom.task_id, XCom.dag_id, XCom.key + ) + total_entries = session.query(func.count(XCom.key)).scalar() + query = query.offset(offset).limit(limit) + return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries)) -def get_xcom_entry(): +@provide_session +def get_xcom_entry( + dag_id: str, + task_id: str, + dag_run_id: str, + xcom_key: str, + session: Session +) -> XComCollectionItemSchema: """ Get an XCom entry """ - raise NotImplementedError("Not implemented yet.") + query = session.query(XCom) + query = query.filter(and_(XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.key == xcom_key)) + query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date)) + query = query.filter(DR.run_id == dag_run_id) + + query_object = query.one_or_none() + if not query_object: + raise NotFound("XCom entry not found") + return xcom_collection_item_schema.dump(query_object) def patch_xcom_entry(): diff --git a/airflow/api_connexion/schemas/xcom_schema.py b/airflow/api_connexion/schemas/xcom_schema.py new file mode 100644 index 0000000000000..5adc36da34da6 --- /dev/null +++ b/airflow/api_connexion/schemas/xcom_schema.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, NamedTuple + +from marshmallow import Schema, fields +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.models import XCom + + +class XComCollectionItemSchema(SQLAlchemySchema): + """ + Schema for a xcom item + """ + + class Meta: + """ Meta """ + model = XCom + + key = auto_field() + timestamp = auto_field() + execution_date = auto_field() + task_id = auto_field() + dag_id = auto_field() + + +class XComSchema(XComCollectionItemSchema): + """ + XCom schema + """ + + value = auto_field() + + +class XComCollection(NamedTuple): + """ List of XComs with meta""" + xcom_entries: List[XCom] + total_entries: int + + +class XComCollectionSchema(Schema): + """ XCom Collection Schema""" + xcom_entries = fields.List(fields.Nested(XComCollectionItemSchema)) + total_entries = fields.Int() + + +xcom_schema = XComSchema(strict=True) +xcom_collection_item_schema = XComCollectionItemSchema(strict=True) +xcom_collection_schema = XComCollectionSchema(strict=True) diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 2eabdcd0e8a7a..3f36e602d8534 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -17,60 +17,247 @@ import unittest import pytest +from parameterized import parameterized +from airflow.models import DagRun as DR, XCom +from airflow.utils.dates import parse_execution_date +from airflow.utils.session import create_session, provide_session +from airflow.utils.types import DagRunType from airflow.www import app -class TesXComEndpoint(unittest.TestCase): +class TestXComEndpoint(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() cls.app = app.create_app(testing=True) # type:ignore def setUp(self) -> None: + """ + Setup For XCom endpoint TC + """ self.client = self.app.test_client() # type:ignore + # clear existing xcoms + with create_session() as session: + session.query(XCom).delete() + session.query(DR).delete() + def tearDown(self) -> None: + """ + Clear Hanging XComs + """ + with create_session() as session: + session.query(XCom).delete() + session.query(DR).delete() -class TestDeleteXComEntry(TesXComEndpoint): + +class TestDeleteXComEntry(TestXComEndpoint): @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): response = self.client.delete( - "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY" + "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries/XCOM_KEY" ) assert response.status_code == 204 -class TestGetXComEntry(TesXComEndpoint): - @pytest.mark.skip(reason="Not implemented yet") - def test_should_response_200(self): +class TestGetXComEntry(TestXComEndpoint): + + @provide_session + def test_should_response_200(self, session): + dag_id = 'test-dag-id' + task_id = 'test-task-id' + execution_date = '2005-04-02T00:00:00+00:00' + xcom_key = 'test-xcom-key' + execution_date_parsed = parse_execution_date(execution_date) + xcom_model = XCom(key=xcom_key, + execution_date=execution_date_parsed, + task_id=task_id, + dag_id=dag_id, + timestamp=execution_date_parsed) + dag_run_id = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + dagrun = DR(dag_id=dag_id, + run_id=dag_run_id, + execution_date=execution_date_parsed, + start_date=execution_date_parsed, + run_type=DagRunType.MANUAL.value) + session.add(xcom_model) + session.add(dagrun) + session.commit() response = self.client.get( - "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY" + f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}" + ) + self.assertEqual(200, response.status_code) + self.assertEqual( + response.json, + { + 'dag_id': dag_id, + 'execution_date': execution_date, + 'key': xcom_key, + 'task_id': task_id, + 'timestamp': execution_date + } ) - assert response.status_code == 200 -class TestGetXComEntries(TesXComEndpoint): - @pytest.mark.skip(reason="Not implemented yet") - def test_should_response_200(self): +class TestGetXComEntries(TestXComEndpoint): + @provide_session + def test_should_response_200(self, session): + dag_id = 'test-dag-id' + task_id = 'test-task-id' + execution_date = '2005-04-02T00:00:00+00:00' + execution_date_parsed = parse_execution_date(execution_date) + xcom_model_1 = XCom(key='test-xcom-key-1', + execution_date=execution_date_parsed, + task_id=task_id, + dag_id=dag_id, + timestamp=execution_date_parsed) + xcom_model_2 = XCom(key='test-xcom-key-2', + execution_date=execution_date_parsed, + task_id=task_id, + dag_id=dag_id, + timestamp=execution_date_parsed) + dag_run_id = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + dagrun = DR(dag_id=dag_id, + run_id=dag_run_id, + execution_date=execution_date_parsed, + start_date=execution_date_parsed, + run_type=DagRunType.MANUAL.value) + xcom_models = [xcom_model_1, xcom_model_2] + session.add_all(xcom_models) + session.add(dagrun) + session.commit() response = self.client.get( - "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/" + f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries" + ) + self.assertEqual(200, response.status_code) + self.assertEqual( + response.json, + { + 'xcom_entries': [ + { + 'dag_id': dag_id, + 'execution_date': execution_date, + 'key': 'test-xcom-key-1', + 'task_id': task_id, + 'timestamp': execution_date + }, + { + 'dag_id': dag_id, + 'execution_date': execution_date, + 'key': 'test-xcom-key-2', + 'task_id': task_id, + 'timestamp': execution_date + } + ], + 'total_entries': 2, + } ) + + +class TestPaginationGetXComEntries(TestXComEndpoint): + + def setUp(self): + super().setUp() + self.dag_id = 'test-dag-id' + self.task_id = 'test-task-id' + self.execution_date = '2005-04-02T00:00:00+00:00' + self.execution_date_parsed = parse_execution_date(self.execution_date) + self.dag_run_id = DR.generate_run_id(DagRunType.MANUAL, self.execution_date_parsed) + + @parameterized.expand( + [ + ( + "limit=1", + ["TEST_XCOM_KEY1"], + ), + ( + "limit=2", + ["TEST_XCOM_KEY1", "TEST_XCOM_KEY10"], + ), + ( + "offset=5", + [ + "TEST_XCOM_KEY5", + "TEST_XCOM_KEY6", + "TEST_XCOM_KEY7", + "TEST_XCOM_KEY8", + "TEST_XCOM_KEY9", + ] + ), + ( + "offset=0", + [ + "TEST_XCOM_KEY1", + "TEST_XCOM_KEY10", + "TEST_XCOM_KEY2", + "TEST_XCOM_KEY3", + "TEST_XCOM_KEY4", + "TEST_XCOM_KEY5", + "TEST_XCOM_KEY6", + "TEST_XCOM_KEY7", + "TEST_XCOM_KEY8", + "TEST_XCOM_KEY9" + ] + ), + ( + "limit=1&offset=5", + ["TEST_XCOM_KEY5"], + ), + ( + "limit=1&offset=1", + ["TEST_XCOM_KEY10"], + ), + ( + "limit=2&offset=2", + ["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"], + ), + ] + ) + @provide_session + def test_handle_limit_offset(self, query_params, expected_xcom_ids, session): + url = "/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries?{query_params}" + url = url.format(dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + query_params=query_params) + dagrun = DR(dag_id=self.dag_id, + run_id=self.dag_run_id, + execution_date=self.execution_date_parsed, + start_date=self.execution_date_parsed, + run_type=DagRunType.MANUAL.value) + xcom_models = self._create_xcoms(10) + session.add_all(xcom_models) + session.add(dagrun) + session.commit() + response = self.client.get(url) assert response.status_code == 200 + self.assertEqual(response.json["total_entries"], 10) + conn_ids = [conn["key"] for conn in response.json["xcom_entries"] if conn] + self.assertEqual(conn_ids, expected_xcom_ids) + + def _create_xcoms(self, count): + return [XCom( + key=f'TEST_XCOM_KEY{i}', + execution_date=self.execution_date_parsed, + task_id=self.task_id, + dag_id=self.dag_id, + timestamp=self.execution_date_parsed, + ) for i in range(1, count + 1)] -class TestPatchXComEntry(TesXComEndpoint): +class TestPatchXComEntry(TestXComEndpoint): @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): response = self.client.patch( - "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries" + "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries" ) assert response.status_code == 200 -class TestPostXComEntry(TesXComEndpoint): +class TestPostXComEntry(TestXComEndpoint): @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): response = self.client.post( - "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY" + "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries/XCOM_KEY" ) assert response.status_code == 200 diff --git a/tests/api_connexion/schemas/test_xcom_schema.py b/tests/api_connexion/schemas/test_xcom_schema.py new file mode 100644 index 0000000000000..d66c8ce58f89b --- /dev/null +++ b/tests/api_connexion/schemas/test_xcom_schema.py @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest + +from sqlalchemy import or_ + +from airflow.api_connexion.schemas.xcom_schema import ( + XComCollection, xcom_collection_item_schema, xcom_collection_schema, xcom_schema, +) +from airflow.models import XCom +from airflow.utils.dates import parse_execution_date +from airflow.utils.session import create_session, provide_session + + +class TestXComSchemaBase(unittest.TestCase): + + def setUp(self): + """ + Clear Hanging XComs pre test + """ + with create_session() as session: + session.query(XCom).delete() + + def tearDown(self) -> None: + """ + Clear Hanging XComs post test + """ + with create_session() as session: + session.query(XCom).delete() + + +class TestXComCollectionItemSchema(TestXComSchemaBase): + + def setUp(self) -> None: + super().setUp() + self.default_time = '2005-04-02T21:00:00+00:00' + self.default_time_parsed = parse_execution_date(self.default_time) + + @provide_session + def test_serialize(self, session): + xcom_model = XCom( + key='test_key', + timestamp=self.default_time_parsed, + execution_date=self.default_time_parsed, + task_id='test_task_id', + dag_id='test_dag', + ) + session.add(xcom_model) + session.commit() + xcom_model = session.query(XCom).first() + deserialized_xcom = xcom_collection_item_schema.dump(xcom_model) + self.assertEqual( + deserialized_xcom[0], + { + 'key': 'test_key', + 'timestamp': self.default_time, + 'execution_date': self.default_time, + 'task_id': 'test_task_id', + 'dag_id': 'test_dag', + } + ) + + def test_deserialize(self): + xcom_dump = { + 'key': 'test_key', + 'timestamp': self.default_time, + 'execution_date': self.default_time, + 'task_id': 'test_task_id', + 'dag_id': 'test_dag', + } + result = xcom_collection_item_schema.load(xcom_dump) + self.assertEqual( + result[0], + { + 'key': 'test_key', + 'timestamp': self.default_time_parsed, + 'execution_date': self.default_time_parsed, + 'task_id': 'test_task_id', + 'dag_id': 'test_dag', + } + ) + + +class TestXComCollectionSchema(TestXComSchemaBase): + + def setUp(self) -> None: + super().setUp() + self.default_time_1 = '2005-04-02T21:00:00+00:00' + self.default_time_2 = '2005-04-02T21:01:00+00:00' + self.time_1 = parse_execution_date(self.default_time_1) + self.time_2 = parse_execution_date(self.default_time_2) + + @provide_session + def test_serialize(self, session): + xcom_model_1 = XCom( + key='test_key_1', + timestamp=self.time_1, + execution_date=self.time_1, + task_id='test_task_id_1', + dag_id='test_dag_1', + ) + xcom_model_2 = XCom( + key='test_key_2', + timestamp=self.time_2, + execution_date=self.time_2, + task_id='test_task_id_2', + dag_id='test_dag_2', + ) + xcom_models = [xcom_model_1, xcom_model_2] + session.add_all(xcom_models) + session.commit() + xcom_models_query = session.query(XCom).filter( + or_(XCom.execution_date == self.time_1, XCom.execution_date == self.time_2) + ) + xcom_models_queried = xcom_models_query.all() + deserialized_xcoms = xcom_collection_schema.dump(XComCollection( + xcom_entries=xcom_models_queried, + total_entries=xcom_models_query.count(), + )) + self.assertEqual( + deserialized_xcoms[0], + { + 'xcom_entries': [ + { + 'key': 'test_key_1', + 'timestamp': self.default_time_1, + 'execution_date': self.default_time_1, + 'task_id': 'test_task_id_1', + 'dag_id': 'test_dag_1', + }, + { + 'key': 'test_key_2', + 'timestamp': self.default_time_2, + 'execution_date': self.default_time_2, + 'task_id': 'test_task_id_2', + 'dag_id': 'test_dag_2', + } + ], + 'total_entries': len(xcom_models), + } + ) + + +class TestXComSchema(TestXComSchemaBase): + + def setUp(self) -> None: + super().setUp() + self.default_time = '2005-04-02T21:00:00+00:00' + self.default_time_parsed = parse_execution_date(self.default_time) + + @provide_session + def test_serialize(self, session): + xcom_model = XCom( + key='test_key', + timestamp=self.default_time_parsed, + execution_date=self.default_time_parsed, + task_id='test_task_id', + dag_id='test_dag', + value=b'test_binary', + ) + session.add(xcom_model) + session.commit() + xcom_model = session.query(XCom).first() + deserialized_xcom = xcom_schema.dump(xcom_model) + self.assertEqual( + deserialized_xcom[0], + { + 'key': 'test_key', + 'timestamp': self.default_time, + 'execution_date': self.default_time, + 'task_id': 'test_task_id', + 'dag_id': 'test_dag', + 'value': 'test_binary', + } + ) + + def test_deserialize(self): + xcom_dump = { + 'key': 'test_key', + 'timestamp': self.default_time, + 'execution_date': self.default_time, + 'task_id': 'test_task_id', + 'dag_id': 'test_dag', + 'value': b'test_binary', + } + result = xcom_schema.load(xcom_dump) + self.assertEqual( + result[0], + { + 'key': 'test_key', + 'timestamp': self.default_time_parsed, + 'execution_date': self.default_time_parsed, + 'task_id': 'test_task_id', + 'dag_id': 'test_dag', + 'value': 'test_binary', + } + )