diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 7d18aa1474451..2347289428512 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1592,7 +1592,6 @@ def xcom_push( context: Any, key: str, value: Any, - execution_date: datetime | None = None, ) -> None: """ Make an XCom available for tasks to pull. @@ -1601,11 +1600,8 @@ def xcom_push( :param key: A key for the XCom :param value: A value for the XCom. The value is pickled and stored in the database. - :param execution_date: if provided, the XCom will not be visible until - this date. This can be used, for example, to send a message to a - task on a future date without it being immediately visible. """ - context["ti"].xcom_push(key=key, value=value, execution_date=execution_date) + context["ti"].xcom_push(key=key, value=value) @staticmethod @provide_session diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 0d633f8bf3d38..165f5c7987305 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -79,7 +79,6 @@ AirflowSkipException, AirflowTaskTerminated, AirflowTaskTimeout, - RemovedInAirflow3Warning, TaskDeferralError, TaskDeferred, UnmappableXComLengthPushed, @@ -3473,7 +3472,6 @@ def xcom_push( self, key: str, value: Any, - execution_date: datetime | None = None, session: Session = NEW_SESSION, ) -> None: """ @@ -3483,19 +3481,7 @@ def xcom_push( :param value: Value to store. What types are possible depends on whether ``enable_xcom_pickling`` is true or not. If so, this can be any picklable object; only be JSON-serializable may be used otherwise. - :param execution_date: Deprecated parameter that has no effect. - """ - if execution_date is not None: - self_execution_date = self.get_dagrun(session).execution_date - if execution_date < self_execution_date: - raise ValueError( - f"execution_date can not be in the past (current execution_date is " - f"{self_execution_date}; received {execution_date})" - ) - elif execution_date is not None: - message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) - + """ XCom.set( key=key, value=value, diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 9829f11fbbde7..87c72d5bf7f53 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -21,9 +21,7 @@ import json import logging import pickle -import warnings -from functools import wraps -from typing import TYPE_CHECKING, Any, Iterable, cast, overload +from typing import TYPE_CHECKING, Any, Iterable, cast from sqlalchemy import ( Column, @@ -40,15 +38,13 @@ from sqlalchemy.dialects.mysql import LONGBLOB from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Query, reconstructor, relationship -from sqlalchemy.orm.exc import NoResultFound from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf -from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies from airflow.utils import timezone from airflow.utils.db import LazySelectSequence -from airflow.utils.helpers import exactly_one, is_container +from airflow.utils.helpers import is_container from airflow.utils.json import XComDecoder, XComEncoder from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session @@ -64,9 +60,6 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: - import datetime - - import pendulum from sqlalchemy.engine import Row from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause @@ -134,8 +127,9 @@ def __repr__(self): return f'' return f'' - @overload @classmethod + @internal_api_call + @provide_session def set( cls, key: str, @@ -150,9 +144,6 @@ def set( """ Store an XCom value. - A deprecated form of this function accepts ``execution_date`` instead of - ``run_id``. The two arguments are mutually exclusive. - :param key: Key to store the XCom. :param value: XCom value to store. :param dag_id: DAG ID. @@ -163,67 +154,14 @@ def set( :param session: Database session. If not given, a new session will be created for this function. """ - - @overload - @classmethod - def set( - cls, - key: str, - value: Any, - task_id: str, - dag_id: str, - execution_date: datetime.datetime, - session: Session = NEW_SESSION, - ) -> None: - """ - Store an XCom value. - - :sphinx-autoapi-skip: - """ - - @classmethod - @internal_api_call - @provide_session - def set( - cls, - key: str, - value: Any, - task_id: str, - dag_id: str, - execution_date: datetime.datetime | None = None, - session: Session = NEW_SESSION, - *, - run_id: str | None = None, - map_index: int = -1, - ) -> None: - """ - Store an XCom value. - - :sphinx-autoapi-skip: - """ from airflow.models.dagrun import DagRun - if not exactly_one(execution_date is not None, run_id is not None): - raise ValueError( - f"Exactly one of run_id or execution_date must be passed. " - f"Passed execution_date={execution_date}, run_id={run_id}" - ) + if not run_id: + raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - if run_id is None: - message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) - try: - dag_run_id, run_id = ( - session.query(DagRun.id, DagRun.run_id) - .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) - .one() - ) - except NoResultFound: - raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None - else: - dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() - if dag_run_id is None: - raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") + dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() + if dag_run_id is None: + raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") # Seamlessly resolve LazySelectSequence to a list. This intends to work # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if @@ -242,7 +180,7 @@ def set( "return value" if key == XCOM_RETURN_KEY else f"value {key}", task_id, dag_id, - run_id or execution_date, + run_id, ) value = list(value) @@ -311,17 +249,18 @@ def get_value( session=session, ) - @overload @staticmethod + @provide_session @internal_api_call def get_one( *, key: str | None = None, dag_id: str | None = None, task_id: str | None = None, - run_id: str | None = None, + run_id: str, map_index: int | None = None, session: Session = NEW_SESSION, + include_prior_dates: bool = False, ) -> Any | None: """ Retrieve an XCom value, optionally meeting certain criteria. @@ -333,9 +272,6 @@ def get_one( If there are no results, *None* is returned. If multiple XCom entries match the criteria, an arbitrary one is returned. - A deprecated form of this function accepts ``execution_date`` instead of - ``run_id``. The two arguments are mutually exclusive. - .. seealso:: ``get_value()`` is a convenience function if you already have a structured TaskInstance or TaskInstanceKey object available. @@ -354,83 +290,27 @@ def get_one( :param session: Database session. If not given, a new session will be created for this function. """ - - @overload - @staticmethod - @internal_api_call - def get_one( - execution_date: datetime.datetime, - key: str | None = None, - task_id: str | None = None, - dag_id: str | None = None, - include_prior_dates: bool = False, - session: Session = NEW_SESSION, - ) -> Any | None: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - :sphinx-autoapi-skip: - """ - - @staticmethod - @provide_session - @internal_api_call - def get_one( - execution_date: datetime.datetime | None = None, - key: str | None = None, - task_id: str | None = None, - dag_id: str | None = None, - include_prior_dates: bool = False, - session: Session = NEW_SESSION, - *, - run_id: str | None = None, - map_index: int | None = None, - ) -> Any | None: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - :sphinx-autoapi-skip: - """ - if not exactly_one(execution_date is not None, run_id is not None): - raise ValueError("Exactly one of run_id or execution_date must be passed") - - if run_id: - query = BaseXCom.get_many( - run_id=run_id, - key=key, - task_ids=task_id, - dag_ids=dag_id, - map_indexes=map_index, - include_prior_dates=include_prior_dates, - limit=1, - session=session, - ) - elif execution_date is not None: - message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - query = BaseXCom.get_many( - execution_date=execution_date, - key=key, - task_ids=task_id, - dag_ids=dag_id, - map_indexes=map_index, - include_prior_dates=include_prior_dates, - limit=1, - session=session, - ) - else: - raise RuntimeError("Should not happen?") + query = BaseXCom.get_many( + run_id=run_id, + key=key, + task_ids=task_id, + dag_ids=dag_id, + map_indexes=map_index, + include_prior_dates=include_prior_dates, + limit=1, + session=session, + ) result = query.with_entities(BaseXCom.value).first() if result: return XCom.deserialize_value(result) return None - @overload + # The 'get_many` is not supported via database isolation mode. Attempting to use it in DB isolation + # mode will result in a crash - Resulting Query object cannot be **really** serialized + # TODO(potiuk) - document it in AIP-44 docs @staticmethod + @provide_session def get_many( *, run_id: str, @@ -448,9 +328,6 @@ def get_many( This function returns an SQLAlchemy query of full XCom objects. If you just want one stored value, use :meth:`get_one` instead. - A deprecated form of this function accepts ``execution_date`` instead of - ``run_id``. The two arguments are mutually exclusive. - :param run_id: DAG run ID for the task. :param key: A key for the XComs. If provided, only XComs with matching keys will be returned. Pass *None* (default) to remove the filter. @@ -467,58 +344,10 @@ def get_many( created for this function. :param limit: Limiting returning XComs """ - - @overload - @staticmethod - @internal_api_call - def get_many( - execution_date: datetime.datetime, - key: str | None = None, - task_ids: str | Iterable[str] | None = None, - dag_ids: str | Iterable[str] | None = None, - map_indexes: int | Iterable[int] | None = None, - include_prior_dates: bool = False, - limit: int | None = None, - session: Session = NEW_SESSION, - ) -> Query: - """ - Composes a query to get one or more XCom entries. - - :sphinx-autoapi-skip: - """ - - # The 'get_many` is not supported via database isolation mode. Attempting to use it in DB isolation - # mode will result in a crash - Resulting Query object cannot be **really** serialized - # TODO(potiuk) - document it in AIP-44 docs - @staticmethod - @provide_session - def get_many( - execution_date: datetime.datetime | None = None, - key: str | None = None, - task_ids: str | Iterable[str] | None = None, - dag_ids: str | Iterable[str] | None = None, - map_indexes: int | Iterable[int] | None = None, - include_prior_dates: bool = False, - limit: int | None = None, - session: Session = NEW_SESSION, - *, - run_id: str | None = None, - ) -> Query: - """ - Composes a query to get one or more XCom entries. - - :sphinx-autoapi-skip: - """ from airflow.models.dagrun import DagRun - if not exactly_one(execution_date is not None, run_id is not None): - raise ValueError( - f"Exactly one of run_id or execution_date must be passed. " - f"Passed execution_date={execution_date}, run_id={run_id}" - ) - if execution_date is not None: - message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) + if not run_id: + raise ValueError(f"run_id must be passed. Passed run_id={run_id}") query = session.query(BaseXCom).join(BaseXCom.dag_run) @@ -545,13 +374,8 @@ def get_many( query = query.filter(BaseXCom.map_index == map_indexes) if include_prior_dates: - if execution_date is not None: - query = query.filter(DagRun.execution_date <= execution_date) - else: - dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery() - query = query.filter(BaseXCom.execution_date <= dr.c.execution_date) - elif execution_date is not None: - query = query.filter(DagRun.execution_date == execution_date) + dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery() + query = query.filter(BaseXCom.execution_date <= dr.c.execution_date) else: query = query.filter(BaseXCom.run_id == run_id) @@ -578,8 +402,8 @@ def purge(xcom: XCom, session: Session) -> None: """Purge an XCom entry from underlying storage implementations.""" pass - @overload @staticmethod + @provide_session @internal_api_call def clear( *, @@ -592,9 +416,6 @@ def clear( """ Clear all XCom data from the database for the given task instance. - A deprecated form of this function accepts ``execution_date`` instead of - ``run_id``. The two arguments are mutually exclusive. - :param dag_id: ID of DAG to clear the XCom for. :param task_id: ID of task to clear the XCom for. :param run_id: ID of DAG run to clear the XCom for. @@ -603,41 +424,6 @@ def clear( :param session: Database session. If not given, a new session will be created for this function. """ - - @overload - @staticmethod - @internal_api_call - def clear( - execution_date: pendulum.DateTime, - dag_id: str, - task_id: str, - session: Session = NEW_SESSION, - ) -> None: - """ - Clear all XCom data from the database for the given task instance. - - :sphinx-autoapi-skip: - """ - - @staticmethod - @provide_session - @internal_api_call - def clear( - execution_date: pendulum.DateTime | None = None, - dag_id: str | None = None, - task_id: str | None = None, - session: Session = NEW_SESSION, - *, - run_id: str | None = None, - map_index: int | None = None, - ) -> None: - """ - Clear all XCom data from the database for the given task instance. - - :sphinx-autoapi-skip: - """ - from airflow.models import DagRun - # Given the historic order of this function (execution_date was first argument) to add a new optional # param we need to add default values for everything :( if dag_id is None: @@ -645,20 +431,8 @@ def clear( if task_id is None: raise TypeError("clear() missing required argument: task_id") - if not exactly_one(execution_date is not None, run_id is not None): - raise ValueError( - f"Exactly one of run_id or execution_date must be passed. " - f"Passed execution_date={execution_date}, run_id={run_id}" - ) - - if execution_date is not None: - message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead." - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) - run_id = ( - session.query(DagRun.run_id) - .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) - .scalar() - ) + if not run_id: + raise ValueError(f"run_id must be passed. Passed run_id={run_id}") query = session.query(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) if map_index is not None: @@ -747,33 +521,6 @@ def _process_row(row: Row) -> Any: return XCom.deserialize_value(row) -def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None: - """ - Patch a custom ``serialize_value`` to accept the modern signature. - - To give custom XCom backends more flexibility with how they store values, we - now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``. - In order to maintain compatibility with custom XCom backends written with - the old signature, we check the signature and, if necessary, patch with a - method that ignores kwargs the backend does not accept. - """ - old_serializer = clazz.serialize_value - - @wraps(old_serializer) - def _shim(**kwargs): - kwargs = {k: kwargs.get(k) for k in params} - warnings.warn( - f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and" - f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be " - f"removed in a future release.", - RemovedInAirflow3Warning, - stacklevel=1, - ) - return old_serializer(**kwargs) - - clazz.serialize_value = _shim # type: ignore[assignment] - - def _get_function_params(function) -> list[str]: """ Return the list of variables names of a function. @@ -801,10 +548,6 @@ def resolve_xcom_backend() -> type[BaseXCom]: raise TypeError( f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`." ) - base_xcom_params = _get_function_params(BaseXCom.serialize_value) - xcom_params = _get_function_params(clazz.serialize_value) - if set(base_xcom_params) != set(xcom_params): - _patch_outdated_serializer(clazz=clazz, params=xcom_params) return clazz diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 0dcb7880eba7d..549b03680df83 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -203,7 +203,6 @@ def xcom_push( self, key: str, value: Any, - execution_date: datetime | None = None, session: Session | None = None, ) -> None: """ @@ -211,13 +210,11 @@ def xcom_push( :param key: the key to identify the XCom value :param value: the value of the XCom - :param execution_date: the execution date to push the XCom for """ return TaskInstance.xcom_push( self=self, # type: ignore[arg-type] key=key, value=value, - execution_date=execution_date, session=session, ) diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index e9db3d946d8e6..07533ec944be8 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -214,33 +214,6 @@ def test_get_one_custom_backend_no_use_orm_deserialize_value(self, task_instance assert value == {"key": "value"} XCom.orm_deserialize_value.assert_not_called() - @pytest.mark.skip_if_database_isolation_mode - @conf_vars({("core", "enable_xcom_pickling"): "False"}) - @mock.patch("airflow.models.xcom.conf.getimport") - def test_set_serialize_call_old_signature(self, get_import, task_instance): - """ - When XCom.serialize_value takes only param ``value``, other kwargs should be ignored. - """ - serialize_watcher = MagicMock() - - class OldSignatureXCom(BaseXCom): - @staticmethod - def serialize_value(value, **kwargs): - serialize_watcher(value=value, **kwargs) - return json.dumps(value).encode("utf-8") - - get_import.return_value = OldSignatureXCom - - XCom = resolve_xcom_backend() - XCom.set( - key=XCOM_RETURN_KEY, - value={"my_xcom_key": "my_xcom_value"}, - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - run_id=task_instance.run_id, - ) - serialize_watcher.assert_called_once_with(value={"my_xcom_key": "my_xcom_value"}) - @pytest.mark.skip_if_database_isolation_mode @conf_vars({("core", "enable_xcom_pickling"): "False"}) @mock.patch("airflow.models.xcom.conf.getimport") @@ -335,19 +308,6 @@ def test_xcom_get_one(self, session, task_instance): ) assert stored_value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_get_one") - def test_xcom_get_one_with_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - stored_value = XCom.get_one( - key="xcom_1", - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - execution_date=task_instance.execution_date, - session=session, - ) - assert stored_value == {"key": "value"} - @pytest.fixture def tis_for_xcom_get_one_from_prior_date(self, task_instance_factory, push_simple_json_xcom): date1 = timezone.datetime(2021, 12, 3, 4, 56) @@ -376,24 +336,6 @@ def test_xcom_get_one_from_prior_date(self, session, tis_for_xcom_get_one_from_p ) assert retrieved_value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode - def test_xcom_get_one_from_prior_with_execution_date( - self, - session, - tis_for_xcom_get_one_from_prior_date, - ): - _, ti2 = tis_for_xcom_get_one_from_prior_date - with pytest.deprecated_call(): - retrieved_value = XCom.get_one( - execution_date=ti2.execution_date, - key="xcom_1", - task_id="task_1", - dag_id="dag", - include_prior_dates=True, - session=session, - ) - assert retrieved_value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode @pytest.fixture def setup_for_xcom_get_many_single_argument_value(self, task_instance, push_simple_json_xcom): @@ -413,21 +355,6 @@ def test_xcom_get_many_single_argument_value(self, session, task_instance): assert stored_xcoms[0].key == "xcom_1" assert stored_xcoms[0].value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_get_many_single_argument_value") - def test_xcom_get_many_single_argument_value_with_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - stored_xcoms = XCom.get_many( - execution_date=task_instance.execution_date, - key="xcom_1", - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - session=session, - ).all() - assert len(stored_xcoms) == 1 - assert stored_xcoms[0].key == "xcom_1" - assert stored_xcoms[0].value == {"key": "value"} - @pytest.mark.skip_if_database_isolation_mode @pytest.fixture def setup_for_xcom_get_many_multiple_tasks(self, task_instances, push_simple_json_xcom): @@ -448,20 +375,6 @@ def test_xcom_get_many_multiple_tasks(self, session, task_instance): sorted_values = [x.value for x in sorted(stored_xcoms, key=operator.attrgetter("task_id"))] assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}] - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_get_many_multiple_tasks") - def test_xcom_get_many_multiple_tasks_with_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - stored_xcoms = XCom.get_many( - execution_date=task_instance.execution_date, - key="xcom_1", - dag_ids=task_instance.dag_id, - task_ids=["task_1", "task_2"], - session=session, - ) - sorted_values = [x.value for x in sorted(stored_xcoms, key=operator.attrgetter("task_id"))] - assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}] - @pytest.fixture def tis_for_xcom_get_many_from_prior_dates(self, task_instance_factory, push_simple_json_xcom): date1 = timezone.datetime(2021, 12, 3, 4, 56) @@ -488,27 +401,6 @@ def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_fro assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}] assert [x.execution_date for x in stored_xcoms] == [ti2.execution_date, ti1.execution_date] - @pytest.mark.skip_if_database_isolation_mode - def test_xcom_get_many_from_prior_dates_with_execution_date( - self, - session, - tis_for_xcom_get_many_from_prior_dates, - ): - ti1, ti2 = tis_for_xcom_get_many_from_prior_dates - with pytest.deprecated_call(): - stored_xcoms = XCom.get_many( - execution_date=ti2.execution_date, - key="xcom_1", - dag_ids="dag", - task_ids="task_1", - include_prior_dates=True, - session=session, - ) - - # The retrieved XComs should be ordered by logical date, latest first. - assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}] - assert [x.execution_date for x in stored_xcoms] == [ti2.execution_date, ti1.execution_date] - @pytest.mark.usefixtures("setup_xcom_pickling") class TestXComSet: @@ -528,24 +420,6 @@ def test_xcom_set(self, session, task_instance): assert stored_xcoms[0].task_id == "task_1" assert stored_xcoms[0].execution_date == task_instance.execution_date - @pytest.mark.skip_if_database_isolation_mode - def test_xcom_set_with_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - XCom.set( - key="xcom_1", - value={"key": "value"}, - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - execution_date=task_instance.execution_date, - session=session, - ) - stored_xcoms = session.query(XCom).all() - assert stored_xcoms[0].key == "xcom_1" - assert stored_xcoms[0].value == {"key": "value"} - assert stored_xcoms[0].dag_id == "dag" - assert stored_xcoms[0].task_id == "task_1" - assert stored_xcoms[0].execution_date == task_instance.execution_date - @pytest.fixture def setup_for_xcom_set_again_replace(self, task_instance, push_simple_json_xcom): push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key1": "value1"}) @@ -563,21 +437,6 @@ def test_xcom_set_again_replace(self, session, task_instance): ) assert session.query(XCom).one().value == {"key2": "value2"} - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_set_again_replace") - def test_xcom_set_again_replace_with_execution_date(self, session, task_instance): - assert session.query(XCom).one().value == {"key1": "value1"} - with pytest.deprecated_call(): - XCom.set( - key="xcom_1", - value={"key2": "value2"}, - dag_id=task_instance.dag_id, - task_id="task_1", - execution_date=task_instance.execution_date, - session=session, - ) - assert session.query(XCom).one().value == {"key2": "value2"} - @pytest.mark.usefixtures("setup_xcom_pickling") class TestXComClear: @@ -598,19 +457,6 @@ def test_xcom_clear(self, mock_purge, session, task_instance): assert session.query(XCom).count() == 0 assert mock_purge.call_count == 0 if is_db_isolation_mode() else 1 - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_clear") - def test_xcom_clear_with_execution_date(self, session, task_instance): - assert session.query(XCom).count() == 1 - with pytest.deprecated_call(): - XCom.clear( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - execution_date=task_instance.execution_date, - session=session, - ) - assert session.query(XCom).count() == 0 - @pytest.mark.usefixtures("setup_for_xcom_clear") def test_xcom_clear_different_run(self, session, task_instance): XCom.clear( @@ -620,15 +466,3 @@ def test_xcom_clear_different_run(self, session, task_instance): session=session, ) assert session.query(XCom).count() == 1 - - @pytest.mark.skip_if_database_isolation_mode - @pytest.mark.usefixtures("setup_for_xcom_clear") - def test_xcom_clear_different_execution_date(self, session, task_instance): - with pytest.deprecated_call(): - XCom.clear( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - execution_date=timezone.utcnow(), - session=session, - ) - assert session.query(XCom).count() == 1 diff --git a/tests/providers/amazon/aws/links/test_base_aws.py b/tests/providers/amazon/aws/links/test_base_aws.py index 446d584edf358..1afcfea0a826f 100644 --- a/tests/providers/amazon/aws/links/test_base_aws.py +++ b/tests/providers/amazon/aws/links/test_base_aws.py @@ -25,6 +25,7 @@ from airflow.models.xcom import XCom from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink from airflow.serialization.serialized_objects import SerializedDAG +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.mock_operators import MockOperator if TYPE_CHECKING: @@ -75,11 +76,13 @@ def test_persist(self, region_name, aws_partition, keywords, expected_value): ) ti = mock_context["ti"] - ti.xcom_push.assert_called_once_with( - execution_date=None, - key=XCOM_KEY, - value=expected_value, - ) + if AIRFLOW_V_3_0_PLUS: + ti.xcom_push.assert_called_once_with( + key=XCOM_KEY, + value=expected_value, + ) + else: + ti.xcom_push.assert_called_once_with(key=XCOM_KEY, value=expected_value, execution_date=None) def test_disable_xcom_push(self): mock_context = mock.MagicMock() diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py b/tests/providers/google/cloud/operators/test_bigquery_dts.py index b3145151d3da6..f44479bbce9e4 100644 --- a/tests/providers/google/cloud/operators/test_bigquery_dts.py +++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py @@ -27,6 +27,7 @@ BigQueryDataTransferServiceStartTransferRunsOperator, BigQueryDeleteDataTransferConfigOperator, ) +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS PROJECT_ID = "id" @@ -71,7 +72,10 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) - ti.xcom_push.assert_called_with(execution_date=None, key="transfer_config_id", value="1a2b3c") + if AIRFLOW_V_3_0_PLUS: + ti.xcom_push.assert_called_with(key="transfer_config_id", value="1a2b3c") + else: + ti.xcom_push.assert_called_with(key="transfer_config_id", value="1a2b3c", execution_date=None) assert "secret_access_key" not in return_value.get("params", {}) assert "access_key_id" not in return_value.get("params", {}) @@ -126,7 +130,10 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) - ti.xcom_push.assert_called_with(execution_date=None, key="run_id", value="123") + if AIRFLOW_V_3_0_PLUS: + ti.xcom_push.assert_called_with(key="run_id", value="123") + else: + ti.xcom_push.assert_called_with(key="run_id", value="123", execution_date=None) @mock.patch( f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook", diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 1d1f2a1ef818a..58b38125ee1d3 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -79,7 +79,7 @@ from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.timezone import datetime -from tests.test_utils.compat import AIRFLOW_VERSION +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_VERSION from tests.test_utils.db import clear_db_runs, clear_db_xcom AIRFLOW_VERSION_LABEL = "v" + str(AIRFLOW_VERSION).replace(".", "-").replace("+", "-") @@ -440,19 +440,32 @@ def tearDownClass(cls): class DataprocJobTestBase(DataprocTestBase): @classmethod def setup_class(cls): - cls.extra_links_expected_calls = [ - call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_JOB_CONF_EXPECTED), - call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), - ] + if AIRFLOW_V_3_0_PLUS: + cls.extra_links_expected_calls = [ + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] + else: + cls.extra_links_expected_calls = [ + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] class DataprocClusterTestBase(DataprocTestBase): @classmethod def setup_class(cls): super().setup_class() - cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(execution_date=None, key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) - ] + if AIRFLOW_V_3_0_PLUS: + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) + ] + else: + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push( + key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, execution_date=None + ) + ] class TestsClusterGenerator: @@ -758,11 +771,17 @@ def test_execute(self, mock_hook, to_dict_mock): self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation()) - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_cluster", - value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + ) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + execution_date=None, + ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -808,11 +827,17 @@ def test_execute_in_gke(self, mock_hook, to_dict_mock): self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation()) - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_cluster", - value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + ) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + execution_date=None, + ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -1095,9 +1120,14 @@ class TestDataprocClusterScaleOperator(DataprocClusterTestBase): @classmethod def setup_class(cls): super().setup_class() - cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) - ] + if AIRFLOW_V_3_0_PLUS: + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + ] + else: + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None) + ] def test_deprecation_warning(self): with pytest.warns(AirflowProviderDeprecationWarning) as warnings: @@ -1142,11 +1172,17 @@ def test_execute(self, mock_hook): # Test whether xcom push occurs before cluster is updated self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - self.mock_ti.xcom_push.assert_called_once_with( - key="conf", - value=DATAPROC_CLUSTER_CONF_EXPECTED, - execution_date=None, - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with( + key="conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + ) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None, + ) @pytest.mark.db_test @@ -1310,9 +1346,12 @@ def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): - xcom_push_call = call.ti.xcom_push( - execution_date=None, key="dataproc_job", value=DATAPROC_JOB_EXPECTED - ) + if AIRFLOW_V_3_0_PLUS: + xcom_push_call = call.ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + else: + xcom_push_call = call.ti.xcom_push( + key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None + ) wait_for_job_call = call.hook().wait_for_job( job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT, timeout=None ) @@ -1358,9 +1397,12 @@ def test_execute(self, mock_hook): job_id=TEST_JOB_ID, project_id=GCP_PROJECT, region=GCP_REGION, timeout=None ) - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_async(self, mock_hook): @@ -1398,9 +1440,12 @@ def test_execute_async(self, mock_hook): ) mock_hook.return_value.wait_for_job.assert_not_called() - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) @@ -1633,11 +1678,17 @@ def test_execute(self, mock_hook): # Test whether the xcom push happens before updating the cluster self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_cluster", - value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + ) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + execution_date=None, + ) def test_missing_region_parameter(self): with pytest.raises(AirflowException): @@ -2399,10 +2450,16 @@ def test_builder(self, mock_hook, mock_uuid): class TestDataProcSparkOperator(DataprocJobTestBase): @classmethod def setup_class(cls): - cls.extra_links_expected_calls = [ - call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_JOB_CONF_EXPECTED), - call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), - ] + if AIRFLOW_V_3_0_PLUS: + cls.extra_links_expected_calls = [ + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] + else: + cls.extra_links_expected_calls = [ + call.ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] @@ -2446,9 +2503,12 @@ def test_execute(self, mock_hook, mock_uuid): assert self.job == job op.execute(context=self.mock_context) - self.mock_ti.xcom_push.assert_called_once_with( - key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None - ) + if AIRFLOW_V_3_0_PLUS: + self.mock_ti.xcom_push.assert_called_once_with(key="conf", value=DATAPROC_JOB_CONF_EXPECTED) + else: + self.mock_ti.xcom_push.assert_called_once_with( + key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None + ) # Test whether xcom push occurs before polling for job self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False) diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index c77dd7747d19d..8a258735291f1 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -111,8 +111,6 @@ def mock_response(status_code, content: Any = None, headers: dict | None = None) def mock_context(task) -> Context: - from datetime import datetime - from airflow.models import TaskInstance from airflow.utils.session import NEW_SESSION from airflow.utils.state import TaskInstanceState @@ -146,13 +144,7 @@ def xcom_pull( return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}") return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") - def xcom_push( - self, - key: str, - value: Any, - execution_date: datetime | None = None, - session: Session = NEW_SESSION, - ) -> None: + def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value values["ti"] = MockedTaskInstance(task=task) diff --git a/tests/providers/yandex/links/test_yq.py b/tests/providers/yandex/links/test_yq.py index 06f1e83939f9a..d46862f1c737f 100644 --- a/tests/providers/yandex/links/test_yq.py +++ b/tests/providers/yandex/links/test_yq.py @@ -23,6 +23,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XCom from airflow.providers.yandex.links.yq import YQLink +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.mock_operators import MockOperator yandexcloud = pytest.importorskip("yandexcloud") @@ -34,11 +35,13 @@ def test_persist(): YQLink.persist(context=mock_context, task_instance=MockOperator(task_id="test_task_id"), web_link="g.com") ti = mock_context["ti"] - ti.xcom_push.assert_called_once_with( - execution_date=None, - key="web_link", - value="g.com", - ) + if AIRFLOW_V_3_0_PLUS: + ti.xcom_push.assert_called_once_with( + key="web_link", + value="g.com", + ) + else: + ti.xcom_push.assert_called_once_with(key="web_link", value="g.com", execution_date=None) def test_default_link(): diff --git a/tests/providers/yandex/operators/test_yq.py b/tests/providers/yandex/operators/test_yq.py index e342a2f961dfd..034f0505517ba 100644 --- a/tests/providers/yandex/operators/test_yq.py +++ b/tests/providers/yandex/operators/test_yq.py @@ -22,6 +22,8 @@ import pytest +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + yandexcloud = pytest.importorskip("yandexcloud") import responses @@ -89,15 +91,25 @@ def test_execute_query(self, mock_get_connection): results = operator.execute(context) assert results == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} - context["ti"].xcom_push.assert_has_calls( - [ - call( - key="web_link", - value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", - execution_date=None, - ), - ] - ) + if AIRFLOW_V_3_0_PLUS: + context["ti"].xcom_push.assert_has_calls( + [ + call( + key="web_link", + value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", + ), + ] + ) + else: + context["ti"].xcom_push.assert_has_calls( + [ + call( + key="web_link", + value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", + execution_date=None, + ), + ] + ) responses.get( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", diff --git a/tests/test_utils/compat.py b/tests/test_utils/compat.py index b09973903b4e2..5daf429cf641f 100644 --- a/tests/test_utils/compat.py +++ b/tests/test_utils/compat.py @@ -46,6 +46,7 @@ AIRFLOW_V_2_8_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.8.0") AIRFLOW_V_2_9_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.9.0") AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.10.0") +AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") try: from airflow.models.baseoperatorlink import BaseOperatorLink