From b54b0d35c0b27c834c8a99ede4e585408e3c9c52 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 14 Feb 2026 19:48:02 -0600 Subject: [PATCH 1/4] AIP-76: Implement SequenceMapper for partition key validation Add SequenceMapper, a PartitionMapper subclass that validates incoming partition keys against a predefined list of allowed values. This is the first half of #44145 (PartitionBySequence and PartitionByProduct). related: #44145 --- .../src/airflow/partition_mappers/sequence.py | 42 ++++++++++++++++ .../src/airflow/serialization/encoders.py | 6 +++ .../unit/partition_mappers/test_sequence.py | 48 +++++++++++++++++++ .../serialization/test_serialized_objects.py | 25 ++++++++++ task-sdk/docs/api.rst | 2 + task-sdk/src/airflow/sdk/__init__.py | 3 ++ task-sdk/src/airflow/sdk/__init__.pyi | 2 + .../definitions/partition_mappers/sequence.py | 26 ++++++++++ 8 files changed, 154 insertions(+) create mode 100644 airflow-core/src/airflow/partition_mappers/sequence.py create mode 100644 airflow-core/tests/unit/partition_mappers/test_sequence.py create mode 100644 task-sdk/src/airflow/sdk/definitions/partition_mappers/sequence.py diff --git a/airflow-core/src/airflow/partition_mappers/sequence.py b/airflow-core/src/airflow/partition_mappers/sequence.py new file mode 100644 index 0000000000000..71baa6944b335 --- /dev/null +++ b/airflow-core/src/airflow/partition_mappers/sequence.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +from airflow.partition_mappers.base import PartitionMapper + + +class SequenceMapper(PartitionMapper): + """Partition mapper that validates keys against a defined sequence.""" + + def __init__(self, sequence: list[str]) -> None: + self.sequence = sequence + self._valid_keys = frozenset(sequence) + + def to_downstream(self, key: str) -> str: + if key not in self._valid_keys: + raise ValueError(f"Key {key!r} not in sequence {self.sequence}") + return key + + def serialize(self) -> dict[str, Any]: + return {"sequence": self.sequence} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: + return cls(sequence=data["sequence"]) diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index 0242b6ee61d37..62587448e5d63 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -45,6 +45,7 @@ PartitionMapper, ProductMapper, QuarterlyMapper, + SequenceMapper, WeeklyMapper, YearlyMapper, ) @@ -375,6 +376,7 @@ def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]: QuarterlyMapper: "airflow.partition_mappers.temporal.QuarterlyMapper", YearlyMapper: "airflow.partition_mappers.temporal.YearlyMapper", ProductMapper: "airflow.partition_mappers.product.ProductMapper", + SequenceMapper: "airflow.partition_mappers.sequence.SequenceMapper", } @functools.singledispatchmethod @@ -416,6 +418,10 @@ def _(self, partition_mapper: ProductMapper) -> dict[str, Any]: "mappers": [encode_partition_mapper(m) for m in partition_mapper.mappers], } + @serialize_partition_mapper.register + def _(self, partition_mapper: SequenceMapper) -> dict[str, Any]: + return {"sequence": partition_mapper.sequence} + _serializer = _Serializer() diff --git a/airflow-core/tests/unit/partition_mappers/test_sequence.py b/airflow-core/tests/unit/partition_mappers/test_sequence.py new file mode 100644 index 0000000000000..a17002301c5cf --- /dev/null +++ b/airflow-core/tests/unit/partition_mappers/test_sequence.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.partition_mappers.sequence import SequenceMapper + + +class TestSequenceMapper: + def test_to_downstream(self): + pm = SequenceMapper(["us", "eu", "apac"]) + assert pm.to_downstream("us") == "us" + assert pm.to_downstream("eu") == "eu" + + def test_to_downstream_invalid_key(self): + pm = SequenceMapper(["us", "eu"]) + with pytest.raises(ValueError, match="not in sequence"): + pm.to_downstream("apac") + + def test_serialize(self): + pm = SequenceMapper(["a", "b", "c"]) + assert pm.serialize() == {"sequence": ["a", "b", "c"]} + + def test_deserialize(self): + pm = SequenceMapper.deserialize({"sequence": ["x", "y"]}) + assert isinstance(pm, SequenceMapper) + assert pm.sequence == ["x", "y"] + + def test_empty_sequence(self): + pm = SequenceMapper([]) + assert pm.serialize() == {"sequence": []} + with pytest.raises(ValueError, match="not in sequence"): + pm.to_downstream("any") diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index e001def29a9e3..fe7492d52eaab 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -877,6 +877,31 @@ def test_decode_product_mapper(): assert core_pm.to_downstream("2024-06-15T10:30:00|2024-06-15T10:30:00") == "2024-06-15T10|2024-06-15" +def test_encode_sequence_mapper(): + from airflow.sdk import SequenceMapper + from airflow.serialization.encoders import encode_partition_mapper + + partition_mapper = SequenceMapper(["us", "eu", "apac"]) + assert encode_partition_mapper(partition_mapper) == { + Encoding.TYPE: "airflow.partition_mappers.sequence.SequenceMapper", + Encoding.VAR: {"sequence": ["us", "eu", "apac"]}, + } + + +def test_decode_sequence_mapper(): + from airflow.partition_mappers.sequence import SequenceMapper as CoreSequenceMapper + from airflow.sdk import SequenceMapper + from airflow.serialization.decoders import decode_partition_mapper + from airflow.serialization.encoders import encode_partition_mapper + + partition_mapper = SequenceMapper(["us", "eu", "apac"]) + encoded_pm = encode_partition_mapper(partition_mapper) + core_pm = decode_partition_mapper(encoded_pm) + + assert isinstance(core_pm, CoreSequenceMapper) + assert core_pm.sequence == ["us", "eu", "apac"] + + class TestSerializedBaseOperator: # ensure the default logging config is used for this test, no matter what ran before @pytest.mark.usefixtures("reset_logging_config") diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 88cc4205374b3..8116ecd604c49 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -217,6 +217,8 @@ Partition Mapper .. autoapiclass:: airflow.sdk.ProductMapper +.. autoapiclass:: airflow.sdk.SequenceMapper + I/O Helpers ----------- .. autoapiclass:: airflow.sdk.ObjectStoragePath diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 669c87e019ae2..f893613ae8d8e 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -63,6 +63,7 @@ "PokeReturnValue", "ProductMapper", "QuarterlyMapper", + "SequenceMapper", "SkipMixin", "SyncCallback", "TaskGroup", @@ -124,6 +125,7 @@ from airflow.sdk.definitions.partition_mappers.base import PartitionMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper + from airflow.sdk.definitions.partition_mappers.sequence import SequenceMapper from airflow.sdk.definitions.partition_mappers.temporal import ( DailyMapper, HourlyMapper, @@ -203,6 +205,7 @@ "ProductMapper": ".definitions.partition_mappers.product", "QuarterlyMapper": ".definitions.partition_mappers.temporal", "SecretCache": ".execution_time.cache", + "SequenceMapper": ".definitions.partition_mappers.sequence", "SkipMixin": ".bases.skipmixin", "SyncCallback": ".definitions.callback", "TaskGroup": ".definitions.taskgroup", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index ed9943700b588..e458916209340 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -64,6 +64,7 @@ from airflow.sdk.definitions.param import Param as Param from airflow.sdk.definitions.partition_mappers.base import PartitionMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper +from airflow.sdk.definitions.partition_mappers.sequence import SequenceMapper from airflow.sdk.definitions.partition_mappers.temporal import ( DailyMapper, HourlyMapper, @@ -139,6 +140,7 @@ __all__ = [ "ProductMapper", "QuarterlyMapper", "SecretCache", + "SequenceMapper", "SkipMixin", "TaskGroup", "TaskInstanceState", diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/sequence.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/sequence.py new file mode 100644 index 0000000000000..2ef9799c126d1 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/sequence.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.sdk.definitions.partition_mappers.base import PartitionMapper + + +class SequenceMapper(PartitionMapper): + """Partition mapper that validates keys against a defined sequence.""" + + def __init__(self, sequence: list[str]) -> None: + self.sequence = sequence From 88526548eb07ebd2ec0917fe78179018e0dce6d5 Mon Sep 17 00:00:00 2001 From: Anish Date: Wed, 4 Mar 2026 00:18:10 -0600 Subject: [PATCH 2/4] added example and clean up --- .../example_dags/example_asset_partition.py | 41 +++++++++++++++++++ .../src/airflow/partition_mappers/sequence.py | 3 +- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py b/airflow-core/src/airflow/example_dags/example_asset_partition.py index 69c9d883e63ad..6156c9b3512f3 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_partition.py +++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py @@ -28,6 +28,7 @@ IdentityMapper, PartitionedAssetTimetable, ProductMapper, + SequenceMapper, YearlyMapper, asset, task, @@ -184,3 +185,43 @@ def aggregate_sales(dag_run=None): print(dag_run.partition_key) aggregate_sales() + + +region_raw_stats = Asset(uri="file://incoming/player-stats/by-region.csv", name="region_raw_stats") + + +with DAG( + dag_id="ingest_region_stats", + schedule=None, + tags=["player-stats", "regional"], +): + """ + Ingest player statistics per region. + + Externally triggered with partition_key set to a region code (us, eu, apac). + """ + + @task(outlets=[region_raw_stats]) + def ingest_region(): + """Materialize player statistics for a single region partition.""" + pass + + ingest_region() + + +@asset( + uri="file://analytics/player-stats/regional-breakdown.csv", + schedule=PartitionedAssetTimetable( + assets=region_raw_stats, + default_partition_mapper=SequenceMapper(["us", "eu", "apac"]), + ), + tags=["player-stats", "regional"], +) +def regional_stats_breakdown(): + """ + Aggregate regional player statistics. + + This asset demonstrates SequenceMapper, which validates that upstream partition + keys belong to a fixed set of values (us, eu, apac) rather than time-based partitions. + """ + pass diff --git a/airflow-core/src/airflow/partition_mappers/sequence.py b/airflow-core/src/airflow/partition_mappers/sequence.py index 71baa6944b335..6670dbabf2d5b 100644 --- a/airflow-core/src/airflow/partition_mappers/sequence.py +++ b/airflow-core/src/airflow/partition_mappers/sequence.py @@ -27,10 +27,9 @@ class SequenceMapper(PartitionMapper): def __init__(self, sequence: list[str]) -> None: self.sequence = sequence - self._valid_keys = frozenset(sequence) def to_downstream(self, key: str) -> str: - if key not in self._valid_keys: + if key not in self.sequence: raise ValueError(f"Key {key!r} not in sequence {self.sequence}") return key From 6f575278bf9676e09c672cc253baa5385fedc4b2 Mon Sep 17 00:00:00 2001 From: Anish Date: Wed, 4 Mar 2026 01:09:48 -0600 Subject: [PATCH 3/4] fix spelling tests --- .../src/airflow/example_dags/example_asset_partition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py b/airflow-core/src/airflow/example_dags/example_asset_partition.py index 6156c9b3512f3..1c9f06371f107 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_partition.py +++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py @@ -198,7 +198,7 @@ def aggregate_sales(dag_run=None): """ Ingest player statistics per region. - Externally triggered with partition_key set to a region code (us, eu, apac). + Externally triggered with partition_key set to a region code (``us``, ``eu``, ``apac``). """ @task(outlets=[region_raw_stats]) @@ -222,6 +222,6 @@ def regional_stats_breakdown(): Aggregate regional player statistics. This asset demonstrates SequenceMapper, which validates that upstream partition - keys belong to a fixed set of values (us, eu, apac) rather than time-based partitions. + keys belong to a fixed set of values (``us``, ``eu``, ``apac``) rather than time-based partitions. """ pass From 09e105a322f167ebf27f9f5798e786111a43e9b2 Mon Sep 17 00:00:00 2001 From: Anish Date: Thu, 5 Mar 2026 01:15:48 -0600 Subject: [PATCH 4/4] Rename SequenceMapper to AllowedKeyMapper --- .../example_dags/example_asset_partition.py | 8 ++--- .../{sequence.py => allowed_key.py} | 16 +++++----- .../src/airflow/serialization/encoders.py | 8 ++--- .../{test_sequence.py => test_allowed_key.py} | 30 +++++++++---------- .../serialization/test_serialized_objects.py | 22 +++++++------- task-sdk/docs/api.rst | 2 +- task-sdk/src/airflow/sdk/__init__.py | 6 ++-- task-sdk/src/airflow/sdk/__init__.pyi | 4 +-- .../{sequence.py => allowed_key.py} | 8 ++--- 9 files changed, 52 insertions(+), 52 deletions(-) rename airflow-core/src/airflow/partition_mappers/{sequence.py => allowed_key.py} (71%) rename airflow-core/tests/unit/partition_mappers/{test_sequence.py => test_allowed_key.py} (60%) rename task-sdk/src/airflow/sdk/definitions/partition_mappers/{sequence.py => allowed_key.py} (80%) diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py b/airflow-core/src/airflow/example_dags/example_asset_partition.py index 1c9f06371f107..ecc9d7427af0c 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_partition.py +++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py @@ -21,6 +21,7 @@ from airflow.sdk import ( DAG, + AllowedKeyMapper, Asset, CronPartitionTimetable, DailyMapper, @@ -28,7 +29,6 @@ IdentityMapper, PartitionedAssetTimetable, ProductMapper, - SequenceMapper, YearlyMapper, asset, task, @@ -213,7 +213,7 @@ def ingest_region(): uri="file://analytics/player-stats/regional-breakdown.csv", schedule=PartitionedAssetTimetable( assets=region_raw_stats, - default_partition_mapper=SequenceMapper(["us", "eu", "apac"]), + default_partition_mapper=AllowedKeyMapper(["us", "eu", "apac"]), ), tags=["player-stats", "regional"], ) @@ -221,7 +221,7 @@ def regional_stats_breakdown(): """ Aggregate regional player statistics. - This asset demonstrates SequenceMapper, which validates that upstream partition - keys belong to a fixed set of values (``us``, ``eu``, ``apac``) rather than time-based partitions. + This asset demonstrates AllowedKeyMapper, which validates that upstream partition + keys belong to a fixed set of allowed values (``us``, ``eu``, ``apac``) rather than time-based partitions. """ pass diff --git a/airflow-core/src/airflow/partition_mappers/sequence.py b/airflow-core/src/airflow/partition_mappers/allowed_key.py similarity index 71% rename from airflow-core/src/airflow/partition_mappers/sequence.py rename to airflow-core/src/airflow/partition_mappers/allowed_key.py index 6670dbabf2d5b..8b560f426aa84 100644 --- a/airflow-core/src/airflow/partition_mappers/sequence.py +++ b/airflow-core/src/airflow/partition_mappers/allowed_key.py @@ -22,20 +22,20 @@ from airflow.partition_mappers.base import PartitionMapper -class SequenceMapper(PartitionMapper): - """Partition mapper that validates keys against a defined sequence.""" +class AllowedKeyMapper(PartitionMapper): + """Partition mapper that validates keys against a set of allowed keys.""" - def __init__(self, sequence: list[str]) -> None: - self.sequence = sequence + def __init__(self, allowed_keys: list[str]) -> None: + self.allowed_keys = allowed_keys def to_downstream(self, key: str) -> str: - if key not in self.sequence: - raise ValueError(f"Key {key!r} not in sequence {self.sequence}") + if key not in self.allowed_keys: + raise ValueError(f"Key {key!r} not in allowed keys {self.allowed_keys}") return key def serialize(self) -> dict[str, Any]: - return {"sequence": self.sequence} + return {"allowed_keys": self.allowed_keys} @classmethod def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: - return cls(sequence=data["sequence"]) + return cls(allowed_keys=data["allowed_keys"]) diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index 62587448e5d63..7db97b844f6a6 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -28,6 +28,7 @@ from airflow._shared.module_loading import qualname from airflow.partition_mappers.base import PartitionMapper as CorePartitionMapper from airflow.sdk import ( + AllowedKeyMapper, Asset, AssetAlias, AssetAll, @@ -45,7 +46,6 @@ PartitionMapper, ProductMapper, QuarterlyMapper, - SequenceMapper, WeeklyMapper, YearlyMapper, ) @@ -376,7 +376,7 @@ def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]: QuarterlyMapper: "airflow.partition_mappers.temporal.QuarterlyMapper", YearlyMapper: "airflow.partition_mappers.temporal.YearlyMapper", ProductMapper: "airflow.partition_mappers.product.ProductMapper", - SequenceMapper: "airflow.partition_mappers.sequence.SequenceMapper", + AllowedKeyMapper: "airflow.partition_mappers.allowed_key.AllowedKeyMapper", } @functools.singledispatchmethod @@ -419,8 +419,8 @@ def _(self, partition_mapper: ProductMapper) -> dict[str, Any]: } @serialize_partition_mapper.register - def _(self, partition_mapper: SequenceMapper) -> dict[str, Any]: - return {"sequence": partition_mapper.sequence} + def _(self, partition_mapper: AllowedKeyMapper) -> dict[str, Any]: + return {"allowed_keys": partition_mapper.allowed_keys} _serializer = _Serializer() diff --git a/airflow-core/tests/unit/partition_mappers/test_sequence.py b/airflow-core/tests/unit/partition_mappers/test_allowed_key.py similarity index 60% rename from airflow-core/tests/unit/partition_mappers/test_sequence.py rename to airflow-core/tests/unit/partition_mappers/test_allowed_key.py index a17002301c5cf..a04b22e48e9b4 100644 --- a/airflow-core/tests/unit/partition_mappers/test_sequence.py +++ b/airflow-core/tests/unit/partition_mappers/test_allowed_key.py @@ -18,31 +18,31 @@ import pytest -from airflow.partition_mappers.sequence import SequenceMapper +from airflow.partition_mappers.allowed_key import AllowedKeyMapper -class TestSequenceMapper: +class TestAllowedKeyMapper: def test_to_downstream(self): - pm = SequenceMapper(["us", "eu", "apac"]) + pm = AllowedKeyMapper(["us", "eu", "apac"]) assert pm.to_downstream("us") == "us" assert pm.to_downstream("eu") == "eu" def test_to_downstream_invalid_key(self): - pm = SequenceMapper(["us", "eu"]) - with pytest.raises(ValueError, match="not in sequence"): + pm = AllowedKeyMapper(["us", "eu"]) + with pytest.raises(ValueError, match="not in allowed keys"): pm.to_downstream("apac") def test_serialize(self): - pm = SequenceMapper(["a", "b", "c"]) - assert pm.serialize() == {"sequence": ["a", "b", "c"]} + pm = AllowedKeyMapper(["a", "b", "c"]) + assert pm.serialize() == {"allowed_keys": ["a", "b", "c"]} def test_deserialize(self): - pm = SequenceMapper.deserialize({"sequence": ["x", "y"]}) - assert isinstance(pm, SequenceMapper) - assert pm.sequence == ["x", "y"] - - def test_empty_sequence(self): - pm = SequenceMapper([]) - assert pm.serialize() == {"sequence": []} - with pytest.raises(ValueError, match="not in sequence"): + pm = AllowedKeyMapper.deserialize({"allowed_keys": ["x", "y"]}) + assert isinstance(pm, AllowedKeyMapper) + assert pm.allowed_keys == ["x", "y"] + + def test_empty_allowed_keys(self): + pm = AllowedKeyMapper([]) + assert pm.serialize() == {"allowed_keys": []} + with pytest.raises(ValueError, match="not in allowed keys"): pm.to_downstream("any") diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index fe7492d52eaab..a814de07d2e05 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -877,29 +877,29 @@ def test_decode_product_mapper(): assert core_pm.to_downstream("2024-06-15T10:30:00|2024-06-15T10:30:00") == "2024-06-15T10|2024-06-15" -def test_encode_sequence_mapper(): - from airflow.sdk import SequenceMapper +def test_encode_allowed_key_mapper(): + from airflow.sdk import AllowedKeyMapper from airflow.serialization.encoders import encode_partition_mapper - partition_mapper = SequenceMapper(["us", "eu", "apac"]) + partition_mapper = AllowedKeyMapper(["us", "eu", "apac"]) assert encode_partition_mapper(partition_mapper) == { - Encoding.TYPE: "airflow.partition_mappers.sequence.SequenceMapper", - Encoding.VAR: {"sequence": ["us", "eu", "apac"]}, + Encoding.TYPE: "airflow.partition_mappers.allowed_key.AllowedKeyMapper", + Encoding.VAR: {"allowed_keys": ["us", "eu", "apac"]}, } -def test_decode_sequence_mapper(): - from airflow.partition_mappers.sequence import SequenceMapper as CoreSequenceMapper - from airflow.sdk import SequenceMapper +def test_decode_allowed_key_mapper(): + from airflow.partition_mappers.allowed_key import AllowedKeyMapper as CoreAllowedKeyMapper + from airflow.sdk import AllowedKeyMapper from airflow.serialization.decoders import decode_partition_mapper from airflow.serialization.encoders import encode_partition_mapper - partition_mapper = SequenceMapper(["us", "eu", "apac"]) + partition_mapper = AllowedKeyMapper(["us", "eu", "apac"]) encoded_pm = encode_partition_mapper(partition_mapper) core_pm = decode_partition_mapper(encoded_pm) - assert isinstance(core_pm, CoreSequenceMapper) - assert core_pm.sequence == ["us", "eu", "apac"] + assert isinstance(core_pm, CoreAllowedKeyMapper) + assert core_pm.allowed_keys == ["us", "eu", "apac"] class TestSerializedBaseOperator: diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 8116ecd604c49..0565eeca0fb65 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -217,7 +217,7 @@ Partition Mapper .. autoapiclass:: airflow.sdk.ProductMapper -.. autoapiclass:: airflow.sdk.SequenceMapper +.. autoapiclass:: airflow.sdk.AllowedKeyMapper I/O Helpers ----------- diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index f893613ae8d8e..38cc29f357c31 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -20,6 +20,7 @@ __all__ = [ "__version__", + "AllowedKeyMapper", "Asset", "AssetAlias", "AssetAll", @@ -63,7 +64,6 @@ "PokeReturnValue", "ProductMapper", "QuarterlyMapper", - "SequenceMapper", "SkipMixin", "SyncCallback", "TaskGroup", @@ -122,10 +122,10 @@ from airflow.sdk.definitions.decorators.task_group import task_group from airflow.sdk.definitions.edges import EdgeModifier, Label from airflow.sdk.definitions.param import Param, ParamsDict + from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper from airflow.sdk.definitions.partition_mappers.base import PartitionMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper - from airflow.sdk.definitions.partition_mappers.sequence import SequenceMapper from airflow.sdk.definitions.partition_mappers.temporal import ( DailyMapper, HourlyMapper, @@ -161,6 +161,7 @@ conf: AirflowSDKConfigParser __lazy_imports: dict[str, str] = { + "AllowedKeyMapper": ".definitions.partition_mappers.allowed_key", "Asset": ".definitions.asset", "AssetAlias": ".definitions.asset", "AssetAll": ".definitions.asset", @@ -205,7 +206,6 @@ "ProductMapper": ".definitions.partition_mappers.product", "QuarterlyMapper": ".definitions.partition_mappers.temporal", "SecretCache": ".execution_time.cache", - "SequenceMapper": ".definitions.partition_mappers.sequence", "SkipMixin": ".bases.skipmixin", "SyncCallback": ".definitions.callback", "TaskGroup": ".definitions.taskgroup", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index e458916209340..d7d503297dca0 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -61,10 +61,10 @@ from airflow.sdk.definitions.decorators import setup as setup, task as task, tea from airflow.sdk.definitions.decorators.task_group import task_group as task_group from airflow.sdk.definitions.edges import EdgeModifier as EdgeModifier, Label as Label from airflow.sdk.definitions.param import Param as Param +from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper from airflow.sdk.definitions.partition_mappers.base import PartitionMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper -from airflow.sdk.definitions.partition_mappers.sequence import SequenceMapper from airflow.sdk.definitions.partition_mappers.temporal import ( DailyMapper, HourlyMapper, @@ -100,6 +100,7 @@ conf: AirflowSDKConfigParser __all__ = [ "__version__", + "AllowedKeyMapper", "Asset", "AssetAlias", "AssetAll", @@ -140,7 +141,6 @@ __all__ = [ "ProductMapper", "QuarterlyMapper", "SecretCache", - "SequenceMapper", "SkipMixin", "TaskGroup", "TaskInstanceState", diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/sequence.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py similarity index 80% rename from task-sdk/src/airflow/sdk/definitions/partition_mappers/sequence.py rename to task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py index 2ef9799c126d1..7d860bd879674 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/sequence.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/allowed_key.py @@ -19,8 +19,8 @@ from airflow.sdk.definitions.partition_mappers.base import PartitionMapper -class SequenceMapper(PartitionMapper): - """Partition mapper that validates keys against a defined sequence.""" +class AllowedKeyMapper(PartitionMapper): + """Partition mapper that validates keys against a set of allowed keys.""" - def __init__(self, sequence: list[str]) -> None: - self.sequence = sequence + def __init__(self, allowed_keys: list[str]) -> None: + self.allowed_keys = allowed_keys