From 6502d2b65325bae31b5a586f4b685c2a71201c1f Mon Sep 17 00:00:00 2001 From: dirrao Date: Tue, 27 Aug 2024 22:03:25 +0530 Subject: [PATCH 01/10] airflow.models.xcom deprecations removed --- airflow/models/baseoperator.py | 6 +- airflow/models/taskinstance.py | 15 +- airflow/models/xcom.py | 233 +++----------------------- tests/models/test_xcom.py | 166 ------------------ tests/providers/microsoft/conftest.py | 3 - 5 files changed, 28 insertions(+), 395 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 7d18aa1474451..2347289428512 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1592,7 +1592,6 @@ def xcom_push( context: Any, key: str, value: Any, - execution_date: datetime | None = None, ) -> None: """ Make an XCom available for tasks to pull. @@ -1601,11 +1600,8 @@ def xcom_push( :param key: A key for the XCom :param value: A value for the XCom. The value is pickled and stored in the database. - :param execution_date: if provided, the XCom will not be visible until - this date. This can be used, for example, to send a message to a - task on a future date without it being immediately visible. """ - context["ti"].xcom_push(key=key, value=value, execution_date=execution_date) + context["ti"].xcom_push(key=key, value=value) @staticmethod @provide_session diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 0d633f8bf3d38..049a01954dd47 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -3473,7 +3473,6 @@ def xcom_push( self, key: str, value: Any, - execution_date: datetime | None = None, session: Session = NEW_SESSION, ) -> None: """ @@ -3483,19 +3482,7 @@ def xcom_push( :param value: Value to store. What types are possible depends on whether ``enable_xcom_pickling`` is true or not. If so, this can be any picklable object; only be JSON-serializable may be used otherwise. - :param execution_date: Deprecated parameter that has no effect. - """ - if execution_date is not None: - self_execution_date = self.get_dagrun(session).execution_date - if execution_date < self_execution_date: - raise ValueError( - f"execution_date can not be in the past (current execution_date is " - f"{self_execution_date}; received {execution_date})" - ) - elif execution_date is not None: - message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) - + """ XCom.set( key=key, value=value, diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 9829f11fbbde7..c342233302d9b 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -21,8 +21,6 @@ import json import logging import pickle -import warnings -from functools import wraps from typing import TYPE_CHECKING, Any, Iterable, cast, overload from sqlalchemy import ( @@ -40,15 +38,13 @@ from sqlalchemy.dialects.mysql import LONGBLOB from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Query, reconstructor, relationship -from sqlalchemy.orm.exc import NoResultFound from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf -from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies from airflow.utils import timezone from airflow.utils.db import LazySelectSequence -from airflow.utils.helpers import exactly_one, is_container +from airflow.utils.helpers import is_container from airflow.utils.json import XComDecoder, XComEncoder from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session @@ -64,9 +60,6 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: - import datetime - - import pendulum from sqlalchemy.engine import Row from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause @@ -150,9 +143,6 @@ def set( """ Store an XCom value. - A deprecated form of this function accepts ``execution_date`` instead of - ``run_id``. The two arguments are mutually exclusive. - :param key: Key to store the XCom. :param value: XCom value to store. :param dag_id: DAG ID. @@ -164,23 +154,6 @@ def set( created for this function. """ - @overload - @classmethod - def set( - cls, - key: str, - value: Any, - task_id: str, - dag_id: str, - execution_date: datetime.datetime, - session: Session = NEW_SESSION, - ) -> None: - """ - Store an XCom value. - - :sphinx-autoapi-skip: - """ - @classmethod @internal_api_call @provide_session @@ -190,7 +163,6 @@ def set( value: Any, task_id: str, dag_id: str, - execution_date: datetime.datetime | None = None, session: Session = NEW_SESSION, *, run_id: str | None = None, @@ -203,27 +175,12 @@ def set( """ from airflow.models.dagrun import DagRun - if not exactly_one(execution_date is not None, run_id is not None): - raise ValueError( - f"Exactly one of run_id or execution_date must be passed. " - f"Passed execution_date={execution_date}, run_id={run_id}" - ) + if not run_id: + raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - if run_id is None: - message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) - try: - dag_run_id, run_id = ( - session.query(DagRun.id, DagRun.run_id) - .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) - .one() - ) - except NoResultFound: - raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None - else: - dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() - if dag_run_id is None: - raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") + dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() + if dag_run_id is None: + raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") # Seamlessly resolve LazySelectSequence to a list. This intends to work # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if @@ -242,7 +199,7 @@ def set( "return value" if key == XCOM_RETURN_KEY else f"value {key}", task_id, dag_id, - run_id or execution_date, + run_id, ) value = list(value) @@ -333,9 +290,6 @@ def get_one( If there are no results, *None* is returned. If multiple XCom entries match the criteria, an arbitrary one is returned. - A deprecated form of this function accepts ``execution_date`` instead of - ``run_id``. The two arguments are mutually exclusive. - .. seealso:: ``get_value()`` is a convenience function if you already have a structured TaskInstance or TaskInstanceKey object available. @@ -355,28 +309,10 @@ def get_one( created for this function. """ - @overload - @staticmethod - @internal_api_call - def get_one( - execution_date: datetime.datetime, - key: str | None = None, - task_id: str | None = None, - dag_id: str | None = None, - include_prior_dates: bool = False, - session: Session = NEW_SESSION, - ) -> Any | None: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - :sphinx-autoapi-skip: - """ - @staticmethod @provide_session @internal_api_call def get_one( - execution_date: datetime.datetime | None = None, key: str | None = None, task_id: str | None = None, dag_id: str | None = None, @@ -391,38 +327,19 @@ def get_one( :sphinx-autoapi-skip: """ - if not exactly_one(execution_date is not None, run_id is not None): - raise ValueError("Exactly one of run_id or execution_date must be passed") - - if run_id: - query = BaseXCom.get_many( - run_id=run_id, - key=key, - task_ids=task_id, - dag_ids=dag_id, - map_indexes=map_index, - include_prior_dates=include_prior_dates, - limit=1, - session=session, - ) - elif execution_date is not None: - message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - query = BaseXCom.get_many( - execution_date=execution_date, - key=key, - task_ids=task_id, - dag_ids=dag_id, - map_indexes=map_index, - include_prior_dates=include_prior_dates, - limit=1, - session=session, - ) - else: - raise RuntimeError("Should not happen?") + if not run_id: + raise ValueError("run_id must be passed") + + query = BaseXCom.get_many( + run_id=run_id, + key=key, + task_ids=task_id, + dag_ids=dag_id, + map_indexes=map_index, + include_prior_dates=include_prior_dates, + limit=1, + session=session, + ) result = query.with_entities(BaseXCom.value).first() if result: @@ -448,9 +365,6 @@ def get_many( This function returns an SQLAlchemy query of full XCom objects. If you just want one stored value, use :meth:`get_one` instead. - A deprecated form of this function accepts ``execution_date`` instead of - ``run_id``. The two arguments are mutually exclusive. - :param run_id: DAG run ID for the task. :param key: A key for the XComs. If provided, only XComs with matching keys will be returned. Pass *None* (default) to remove the filter. @@ -468,32 +382,12 @@ def get_many( :param limit: Limiting returning XComs """ - @overload - @staticmethod - @internal_api_call - def get_many( - execution_date: datetime.datetime, - key: str | None = None, - task_ids: str | Iterable[str] | None = None, - dag_ids: str | Iterable[str] | None = None, - map_indexes: int | Iterable[int] | None = None, - include_prior_dates: bool = False, - limit: int | None = None, - session: Session = NEW_SESSION, - ) -> Query: - """ - Composes a query to get one or more XCom entries. - - :sphinx-autoapi-skip: - """ - # The 'get_many` is not supported via database isolation mode. Attempting to use it in DB isolation # mode will result in a crash - Resulting Query object cannot be **really** serialized # TODO(potiuk) - document it in AIP-44 docs @staticmethod @provide_session def get_many( - execution_date: datetime.datetime | None = None, key: str | None = None, task_ids: str | Iterable[str] | None = None, dag_ids: str | Iterable[str] | None = None, @@ -511,14 +405,8 @@ def get_many( """ from airflow.models.dagrun import DagRun - if not exactly_one(execution_date is not None, run_id is not None): - raise ValueError( - f"Exactly one of run_id or execution_date must be passed. " - f"Passed execution_date={execution_date}, run_id={run_id}" - ) - if execution_date is not None: - message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) + if not run_id: + raise ValueError(f"run_id must be passed. Passed run_id={run_id}") query = session.query(BaseXCom).join(BaseXCom.dag_run) @@ -545,13 +433,8 @@ def get_many( query = query.filter(BaseXCom.map_index == map_indexes) if include_prior_dates: - if execution_date is not None: - query = query.filter(DagRun.execution_date <= execution_date) - else: - dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery() - query = query.filter(BaseXCom.execution_date <= dr.c.execution_date) - elif execution_date is not None: - query = query.filter(DagRun.execution_date == execution_date) + dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery() + query = query.filter(BaseXCom.execution_date <= dr.c.execution_date) else: query = query.filter(BaseXCom.run_id == run_id) @@ -592,9 +475,6 @@ def clear( """ Clear all XCom data from the database for the given task instance. - A deprecated form of this function accepts ``execution_date`` instead of - ``run_id``. The two arguments are mutually exclusive. - :param dag_id: ID of DAG to clear the XCom for. :param task_id: ID of task to clear the XCom for. :param run_id: ID of DAG run to clear the XCom for. @@ -604,26 +484,10 @@ def clear( created for this function. """ - @overload - @staticmethod - @internal_api_call - def clear( - execution_date: pendulum.DateTime, - dag_id: str, - task_id: str, - session: Session = NEW_SESSION, - ) -> None: - """ - Clear all XCom data from the database for the given task instance. - - :sphinx-autoapi-skip: - """ - @staticmethod @provide_session @internal_api_call def clear( - execution_date: pendulum.DateTime | None = None, dag_id: str | None = None, task_id: str | None = None, session: Session = NEW_SESSION, @@ -636,8 +500,6 @@ def clear( :sphinx-autoapi-skip: """ - from airflow.models import DagRun - # Given the historic order of this function (execution_date was first argument) to add a new optional # param we need to add default values for everything :( if dag_id is None: @@ -645,20 +507,8 @@ def clear( if task_id is None: raise TypeError("clear() missing required argument: task_id") - if not exactly_one(execution_date is not None, run_id is not None): - raise ValueError( - f"Exactly one of run_id or execution_date must be passed. " - f"Passed execution_date={execution_date}, run_id={run_id}" - ) - - if execution_date is not None: - message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) - run_id = ( - session.query(DagRun.run_id) - .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) - .scalar() - ) + if not run_id: + raise ValueError(f"run_id must be passed. Passed run_id={run_id}") query = session.query(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) if map_index is not None: @@ -747,33 +597,6 @@ def _process_row(row: Row) -> Any: return XCom.deserialize_value(row) -def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None: - """ - Patch a custom ``serialize_value`` to accept the modern signature. - - To give custom XCom backends more flexibility with how they store values, we - now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``. - In order to maintain compatibility with custom XCom backends written with - the old signature, we check the signature and, if necessary, patch with a - method that ignores kwargs the backend does not accept. - """ - old_serializer = clazz.serialize_value - - @wraps(old_serializer) - def _shim(**kwargs): - kwargs = {k: kwargs.get(k) for k in params} - warnings.warn( - f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and" - f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be " - f"removed in a future release.", - RemovedInAirflow3Warning, - stacklevel=1, - ) - return old_serializer(**kwargs) - - clazz.serialize_value = _shim # type: ignore[assignment] - - def _get_function_params(function) -> list[str]: """ Return the list of variables names of a function. @@ -801,10 +624,6 @@ def resolve_xcom_backend() -> type[BaseXCom]: raise TypeError( f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`." ) - base_xcom_params = _get_function_params(BaseXCom.serialize_value) - xcom_params = _get_function_params(clazz.serialize_value) - if set(base_xcom_params) != set(xcom_params): - _patch_outdated_serializer(clazz=clazz, params=xcom_params) return clazz diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index e9db3d946d8e6..07533ec944be8 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -214,33 +214,6 @@ def test_get_one_custom_backend_no_use_orm_deserialize_value(self, task_instance assert value == {"key": "value"} XCom.orm_deserialize_value.assert_not_called() - @pytest.mark.skip_if_database_isolation_mode - @conf_vars({("core", "enable_xcom_pickling"): "False"}) - @mock.patch("airflow.models.xcom.conf.getimport") - def test_set_serialize_call_old_signature(self, get_import, task_instance): - """ - When XCom.serialize_value takes only param ``value``, other kwargs should be ignored. - """ - serialize_watcher = MagicMock() - - class OldSignatureXCom(BaseXCom): - @staticmethod - def serialize_value(value, **kwargs): - serialize_watcher(value=value, **kwargs) - return json.dumps(value).encode("utf-8") - - get_import.return_value = OldSignatureXCom - - XCom = resolve_xcom_backend() - XCom.set( - key=XCOM_RETURN_KEY, - value={"my_xcom_key": "my_xcom_value"}, - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - run_id=task_instance.run_id, - ) - serialize_watcher.assert_called_once_with(value={"my_xcom_key": "my_xcom_value"}) - @pytest.mark.skip_if_database_isolation_mode @conf_vars({("core", "enable_xcom_pickling"): "False"}) @mock.patch("airflow.models.xcom.conf.getimport") @@ -335,19 +308,6 @@ def test_xcom_get_one(self, session, task_instance): ) assert stored_value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_get_one") - def test_xcom_get_one_with_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - stored_value = XCom.get_one( - key="xcom_1", - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - execution_date=task_instance.execution_date, - session=session, - ) - assert stored_value == {"key": "value"} - @pytest.fixture def tis_for_xcom_get_one_from_prior_date(self, task_instance_factory, push_simple_json_xcom): date1 = timezone.datetime(2021, 12, 3, 4, 56) @@ -376,24 +336,6 @@ def test_xcom_get_one_from_prior_date(self, session, tis_for_xcom_get_one_from_p ) assert retrieved_value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode - def test_xcom_get_one_from_prior_with_execution_date( - self, - session, - tis_for_xcom_get_one_from_prior_date, - ): - _, ti2 = tis_for_xcom_get_one_from_prior_date - with pytest.deprecated_call(): - retrieved_value = XCom.get_one( - execution_date=ti2.execution_date, - key="xcom_1", - task_id="task_1", - dag_id="dag", - include_prior_dates=True, - session=session, - ) - assert retrieved_value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode @pytest.fixture def setup_for_xcom_get_many_single_argument_value(self, task_instance, push_simple_json_xcom): @@ -413,21 +355,6 @@ def test_xcom_get_many_single_argument_value(self, session, task_instance): assert stored_xcoms[0].key == "xcom_1" assert stored_xcoms[0].value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_get_many_single_argument_value") - def test_xcom_get_many_single_argument_value_with_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - stored_xcoms = XCom.get_many( - execution_date=task_instance.execution_date, - key="xcom_1", - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - session=session, - ).all() - assert len(stored_xcoms) == 1 - assert stored_xcoms[0].key == "xcom_1" - assert stored_xcoms[0].value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode @pytest.fixture def setup_for_xcom_get_many_multiple_tasks(self, task_instances, push_simple_json_xcom): @@ -448,20 +375,6 @@ def test_xcom_get_many_multiple_tasks(self, session, task_instance): sorted_values = [x.value for x in sorted(stored_xcoms, key=operator.attrgetter("task_id"))] assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}] - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_get_many_multiple_tasks") - def test_xcom_get_many_multiple_tasks_with_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - stored_xcoms = XCom.get_many( - execution_date=task_instance.execution_date, - key="xcom_1", - dag_ids=task_instance.dag_id, - task_ids=["task_1", "task_2"], - session=session, - ) - sorted_values = [x.value for x in sorted(stored_xcoms, key=operator.attrgetter("task_id"))] - assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}] - @pytest.fixture def tis_for_xcom_get_many_from_prior_dates(self, task_instance_factory, push_simple_json_xcom): date1 = timezone.datetime(2021, 12, 3, 4, 56) @@ -488,27 +401,6 @@ def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_fro assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}] assert [x.execution_date for x in stored_xcoms] == [ti2.execution_date, ti1.execution_date] - @pytest.mark.skip_if_database_isolation_mode - def test_xcom_get_many_from_prior_dates_with_execution_date( - self, - session, - tis_for_xcom_get_many_from_prior_dates, - ): - ti1, ti2 = tis_for_xcom_get_many_from_prior_dates - with pytest.deprecated_call(): - stored_xcoms = XCom.get_many( - execution_date=ti2.execution_date, - key="xcom_1", - dag_ids="dag", - task_ids="task_1", - include_prior_dates=True, - session=session, - ) - - # The retrieved XComs should be ordered by logical date, latest first. - assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}] - assert [x.execution_date for x in stored_xcoms] == [ti2.execution_date, ti1.execution_date] - @pytest.mark.usefixtures("setup_xcom_pickling") class TestXComSet: @@ -528,24 +420,6 @@ def test_xcom_set(self, session, task_instance): assert stored_xcoms[0].task_id == "task_1" assert stored_xcoms[0].execution_date == task_instance.execution_date - @pytest.mark.skip_if_database_isolation_mode - def test_xcom_set_with_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - XCom.set( - key="xcom_1", - value={"key": "value"}, - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - execution_date=task_instance.execution_date, - session=session, - ) - stored_xcoms = session.query(XCom).all() - assert stored_xcoms[0].key == "xcom_1" - assert stored_xcoms[0].value == {"key": "value"} - assert stored_xcoms[0].dag_id == "dag" - assert stored_xcoms[0].task_id == "task_1" - assert stored_xcoms[0].execution_date == task_instance.execution_date - @pytest.fixture def setup_for_xcom_set_again_replace(self, task_instance, push_simple_json_xcom): push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key1": "value1"}) @@ -563,21 +437,6 @@ def test_xcom_set_again_replace(self, session, task_instance): ) assert session.query(XCom).one().value == {"key2": "value2"} - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_set_again_replace") - def test_xcom_set_again_replace_with_execution_date(self, session, task_instance): - assert session.query(XCom).one().value == {"key1": "value1"} - with pytest.deprecated_call(): - XCom.set( - key="xcom_1", - value={"key2": "value2"}, - dag_id=task_instance.dag_id, - task_id="task_1", - execution_date=task_instance.execution_date, - session=session, - ) - assert session.query(XCom).one().value == {"key2": "value2"} - @pytest.mark.usefixtures("setup_xcom_pickling") class TestXComClear: @@ -598,19 +457,6 @@ def test_xcom_clear(self, mock_purge, session, task_instance): assert session.query(XCom).count() == 0 assert mock_purge.call_count == 0 if is_db_isolation_mode() else 1 - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_clear") - def test_xcom_clear_with_execution_date(self, session, task_instance): - assert session.query(XCom).count() == 1 - with pytest.deprecated_call(): - XCom.clear( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - execution_date=task_instance.execution_date, - session=session, - ) - assert session.query(XCom).count() == 0 - @pytest.mark.usefixtures("setup_for_xcom_clear") def test_xcom_clear_different_run(self, session, task_instance): XCom.clear( @@ -620,15 +466,3 @@ def test_xcom_clear_different_run(self, session, task_instance): session=session, ) assert session.query(XCom).count() == 1 - - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_clear") - def test_xcom_clear_different_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - XCom.clear( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - execution_date=timezone.utcnow(), - session=session, - ) - assert session.query(XCom).count() == 1 diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index c77dd7747d19d..a33e803df8c09 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -111,8 +111,6 @@ def mock_response(status_code, content: Any = None, headers: dict | None = None) def mock_context(task) -> Context: - from datetime import datetime - from airflow.models import TaskInstance from airflow.utils.session import NEW_SESSION from airflow.utils.state import TaskInstanceState @@ -150,7 +148,6 @@ def xcom_push( self, key: str, value: Any, - execution_date: datetime | None = None, session: Session = NEW_SESSION, ) -> None: values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value From f740c11466554fc0b12d3dc2e000f30492faf0b5 Mon Sep 17 00:00:00 2001 From: dirrao Date: Wed, 28 Aug 2024 10:11:21 +0530 Subject: [PATCH 02/10] airflow.models.xcom deprecations removed --- airflow/models/xcom.py | 64 +++++++++++++++++++ .../amazon/aws/links/test_base_aws.py | 1 - .../cloud/operators/test_bigquery_dts.py | 4 +- .../google/cloud/operators/test_dataproc.py | 28 +++----- tests/providers/yandex/links/test_yq.py | 1 - 5 files changed, 74 insertions(+), 24 deletions(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index c342233302d9b..13df9fe84e28c 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -154,6 +154,22 @@ def set( created for this function. """ + @overload + @classmethod + def set( + cls, + key: str, + value: Any, + task_id: str, + dag_id: str, + session: Session = NEW_SESSION, + ) -> None: + """ + Store an XCom value. + + :sphinx-autoapi-skip: + """ + @classmethod @internal_api_call @provide_session @@ -309,6 +325,22 @@ def get_one( created for this function. """ + @overload + @staticmethod + @internal_api_call + def get_one( + key: str | None = None, + task_id: str | None = None, + dag_id: str | None = None, + include_prior_dates: bool = False, + session: Session = NEW_SESSION, + ) -> Any | None: + """ + Retrieve an XCom value, optionally meeting certain criteria. + + :sphinx-autoapi-skip: + """ + @staticmethod @provide_session @internal_api_call @@ -382,6 +414,24 @@ def get_many( :param limit: Limiting returning XComs """ + @overload + @staticmethod + @internal_api_call + def get_many( + key: str | None = None, + task_ids: str | Iterable[str] | None = None, + dag_ids: str | Iterable[str] | None = None, + map_indexes: int | Iterable[int] | None = None, + include_prior_dates: bool = False, + limit: int | None = None, + session: Session = NEW_SESSION, + ) -> Query: + """ + Composes a query to get one or more XCom entries. + + :sphinx-autoapi-skip: + """ + # The 'get_many` is not supported via database isolation mode. Attempting to use it in DB isolation # mode will result in a crash - Resulting Query object cannot be **really** serialized # TODO(potiuk) - document it in AIP-44 docs @@ -484,6 +534,20 @@ def clear( created for this function. """ + @overload + @staticmethod + @internal_api_call + def clear( + dag_id: str, + task_id: str, + session: Session = NEW_SESSION, + ) -> None: + """ + Clear all XCom data from the database for the given task instance. + + :sphinx-autoapi-skip: + """ + @staticmethod @provide_session @internal_api_call diff --git a/tests/providers/amazon/aws/links/test_base_aws.py b/tests/providers/amazon/aws/links/test_base_aws.py index 446d584edf358..ba5f6e0014b36 100644 --- a/tests/providers/amazon/aws/links/test_base_aws.py +++ b/tests/providers/amazon/aws/links/test_base_aws.py @@ -76,7 +76,6 @@ def test_persist(self, region_name, aws_partition, keywords, expected_value): ti = mock_context["ti"] ti.xcom_push.assert_called_once_with( - execution_date=None, key=XCOM_KEY, value=expected_value, ) diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py b/tests/providers/google/cloud/operators/test_bigquery_dts.py index b3145151d3da6..b89d3d7463e6b 100644 --- a/tests/providers/google/cloud/operators/test_bigquery_dts.py +++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py @@ -71,7 +71,7 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) - ti.xcom_push.assert_called_with(execution_date=None, key="transfer_config_id", value="1a2b3c") + ti.xcom_push.assert_called_with(key="transfer_config_id", value="1a2b3c") assert "secret_access_key" not in return_value.get("params", {}) assert "access_key_id" not in return_value.get("params", {}) @@ -126,7 +126,7 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) - ti.xcom_push.assert_called_with(execution_date=None, key="run_id", value="123") + ti.xcom_push.assert_called_with(key="run_id", value="123") @mock.patch( f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook", diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 1d1f2a1ef818a..f711d45235817 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -441,7 +441,7 @@ class DataprocJobTestBase(DataprocTestBase): @classmethod def setup_class(cls): cls.extra_links_expected_calls = [ - call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_JOB_CONF_EXPECTED), + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), ] @@ -451,7 +451,7 @@ class DataprocClusterTestBase(DataprocTestBase): def setup_class(cls): super().setup_class() cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(execution_date=None, key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) + call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) ] @@ -761,7 +761,6 @@ def test_execute(self, mock_hook, to_dict_mock): self.mock_ti.xcom_push.assert_called_once_with( key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @@ -811,7 +810,6 @@ def test_execute_in_gke(self, mock_hook, to_dict_mock): self.mock_ti.xcom_push.assert_called_once_with( key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @@ -1096,7 +1094,7 @@ class TestDataprocClusterScaleOperator(DataprocClusterTestBase): def setup_class(cls): super().setup_class() cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + call.ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) ] def test_deprecation_warning(self): @@ -1145,7 +1143,6 @@ def test_execute(self, mock_hook): self.mock_ti.xcom_push.assert_called_once_with( key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, - execution_date=None, ) @@ -1310,9 +1307,7 @@ def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): - xcom_push_call = call.ti.xcom_push( - execution_date=None, key="dataproc_job", value=DATAPROC_JOB_EXPECTED - ) + xcom_push_call = call.ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) wait_for_job_call = call.hook().wait_for_job( job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT, timeout=None ) @@ -1358,9 +1353,7 @@ def test_execute(self, mock_hook): job_id=TEST_JOB_ID, project_id=GCP_PROJECT, region=GCP_REGION, timeout=None ) - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None - ) + self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_async(self, mock_hook): @@ -1398,9 +1391,7 @@ def test_execute_async(self, mock_hook): ) mock_hook.return_value.wait_for_job.assert_not_called() - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None - ) + self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) @@ -1636,7 +1627,6 @@ def test_execute(self, mock_hook): self.mock_ti.xcom_push.assert_called_once_with( key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, ) def test_missing_region_parameter(self): @@ -2400,7 +2390,7 @@ class TestDataProcSparkOperator(DataprocJobTestBase): @classmethod def setup_class(cls): cls.extra_links_expected_calls = [ - call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_JOB_CONF_EXPECTED), + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), ] @@ -2446,9 +2436,7 @@ def test_execute(self, mock_hook, mock_uuid): assert self.job == job op.execute(context=self.mock_context) - self.mock_ti.xcom_push.assert_called_once_with( - key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None - ) + self.mock_ti.xcom_push.assert_called_once_with(key="conf", value=DATAPROC_JOB_CONF_EXPECTED) # Test whether xcom push occurs before polling for job self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False) diff --git a/tests/providers/yandex/links/test_yq.py b/tests/providers/yandex/links/test_yq.py index 06f1e83939f9a..05de903120a22 100644 --- a/tests/providers/yandex/links/test_yq.py +++ b/tests/providers/yandex/links/test_yq.py @@ -35,7 +35,6 @@ def test_persist(): ti = mock_context["ti"] ti.xcom_push.assert_called_once_with( - execution_date=None, key="web_link", value="g.com", ) From 924f8f1112850e7be8c85f9c29eda1c6f64bebd1 Mon Sep 17 00:00:00 2001 From: dirrao Date: Wed, 28 Aug 2024 11:49:03 +0530 Subject: [PATCH 03/10] airflow.models.xcom deprecations removed --- airflow/models/xcom.py | 157 ++---------------- .../serialization/pydantic/taskinstance.py | 3 - 2 files changed, 10 insertions(+), 150 deletions(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 13df9fe84e28c..736e203d3a74f 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -21,7 +21,7 @@ import json import logging import pickle -from typing import TYPE_CHECKING, Any, Iterable, cast, overload +from typing import TYPE_CHECKING, Any, Iterable, cast from sqlalchemy import ( Column, @@ -127,8 +127,9 @@ def __repr__(self): return f'' return f'' - @overload @classmethod + @internal_api_call + @provide_session def set( cls, key: str, @@ -153,42 +154,6 @@ def set( :param session: Database session. If not given, a new session will be created for this function. """ - - @overload - @classmethod - def set( - cls, - key: str, - value: Any, - task_id: str, - dag_id: str, - session: Session = NEW_SESSION, - ) -> None: - """ - Store an XCom value. - - :sphinx-autoapi-skip: - """ - - @classmethod - @internal_api_call - @provide_session - def set( - cls, - key: str, - value: Any, - task_id: str, - dag_id: str, - session: Session = NEW_SESSION, - *, - run_id: str | None = None, - map_index: int = -1, - ) -> None: - """ - Store an XCom value. - - :sphinx-autoapi-skip: - """ from airflow.models.dagrun import DagRun if not run_id: @@ -284,8 +249,8 @@ def get_value( session=session, ) - @overload @staticmethod + @provide_session @internal_api_call def get_one( *, @@ -295,6 +260,7 @@ def get_one( run_id: str | None = None, map_index: int | None = None, session: Session = NEW_SESSION, + include_prior_dates: bool = False, ) -> Any | None: """ Retrieve an XCom value, optionally meeting certain criteria. @@ -324,41 +290,6 @@ def get_one( :param session: Database session. If not given, a new session will be created for this function. """ - - @overload - @staticmethod - @internal_api_call - def get_one( - key: str | None = None, - task_id: str | None = None, - dag_id: str | None = None, - include_prior_dates: bool = False, - session: Session = NEW_SESSION, - ) -> Any | None: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - :sphinx-autoapi-skip: - """ - - @staticmethod - @provide_session - @internal_api_call - def get_one( - key: str | None = None, - task_id: str | None = None, - dag_id: str | None = None, - include_prior_dates: bool = False, - session: Session = NEW_SESSION, - *, - run_id: str | None = None, - map_index: int | None = None, - ) -> Any | None: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - :sphinx-autoapi-skip: - """ if not run_id: raise ValueError("run_id must be passed") @@ -378,8 +309,11 @@ def get_one( return XCom.deserialize_value(result) return None - @overload + # The 'get_many` is not supported via database isolation mode. Attempting to use it in DB isolation + # mode will result in a crash - Resulting Query object cannot be **really** serialized + # TODO(potiuk) - document it in AIP-44 docs @staticmethod + @provide_session def get_many( *, run_id: str, @@ -413,46 +347,6 @@ def get_many( created for this function. :param limit: Limiting returning XComs """ - - @overload - @staticmethod - @internal_api_call - def get_many( - key: str | None = None, - task_ids: str | Iterable[str] | None = None, - dag_ids: str | Iterable[str] | None = None, - map_indexes: int | Iterable[int] | None = None, - include_prior_dates: bool = False, - limit: int | None = None, - session: Session = NEW_SESSION, - ) -> Query: - """ - Composes a query to get one or more XCom entries. - - :sphinx-autoapi-skip: - """ - - # The 'get_many` is not supported via database isolation mode. Attempting to use it in DB isolation - # mode will result in a crash - Resulting Query object cannot be **really** serialized - # TODO(potiuk) - document it in AIP-44 docs - @staticmethod - @provide_session - def get_many( - key: str | None = None, - task_ids: str | Iterable[str] | None = None, - dag_ids: str | Iterable[str] | None = None, - map_indexes: int | Iterable[int] | None = None, - include_prior_dates: bool = False, - limit: int | None = None, - session: Session = NEW_SESSION, - *, - run_id: str | None = None, - ) -> Query: - """ - Composes a query to get one or more XCom entries. - - :sphinx-autoapi-skip: - """ from airflow.models.dagrun import DagRun if not run_id: @@ -511,8 +405,8 @@ def purge(xcom: XCom, session: Session) -> None: """Purge an XCom entry from underlying storage implementations.""" pass - @overload @staticmethod + @provide_session @internal_api_call def clear( *, @@ -533,37 +427,6 @@ def clear( :param session: Database session. If not given, a new session will be created for this function. """ - - @overload - @staticmethod - @internal_api_call - def clear( - dag_id: str, - task_id: str, - session: Session = NEW_SESSION, - ) -> None: - """ - Clear all XCom data from the database for the given task instance. - - :sphinx-autoapi-skip: - """ - - @staticmethod - @provide_session - @internal_api_call - def clear( - dag_id: str | None = None, - task_id: str | None = None, - session: Session = NEW_SESSION, - *, - run_id: str | None = None, - map_index: int | None = None, - ) -> None: - """ - Clear all XCom data from the database for the given task instance. - - :sphinx-autoapi-skip: - """ # Given the historic order of this function (execution_date was first argument) to add a new optional # param we need to add default values for everything :( if dag_id is None: diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 0dcb7880eba7d..549b03680df83 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -203,7 +203,6 @@ def xcom_push( self, key: str, value: Any, - execution_date: datetime | None = None, session: Session | None = None, ) -> None: """ @@ -211,13 +210,11 @@ def xcom_push( :param key: the key to identify the XCom value :param value: the value of the XCom - :param execution_date: the execution date to push the XCom for """ return TaskInstance.xcom_push( self=self, # type: ignore[arg-type] key=key, value=value, - execution_date=execution_date, session=session, ) From fcc871d8107b6867d31f50f7f648703348908072 Mon Sep 17 00:00:00 2001 From: dirrao Date: Wed, 28 Aug 2024 13:27:39 +0530 Subject: [PATCH 04/10] airflow.models.xcom deprecations removed --- tests/providers/yandex/operators/test_yq.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/providers/yandex/operators/test_yq.py b/tests/providers/yandex/operators/test_yq.py index e342a2f961dfd..7da885bd67b3a 100644 --- a/tests/providers/yandex/operators/test_yq.py +++ b/tests/providers/yandex/operators/test_yq.py @@ -94,7 +94,6 @@ def test_execute_query(self, mock_get_connection): call( key="web_link", value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", - execution_date=None, ), ] ) From f601eabf4e690fa37d395f5b3d365a647c193d31 Mon Sep 17 00:00:00 2001 From: dirrao Date: Fri, 30 Aug 2024 17:21:18 +0530 Subject: [PATCH 05/10] xcom backward compitability tests support --- airflow/models/xcom.py | 5 +- .../amazon/aws/links/test_base_aws.py | 12 ++- .../cloud/operators/test_bigquery_dts.py | 11 ++- .../google/cloud/operators/test_dataproc.py | 83 ++++++++++++++----- tests/providers/yandex/links/test_yq.py | 12 ++- tests/providers/yandex/operators/test_yq.py | 29 +++++-- tests/test_utils/compat.py | 1 + 7 files changed, 111 insertions(+), 42 deletions(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 736e203d3a74f..87c72d5bf7f53 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -257,7 +257,7 @@ def get_one( key: str | None = None, dag_id: str | None = None, task_id: str | None = None, - run_id: str | None = None, + run_id: str, map_index: int | None = None, session: Session = NEW_SESSION, include_prior_dates: bool = False, @@ -290,9 +290,6 @@ def get_one( :param session: Database session. If not given, a new session will be created for this function. """ - if not run_id: - raise ValueError("run_id must be passed") - query = BaseXCom.get_many( run_id=run_id, key=key, diff --git a/tests/providers/amazon/aws/links/test_base_aws.py b/tests/providers/amazon/aws/links/test_base_aws.py index ba5f6e0014b36..1afcfea0a826f 100644 --- a/tests/providers/amazon/aws/links/test_base_aws.py +++ b/tests/providers/amazon/aws/links/test_base_aws.py @@ -25,6 +25,7 @@ from airflow.models.xcom import XCom from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink from airflow.serialization.serialized_objects import SerializedDAG +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.mock_operators import MockOperator if TYPE_CHECKING: @@ -75,10 +76,13 @@ def test_persist(self, region_name, aws_partition, keywords, expected_value): ) ti = mock_context["ti"] - ti.xcom_push.assert_called_once_with( - key=XCOM_KEY, - value=expected_value, - ) + if AIRFLOW_V_3_0_PLUS: + ti.xcom_push.assert_called_once_with( + key=XCOM_KEY, + value=expected_value, + ) + else: + ti.xcom_push.assert_called_once_with(key=XCOM_KEY, value=expected_value, execution_date=None) def test_disable_xcom_push(self): mock_context = mock.MagicMock() diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py b/tests/providers/google/cloud/operators/test_bigquery_dts.py index b89d3d7463e6b..f44479bbce9e4 100644 --- a/tests/providers/google/cloud/operators/test_bigquery_dts.py +++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py @@ -27,6 +27,7 @@ BigQueryDataTransferServiceStartTransferRunsOperator, BigQueryDeleteDataTransferConfigOperator, ) +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS PROJECT_ID = "id" @@ -71,7 +72,10 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) - ti.xcom_push.assert_called_with(key="transfer_config_id", value="1a2b3c") + if AIRFLOW_V_3_0_PLUS: + ti.xcom_push.assert_called_with(key="transfer_config_id", value="1a2b3c") + else: + ti.xcom_push.assert_called_with(key="transfer_config_id", value="1a2b3c", execution_date=None) assert "secret_access_key" not in return_value.get("params", {}) assert "access_key_id" not in return_value.get("params", {}) @@ -126,7 +130,10 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) - ti.xcom_push.assert_called_with(key="run_id", value="123") + if AIRFLOW_V_3_0_PLUS: + ti.xcom_push.assert_called_with(key="run_id", value="123") + else: + ti.xcom_push.assert_called_with(key="run_id", value="123", execution_date=None) @mock.patch( f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook", diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index f711d45235817..d0d68fc0d97ad 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -79,7 +79,7 @@ from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.timezone import datetime -from tests.test_utils.compat import AIRFLOW_VERSION +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_VERSION from tests.test_utils.db import clear_db_runs, clear_db_xcom AIRFLOW_VERSION_LABEL = "v" + str(AIRFLOW_VERSION).replace(".", "-").replace("+", "-") @@ -758,10 +758,17 @@ def test_execute(self, mock_hook, to_dict_mock): self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation()) - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_cluster", - value=DATAPROC_CLUSTER_EXPECTED, - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + ) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + execution_date=None, + ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -807,10 +814,17 @@ def test_execute_in_gke(self, mock_hook, to_dict_mock): self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation()) - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_cluster", - value=DATAPROC_CLUSTER_EXPECTED, - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + ) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + execution_date=None, + ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -1140,10 +1154,17 @@ def test_execute(self, mock_hook): # Test whether xcom push occurs before cluster is updated self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - self.mock_ti.xcom_push.assert_called_once_with( - key="conf", - value=DATAPROC_CLUSTER_CONF_EXPECTED, - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with( + key="conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + ) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None, + ) @pytest.mark.db_test @@ -1353,7 +1374,12 @@ def test_execute(self, mock_hook): job_id=TEST_JOB_ID, project_id=GCP_PROJECT, region=GCP_REGION, timeout=None ) - self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_async(self, mock_hook): @@ -1391,7 +1417,12 @@ def test_execute_async(self, mock_hook): ) mock_hook.return_value.wait_for_job.assert_not_called() - self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) @@ -1624,10 +1655,17 @@ def test_execute(self, mock_hook): # Test whether the xcom push happens before updating the cluster self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_cluster", - value=DATAPROC_CLUSTER_EXPECTED, - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + ) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + execution_date=None, + ) def test_missing_region_parameter(self): with pytest.raises(AirflowException): @@ -2436,7 +2474,12 @@ def test_execute(self, mock_hook, mock_uuid): assert self.job == job op.execute(context=self.mock_context) - self.mock_ti.xcom_push.assert_called_once_with(key="conf", value=DATAPROC_JOB_CONF_EXPECTED) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with(key="conf", value=DATAPROC_JOB_CONF_EXPECTED) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None + ) # Test whether xcom push occurs before polling for job self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False) diff --git a/tests/providers/yandex/links/test_yq.py b/tests/providers/yandex/links/test_yq.py index 05de903120a22..d46862f1c737f 100644 --- a/tests/providers/yandex/links/test_yq.py +++ b/tests/providers/yandex/links/test_yq.py @@ -23,6 +23,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XCom from airflow.providers.yandex.links.yq import YQLink +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.mock_operators import MockOperator yandexcloud = pytest.importorskip("yandexcloud") @@ -34,10 +35,13 @@ def test_persist(): YQLink.persist(context=mock_context, task_instance=MockOperator(task_id="test_task_id"), web_link="g.com") ti = mock_context["ti"] - ti.xcom_push.assert_called_once_with( - key="web_link", - value="g.com", - ) + if AIRFLOW_V_3_0_PLUS: + ti.xcom_push.assert_called_once_with( + key="web_link", + value="g.com", + ) + else: + ti.xcom_push.assert_called_once_with(key="web_link", value="g.com", execution_date=None) def test_default_link(): diff --git a/tests/providers/yandex/operators/test_yq.py b/tests/providers/yandex/operators/test_yq.py index 7da885bd67b3a..034f0505517ba 100644 --- a/tests/providers/yandex/operators/test_yq.py +++ b/tests/providers/yandex/operators/test_yq.py @@ -22,6 +22,8 @@ import pytest +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + yandexcloud = pytest.importorskip("yandexcloud") import responses @@ -89,14 +91,25 @@ def test_execute_query(self, mock_get_connection): results = operator.execute(context) assert results == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} - context["ti"].xcom_push.assert_has_calls( - [ - call( - key="web_link", - value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", - ), - ] - ) + if AIRFLOW_V_3_0_PLUS: + context["ti"].xcom_push.assert_has_calls( + [ + call( + key="web_link", + value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", + ), + ] + ) + else: + context["ti"].xcom_push.assert_has_calls( + [ + call( + key="web_link", + value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", + execution_date=None, + ), + ] + ) responses.get( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", diff --git a/tests/test_utils/compat.py b/tests/test_utils/compat.py index b09973903b4e2..5daf429cf641f 100644 --- a/tests/test_utils/compat.py +++ b/tests/test_utils/compat.py @@ -46,6 +46,7 @@ AIRFLOW_V_2_8_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.8.0") AIRFLOW_V_2_9_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.9.0") AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.10.0") +AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") try: from airflow.models.baseoperatorlink import BaseOperatorLink From 0507332ebba2944603de9aadcfacd74fc04cb965 Mon Sep 17 00:00:00 2001 From: dirrao Date: Sat, 31 Aug 2024 09:29:40 +0530 Subject: [PATCH 06/10] xcom backward compitability tests support --- airflow/models/taskinstance.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 049a01954dd47..165f5c7987305 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -79,7 +79,6 @@ AirflowSkipException, AirflowTaskTerminated, AirflowTaskTimeout, - RemovedInAirflow3Warning, TaskDeferralError, TaskDeferred, UnmappableXComLengthPushed, From b6dd849195d567eaec12b195117b00716058089b Mon Sep 17 00:00:00 2001 From: dirrao Date: Sat, 31 Aug 2024 11:40:13 +0530 Subject: [PATCH 07/10] xcom backward compitability tests support --- .../google/cloud/operators/test_dataproc.py | 35 +++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index d0d68fc0d97ad..fbc71e4b7e5fd 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -440,19 +440,31 @@ def tearDownClass(cls): class DataprocJobTestBase(DataprocTestBase): @classmethod def setup_class(cls): - cls.extra_links_expected_calls = [ + if AIRFLOW_V_3_0_PLUS: + cls.extra_links_expected_calls = [ call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), ] + else: + cls.extra_links_expected_calls = [ + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] class DataprocClusterTestBase(DataprocTestBase): @classmethod def setup_class(cls): super().setup_class() - cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) - ] + if AIRFLOW_V_3_0_PLUS: + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) + ] + else: + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, + execution_date=None) + ] class TestsClusterGenerator: @@ -1107,9 +1119,14 @@ class TestDataprocClusterScaleOperator(DataprocClusterTestBase): @classmethod def setup_class(cls): super().setup_class() - cls.extra_links_expected_calls_base = [ + if AIRFLOW_V_3_0_PLUS: + cls.extra_links_expected_calls_base = [ call.ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) ] + else: + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None) + ] def test_deprecation_warning(self): with pytest.warns(AirflowProviderDeprecationWarning) as warnings: @@ -2427,10 +2444,16 @@ def test_builder(self, mock_hook, mock_uuid): class TestDataProcSparkOperator(DataprocJobTestBase): @classmethod def setup_class(cls): - cls.extra_links_expected_calls = [ + if AIRFLOW_V_3_0_PLUS: + cls.extra_links_expected_calls = [ call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), ] + else: + cls.extra_links_expected_calls = [ + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] From 094256bde6393cc1063b8218d45768d2d4ac4495 Mon Sep 17 00:00:00 2001 From: dirrao Date: Sat, 31 Aug 2024 11:58:02 +0530 Subject: [PATCH 08/10] xcom backward compitability tests support --- .../google/cloud/operators/test_dataproc.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index fbc71e4b7e5fd..557765d698af1 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -442,9 +442,9 @@ class DataprocJobTestBase(DataprocTestBase): def setup_class(cls): if AIRFLOW_V_3_0_PLUS: cls.extra_links_expected_calls = [ - call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), - call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), - ] + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] else: cls.extra_links_expected_calls = [ call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None), @@ -462,8 +462,9 @@ def setup_class(cls): ] else: cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None) + call.ti.xcom_push( + key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, execution_date=None + ) ] @@ -1121,8 +1122,8 @@ def setup_class(cls): super().setup_class() if AIRFLOW_V_3_0_PLUS: cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) - ] + call.ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + ] else: cls.extra_links_expected_calls_base = [ call.ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None) @@ -2446,9 +2447,9 @@ class TestDataProcSparkOperator(DataprocJobTestBase): def setup_class(cls): if AIRFLOW_V_3_0_PLUS: cls.extra_links_expected_calls = [ - call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), - call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), - ] + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] else: cls.extra_links_expected_calls = [ call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None), From 862c56072a6a32f3708071110aab5665eecfacf4 Mon Sep 17 00:00:00 2001 From: dirrao Date: Sat, 31 Aug 2024 12:14:46 +0530 Subject: [PATCH 09/10] xcom backward compitability tests support --- tests/providers/google/cloud/operators/test_dataproc.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 557765d698af1..58b38125ee1d3 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -1346,7 +1346,12 @@ def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): - xcom_push_call = call.ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + if AIRFLOW_V_3_0_PLUS: + xcom_push_call = call.ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + else: + xcom_push_call = call.ti.xcom_push( + key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None + ) wait_for_job_call = call.hook().wait_for_job( job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT, timeout=None ) From 062ddddb6e1f942f4211e3fa0dceeaff9e774d24 Mon Sep 17 00:00:00 2001 From: dirrao Date: Sun, 1 Sep 2024 18:13:38 +0530 Subject: [PATCH 10/10] xcom backward compitability tests support --- tests/providers/microsoft/conftest.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index a33e803df8c09..8a258735291f1 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -144,12 +144,7 @@ def xcom_pull( return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}") return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") - def xcom_push( - self, - key: str, - value: Any, - session: Session = NEW_SESSION, - ) -> None: + def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value values["ti"] = MockedTaskInstance(task=task)