Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow-core/docs/templates-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ Variable Type Description
``{{ params }}`` dict[str, Any] | The user-defined params. This can be overridden by the mapping
| passed to ``trigger_dag -c`` if ``dag_run_conf_overrides_params``
| is enabled in ``airflow.cfg``.
``{{ partition_key }}`` str | None | The partition key from the current :class:`~airflow.models.dagrun.DagRun`.
| Returns ``None`` if no partition key was set. Added in version 3.3.0.
``{{ var.value }}`` Airflow variables. See `Airflow Variables in Templates`_ below.
``{{ var.json }}`` Airflow variables. See `Airflow Variables in Templates`_ below.
``{{ conn }}`` Airflow connections. See `Airflow Connections in Templates`_ below.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
prepare_virtualenv,
write_python_script,
)
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
from airflow.providers.standard.version_compat import (
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_2_PLUS,
AIRFLOW_V_3_3_PLUS,
)
from airflow.utils import hashlib_wrapper
from airflow.utils.file import get_unique_dag_module_name

Expand Down Expand Up @@ -449,6 +453,8 @@ class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta):
}
if AIRFLOW_V_3_0_PLUS:
BASE_SERIALIZABLE_CONTEXT_KEYS.add("task_reschedule_count")
if AIRFLOW_V_3_3_PLUS:
BASE_SERIALIZABLE_CONTEXT_KEYS.add("partition_key")

PENDULUM_SERIALIZABLE_CONTEXT_KEYS = {
"data_interval_end",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_1_3_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 3)
AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0)
AIRFLOW_V_3_3_PLUS: bool = get_base_airflow_version_tuple() >= (3, 3, 0)

# BaseOperator: Use 3.1+ due to xcom_push method missing in SDK BaseOperator 3.0.x
# This is needed for DecoratedOperator compatibility
Expand All @@ -50,6 +51,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_2_PLUS",
"AIRFLOW_V_3_3_PLUS",
"ArgNotSet",
"BaseOperator",
]
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/definitions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Context(TypedDict, total=False):
map_index_template: NotRequired[str | None]
outlets: list
params: dict[str, Any]
partition_key: NotRequired[str | None]
prev_data_interval_start_success: NotRequired[DateTime | None]
prev_data_interval_end_success: NotRequired[DateTime | None]
prev_start_date_success: NotRequired[DateTime | None]
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def get_template_context(self) -> Context:
context_from_server: Context = {
# TODO: Assess if we need to pass these through timezone.coerce_datetime
"dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
"partition_key": dag_run.partition_key,
"triggering_asset_events": TriggeringAssetEventsAccessor.build(
AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
for event in dag_run.consumed_asset_events
Expand Down
25 changes: 25 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,8 +1722,33 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s
"ts": "2024-12-01T01:00:00+00:00",
"ts_nodash": "20241201T010000",
"ts_nodash_with_tz": "20241201T010000+0000",
"partition_key": dr.partition_key,
}

def test_partition_key_in_context(self, create_runtime_ti, mock_supervisor_comms):
"""Test that partition_key from dag_run is exposed in the template context."""
task = BaseOperator(task_id="hello")
runtime_ti = create_runtime_ti(task=task, dag_id="basic_task")

dr = runtime_ti._ti_context_from_server.dag_run

mock_supervisor_comms.send.return_value = PrevSuccessfulDagRunResult(
data_interval_end=dr.logical_date - timedelta(hours=1),
data_interval_start=dr.logical_date - timedelta(hours=2),
start_date=dr.start_date - timedelta(hours=1),
end_date=dr.start_date,
)

context = runtime_ti.get_template_context()

# Default: partition_key is None
assert context["partition_key"] is None

# Set partition_key on dag_run and verify it surfaces in context
dr.partition_key = "some-partition"
context = runtime_ti.get_template_context()
assert context["partition_key"] == "some-partition"

def test_lazy_loading_not_triggered_until_accessed(self, create_runtime_ti, mock_supervisor_comms):
"""Ensure lazy-loaded attributes are not resolved until accessed."""
task = BaseOperator(task_id="hello")
Expand Down
Loading