diff --git a/airflow-core/docs/templates-ref.rst b/airflow-core/docs/templates-ref.rst index c908e03da9411..b1609ff0c943e 100644 --- a/airflow-core/docs/templates-ref.rst +++ b/airflow-core/docs/templates-ref.rst @@ -85,6 +85,8 @@ Variable Type Description ``{{ params }}`` dict[str, Any] | The user-defined params. This can be overridden by the mapping | passed to ``trigger_dag -c`` if ``dag_run_conf_overrides_params`` | is enabled in ``airflow.cfg``. +``{{ partition_key }}`` str | None | The partition key from the current :class:`~airflow.models.dagrun.DagRun`. + | Returns ``None`` if no partition key was set. Added in version 3.3.0. ``{{ var.value }}`` Airflow variables. See `Airflow Variables in Templates`_ below. ``{{ var.json }}`` Airflow variables. See `Airflow Variables in Templates`_ below. ``{{ conn }}`` Airflow connections. See `Airflow Connections in Templates`_ below. diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index 8d800d8fab11d..eedd9a68ad0ec 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -66,7 +66,11 @@ prepare_virtualenv, write_python_script, ) -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS +from airflow.providers.standard.version_compat import ( + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, +) from airflow.utils import hashlib_wrapper from airflow.utils.file import get_unique_dag_module_name @@ -449,6 +453,8 @@ class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): } if AIRFLOW_V_3_0_PLUS: BASE_SERIALIZABLE_CONTEXT_KEYS.add("task_reschedule_count") + if AIRFLOW_V_3_3_PLUS: + BASE_SERIALIZABLE_CONTEXT_KEYS.add("partition_key") PENDULUM_SERIALIZABLE_CONTEXT_KEYS = { "data_interval_end", diff --git a/providers/standard/src/airflow/providers/standard/version_compat.py b/providers/standard/src/airflow/providers/standard/version_compat.py index 769e790fb5972..db58809b127bd 100644 --- a/providers/standard/src/airflow/providers/standard/version_compat.py +++ b/providers/standard/src/airflow/providers/standard/version_compat.py @@ -36,6 +36,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) AIRFLOW_V_3_1_3_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 3) AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0) +AIRFLOW_V_3_3_PLUS: bool = get_base_airflow_version_tuple() >= (3, 3, 0) # BaseOperator: Use 3.1+ due to xcom_push method missing in SDK BaseOperator 3.0.x # This is needed for DecoratedOperator compatibility @@ -50,6 +51,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: "AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_2_PLUS", + "AIRFLOW_V_3_3_PLUS", "ArgNotSet", "BaseOperator", ] diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index fe48b8558c1fa..b7a63284608aa 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -59,6 +59,7 @@ class Context(TypedDict, total=False): map_index_template: NotRequired[str | None] outlets: list params: dict[str, Any] + partition_key: NotRequired[str | None] prev_data_interval_start_success: NotRequired[DateTime | None] prev_data_interval_end_success: NotRequired[DateTime | None] prev_start_date_success: NotRequired[DateTime | None] diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6210797b55853..06770634727fa 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -256,6 +256,7 @@ def get_template_context(self) -> Context: context_from_server: Context = { # TODO: Assess if we need to pass these through timezone.coerce_datetime "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 + "partition_key": dag_run.partition_key, "triggering_asset_events": TriggeringAssetEventsAccessor.build( AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event) for event in dag_run.consumed_asset_events diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 5aeb009bd33da..04db2122eabd2 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1722,8 +1722,33 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s "ts": "2024-12-01T01:00:00+00:00", "ts_nodash": "20241201T010000", "ts_nodash_with_tz": "20241201T010000+0000", + "partition_key": dr.partition_key, } + def test_partition_key_in_context(self, create_runtime_ti, mock_supervisor_comms): + """Test that partition_key from dag_run is exposed in the template context.""" + task = BaseOperator(task_id="hello") + runtime_ti = create_runtime_ti(task=task, dag_id="basic_task") + + dr = runtime_ti._ti_context_from_server.dag_run + + mock_supervisor_comms.send.return_value = PrevSuccessfulDagRunResult( + data_interval_end=dr.logical_date - timedelta(hours=1), + data_interval_start=dr.logical_date - timedelta(hours=2), + start_date=dr.start_date - timedelta(hours=1), + end_date=dr.start_date, + ) + + context = runtime_ti.get_template_context() + + # Default: partition_key is None + assert context["partition_key"] is None + + # Set partition_key on dag_run and verify it surfaces in context + dr.partition_key = "some-partition" + context = runtime_ti.get_template_context() + assert context["partition_key"] == "some-partition" + def test_lazy_loading_not_triggered_until_accessed(self, create_runtime_ti, mock_supervisor_comms): """Ensure lazy-loaded attributes are not resolved until accessed.""" task = BaseOperator(task_id="hello")