From 2b24dda57532570a6a1a8b50055f80a4edcb2bc6 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 20 Jul 2022 12:02:05 +0800 Subject: [PATCH 1/2] Improve taskflow annotation with ParamSpec This is made available with our latest Mypy update. ParamSpec allows us to more accurately type a decorated task to return an XComArg (while still being correctly typed to accept the same arguments as the decorated function). This allows us to provide autocompletion for XComArg operations, such as map() and zip() introduced in AIP-42. --- airflow/decorators/__init__.pyi | 12 ++--- airflow/decorators/base.py | 78 +++++++++++++++++++-------------- airflow/typing_compat.py | 32 +++++++++++--- 3 files changed, 76 insertions(+), 46 deletions(-) diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index 726798f1341e3..35a56e231e6bd 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -20,9 +20,9 @@ # necessarily exist at run time. See "Creating Custom @task Decorators" # documentation for more details. -from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union, overload -from airflow.decorators.base import Function, Task, TaskDecorator +from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator from airflow.decorators.branch_python import branch_task from airflow.decorators.python import python_task from airflow.decorators.python_virtualenv import virtualenv_task @@ -68,7 +68,7 @@ class TaskDecoratorCollection: """ # [START mixin_for_typing] @overload - def python(self, python_callable: Function) -> Task[Function]: ... + def python(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... # [END mixin_for_typing] @overload def __call__( @@ -81,7 +81,7 @@ class TaskDecoratorCollection: ) -> TaskDecorator: """Aliasing ``python``; signature should match exactly.""" @overload - def __call__(self, python_callable: Function) -> Task[Function]: + def __call__(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: """Aliasing ``python``; signature should match exactly.""" @overload def virtualenv( @@ -122,7 +122,7 @@ class TaskDecoratorCollection: such as transmission a large amount of XCom to TaskAPI. """ @overload - def virtualenv(self, python_callable: Function) -> Task[Function]: ... + def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... @overload def branch(self, *, multiple_outputs: Optional[bool] = None, **kwargs) -> TaskDecorator: """Create a decorator to wrap the decorated callable into a BranchPythonOperator. @@ -134,7 +134,7 @@ class TaskDecoratorCollection: Dict will unroll to XCom values with keys as XCom keys. Defaults to False. """ @overload - def branch(self, python_callable: Function) -> Task[Function]: ... + def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... # [START decorator_signature] def docker( self, diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index a36d7d5d43a7b..0a4b75cecedb9 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import functools import inspect import re from typing import ( @@ -68,7 +67,7 @@ ) from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg -from airflow.typing_compat import Protocol +from airflow.typing_compat import ParamSpec, Protocol from airflow.utils import timezone from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context from airflow.utils.task_group import TaskGroup, TaskGroupContext @@ -236,13 +235,15 @@ def _hook_apply_defaults(self, *args, **kwargs): return args, kwargs -Function = TypeVar("Function", bound=Callable) +FParams = ParamSpec("FParams") + +FReturn = TypeVar("FReturn") OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") @attr.define(slots=False) -class _TaskDecorator(Generic[Function, OperatorSubclass]): +class _TaskDecorator(Generic[FParams, FReturn, OperatorSubclass]): """ Helper class for providing dynamic task mapping to decorated functions. @@ -251,7 +252,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]): :meta private: """ - function: Function = attr.ib() + function: Callable[FParams, FReturn] = attr.ib() operator_class: Type[OperatorSubclass] multiple_outputs: bool = attr.ib() kwargs: Dict[str, Any] = attr.ib(factory=dict) @@ -272,7 +273,7 @@ def __attrs_post_init__(self): raise TypeError(f"@{self.decorator_name} does not support methods") self.kwargs.setdefault('task_id', self.function.__name__) - def __call__(self, *args, **kwargs) -> XComArg: + def __call__(self, *args: "FParams.args", **kwargs: "FParams.kwargs") -> XComArg: op = self.operator_class( python_callable=self.function, op_args=args, @@ -285,7 +286,7 @@ def __call__(self, *args, **kwargs) -> XComArg: return XComArg(op) @property - def __wrapped__(self) -> Function: + def __wrapped__(self) -> Callable[FParams, FReturn]: return self.function @cached_property @@ -337,9 +338,7 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg: # to False to skip the checks on execution. return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) - def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> XComArg: - from airflow.models.xcom_arg import XComArg - + def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg: if not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) @@ -420,14 +419,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ) return XComArg(operator=operator) - def partial(self, **kwargs: Any) -> "_TaskDecorator[Function, OperatorSubclass]": + def partial(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, OperatorSubclass]": self._validate_arg_names("partial", kwargs) old_kwargs = self.kwargs.get("op_kwargs", {}) prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial") kwargs.update(old_kwargs) return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs}) - def override(self, **kwargs: Any) -> "_TaskDecorator[Function, OperatorSubclass]": + def override(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, OperatorSubclass]": return attr.evolve(self, kwargs={**self.kwargs, **kwargs}) @@ -506,7 +505,7 @@ def _render_if_not_already_resolved(key: str, value: Any): return {k: _render_if_not_already_resolved(k, v) for k, v in value.items()} -class Task(Generic[Function]): +class Task(Generic[FParams, FReturn]): """Declaration of a @task-decorated callable for type-checking. An instance of this type inherits the call signature of the decorated @@ -517,18 +516,21 @@ class Task(Generic[Function]): This type is implemented by ``_TaskDecorator`` at runtime. """ - __call__: Function + __call__: Callable[FParams, XComArg] - function: Function + function: Callable[FParams, FReturn] @property - def __wrapped__(self) -> Function: + def __wrapped__(self) -> Callable[FParams, FReturn]: + ... + + def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]": ... def expand(self, **kwargs: "Mappable") -> XComArg: ... - def partial(self, **kwargs: Any) -> "Task[Function]": + def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg: ... @@ -536,7 +538,10 @@ class TaskDecorator(Protocol): """Type declaration for ``task_decorator_factory`` return type.""" @overload - def __call__(self, python_callable: Function) -> Task[Function]: + def __call__( # type: ignore[misc] + self, + python_callable: Callable[FParams, FReturn], + ) -> Task[FParams, FReturn]: """For the "bare decorator" ``@task`` case.""" @overload @@ -545,7 +550,7 @@ def __call__( *, multiple_outputs: Optional[bool] = None, **kwargs: Any, - ) -> Callable[[Function], Task[Function]]: + ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]: """For the decorator factory ``@task()`` case.""" @@ -556,16 +561,20 @@ def task_decorator_factory( decorated_operator_class: Type[BaseOperator], **kwargs, ) -> TaskDecorator: - """ - A factory that generates a wrapper that wraps a function into an Airflow operator. - Accepts kwargs for operator kwarg. Can be reused in a single DAG. + """Generate a wrapper that wraps a function into an Airflow operator. - :param python_callable: Function to decorate - :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to - multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. - :param decorated_operator_class: The operator that executes the logic needed to run the python function in - the correct environment + Can be reused in a single DAG. + :param python_callable: Function to decorate. + :param multiple_outputs: If set to True, the decorated function's return + value will be unrolled to multiple XCom values. Dict will unroll to XCom + values with its keys as XCom keys. If set to False (default), only at + most one XCom value is pushed. + :param decorated_operator_class: The operator that executes the logic needed + to run the python function in the correct environment. + + Other kwargs are directly forwarded to the underlying operator class when + it's instantiated. """ if multiple_outputs is None: multiple_outputs = cast(bool, attr.NOTHING) @@ -579,10 +588,13 @@ def task_decorator_factory( return cast(TaskDecorator, decorator) elif python_callable is not None: raise TypeError('No args allowed while using @task, use kwargs instead') - decorator_factory = functools.partial( - _TaskDecorator, - multiple_outputs=multiple_outputs, - operator_class=decorated_operator_class, - kwargs=kwargs, - ) + + def decorator_factory(python_callable): + return _TaskDecorator( + function=python_callable, + multiple_outputs=multiple_outputs, + operator_class=decorated_operator_class, + kwargs=kwargs, + ) + return cast(TaskDecorator, decorator_factory) diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py index 163889b8a2975..ec1846438fc67 100644 --- a/airflow/typing_compat.py +++ b/airflow/typing_compat.py @@ -21,10 +21,28 @@ codebase easier. """ -try: - # Literal, Protocol and TypedDict are only added to typing module starting from - # python 3.8 we can safely remove this shim import after Airflow drops - # support for <3.8 - from typing import Literal, Protocol, TypedDict, runtime_checkable # type: ignore -except ImportError: - from typing_extensions import Literal, Protocol, TypedDict, runtime_checkable # type: ignore # noqa +__all__ = [ + "Literal", + "ParamSpec", + "Protocol", + "TypedDict", + "runtime_checkable", +] + +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol, TypedDict, runtime_checkable +else: + from typing_extensions import Protocol, TypedDict, runtime_checkable + +# Literal in 3.8 is limited to one single argument, not e.g. "Literal[1, 2]". +if sys.version_info >= (3, 9): + from typing import Literal +else: + from typing_extensions import Literal + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec From dcfa461ed3a6d2ea0da726ed7058324a0b779e65 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 25 Jul 2022 12:22:12 +0800 Subject: [PATCH 2/2] More type fixes --- .../kubernetes/operators/kubernetes_pod.py | 2 +- airflow/providers/dbt/cloud/hooks/dbt.py | 2 +- .../providers/google/cloud/operators/gcs.py | 28 ++++++++++--------- airflow/providers/qubole/hooks/qubole.py | 3 +- .../providers/salesforce/operators/bulk.py | 2 +- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index 90d4fb88de9c1..ef2ebed6ae0ef 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -328,7 +328,7 @@ def _get_ti_pod_labels( if include_try_number: labels.update(try_number=ti.try_number) # In the case of sub dags this is just useful - if context['dag'].is_subdag: + if context['dag'].parent_dag: labels['parent_dag_id'] = context['dag'].parent_dag.dag_id # Ensure that label is valid for Kube, # and if not truncate/remove invalid chars and replace with short hash. diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index 07a327a04596d..fd370502a5cde 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -87,7 +87,7 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: class JobRunInfo(TypedDict): """Type class for the ``job_run_info`` dictionary.""" - account_id: int + account_id: Optional[int] run_id: int diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index 7c58ee7c5a704..77a1ab656e60f 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -30,7 +30,6 @@ from google.api_core.exceptions import Conflict from google.cloud.exceptions import GoogleCloudError -from pendulum.datetime import DateTime from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -723,22 +722,25 @@ def __init__( def execute(self, context: "Context") -> List[str]: # Define intervals and prefixes. try: - timespan_start = context["data_interval_start"] - timespan_end = context["data_interval_end"] + orig_start = context["data_interval_start"] + orig_end = context["data_interval_end"] except KeyError: - timespan_start = pendulum.instance(context["execution_date"]) + orig_start = pendulum.instance(context["execution_date"]) following_execution_date = context["dag"].following_schedule(context["execution_date"]) if following_execution_date is None: - timespan_end = None + orig_end = None else: - timespan_end = pendulum.instance(following_execution_date) - - if timespan_end is None: # Only possible in Airflow before 2.2. - self.log.warning("No following schedule found, setting timespan end to max %s", timespan_end) - timespan_end = DateTime.max - elif timespan_start >= timespan_end: # Airflow 2.2 sets start == end for non-perodic schedules. - self.log.warning("DAG schedule not periodic, setting timespan end to max %s", timespan_end) - timespan_end = DateTime.max + orig_end = pendulum.instance(following_execution_date) + + timespan_start = orig_start + if orig_end is None: # Only possible in Airflow before 2.2. + self.log.warning("No following schedule found, setting timespan end to max %s", orig_end) + timespan_end = pendulum.instance(datetime.datetime.max) + elif orig_start >= orig_end: # Airflow 2.2 sets start == end for non-perodic schedules. + self.log.warning("DAG schedule not periodic, setting timespan end to max %s", orig_end) + timespan_end = pendulum.instance(datetime.datetime.max) + else: + timespan_end = orig_end timespan_start = timespan_start.in_timezone(timezone.utc) timespan_end = timespan_end.in_timezone(timezone.utc) diff --git a/airflow/providers/qubole/hooks/qubole.py b/airflow/providers/qubole/hooks/qubole.py index 3b0d4bdd1a5f3..340cf4fe131af 100644 --- a/airflow/providers/qubole/hooks/qubole.py +++ b/airflow/providers/qubole/hooks/qubole.py @@ -46,6 +46,7 @@ from airflow.utils.state import State if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context @@ -139,7 +140,7 @@ def __init__(self, *args, **kwargs) -> None: self.kwargs = kwargs self.cls = COMMAND_CLASSES[self.kwargs['command_type']] self.cmd: Optional[Command] = None - self.task_instance = None + self.task_instance: Optional["TaskInstance"] = None @staticmethod def handle_failure_retry(context) -> None: diff --git a/airflow/providers/salesforce/operators/bulk.py b/airflow/providers/salesforce/operators/bulk.py index 110ed685ea1fd..39de722032807 100644 --- a/airflow/providers/salesforce/operators/bulk.py +++ b/airflow/providers/salesforce/operators/bulk.py @@ -47,7 +47,7 @@ class SalesforceBulkOperator(BaseOperator): def __init__( self, *, - operation: Literal[available_operations], + operation: Literal['insert', 'update', 'upsert', 'delete', 'hard_delete'], object_name: str, payload: list, external_id_field: str = 'Id',