diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index 726798f1341e3..35a56e231e6bd 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -20,9 +20,9 @@ # necessarily exist at run time. See "Creating Custom @task Decorators" # documentation for more details. -from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union, overload -from airflow.decorators.base import Function, Task, TaskDecorator +from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator from airflow.decorators.branch_python import branch_task from airflow.decorators.python import python_task from airflow.decorators.python_virtualenv import virtualenv_task @@ -68,7 +68,7 @@ class TaskDecoratorCollection: """ # [START mixin_for_typing] @overload - def python(self, python_callable: Function) -> Task[Function]: ... + def python(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... # [END mixin_for_typing] @overload def __call__( @@ -81,7 +81,7 @@ class TaskDecoratorCollection: ) -> TaskDecorator: """Aliasing ``python``; signature should match exactly.""" @overload - def __call__(self, python_callable: Function) -> Task[Function]: + def __call__(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: """Aliasing ``python``; signature should match exactly.""" @overload def virtualenv( @@ -122,7 +122,7 @@ class TaskDecoratorCollection: such as transmission a large amount of XCom to TaskAPI. """ @overload - def virtualenv(self, python_callable: Function) -> Task[Function]: ... + def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... @overload def branch(self, *, multiple_outputs: Optional[bool] = None, **kwargs) -> TaskDecorator: """Create a decorator to wrap the decorated callable into a BranchPythonOperator. @@ -134,7 +134,7 @@ class TaskDecoratorCollection: Dict will unroll to XCom values with keys as XCom keys. Defaults to False. """ @overload - def branch(self, python_callable: Function) -> Task[Function]: ... + def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... # [START decorator_signature] def docker( self, diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index a36d7d5d43a7b..0a4b75cecedb9 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import functools import inspect import re from typing import ( @@ -68,7 +67,7 @@ ) from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg -from airflow.typing_compat import Protocol +from airflow.typing_compat import ParamSpec, Protocol from airflow.utils import timezone from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context from airflow.utils.task_group import TaskGroup, TaskGroupContext @@ -236,13 +235,15 @@ def _hook_apply_defaults(self, *args, **kwargs): return args, kwargs -Function = TypeVar("Function", bound=Callable) +FParams = ParamSpec("FParams") + +FReturn = TypeVar("FReturn") OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") @attr.define(slots=False) -class _TaskDecorator(Generic[Function, OperatorSubclass]): +class _TaskDecorator(Generic[FParams, FReturn, OperatorSubclass]): """ Helper class for providing dynamic task mapping to decorated functions. @@ -251,7 +252,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]): :meta private: """ - function: Function = attr.ib() + function: Callable[FParams, FReturn] = attr.ib() operator_class: Type[OperatorSubclass] multiple_outputs: bool = attr.ib() kwargs: Dict[str, Any] = attr.ib(factory=dict) @@ -272,7 +273,7 @@ def __attrs_post_init__(self): raise TypeError(f"@{self.decorator_name} does not support methods") self.kwargs.setdefault('task_id', self.function.__name__) - def __call__(self, *args, **kwargs) -> XComArg: + def __call__(self, *args: "FParams.args", **kwargs: "FParams.kwargs") -> XComArg: op = self.operator_class( python_callable=self.function, op_args=args, @@ -285,7 +286,7 @@ def __call__(self, *args, **kwargs) -> XComArg: return XComArg(op) @property - def __wrapped__(self) -> Function: + def __wrapped__(self) -> Callable[FParams, FReturn]: return self.function @cached_property @@ -337,9 +338,7 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg: # to False to skip the checks on execution. return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) - def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> XComArg: - from airflow.models.xcom_arg import XComArg - + def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg: if not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) @@ -420,14 +419,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ) return XComArg(operator=operator) - def partial(self, **kwargs: Any) -> "_TaskDecorator[Function, OperatorSubclass]": + def partial(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, OperatorSubclass]": self._validate_arg_names("partial", kwargs) old_kwargs = self.kwargs.get("op_kwargs", {}) prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial") kwargs.update(old_kwargs) return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs}) - def override(self, **kwargs: Any) -> "_TaskDecorator[Function, OperatorSubclass]": + def override(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, OperatorSubclass]": return attr.evolve(self, kwargs={**self.kwargs, **kwargs}) @@ -506,7 +505,7 @@ def _render_if_not_already_resolved(key: str, value: Any): return {k: _render_if_not_already_resolved(k, v) for k, v in value.items()} -class Task(Generic[Function]): +class Task(Generic[FParams, FReturn]): """Declaration of a @task-decorated callable for type-checking. An instance of this type inherits the call signature of the decorated @@ -517,18 +516,21 @@ class Task(Generic[Function]): This type is implemented by ``_TaskDecorator`` at runtime. """ - __call__: Function + __call__: Callable[FParams, XComArg] - function: Function + function: Callable[FParams, FReturn] @property - def __wrapped__(self) -> Function: + def __wrapped__(self) -> Callable[FParams, FReturn]: + ... + + def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]": ... def expand(self, **kwargs: "Mappable") -> XComArg: ... - def partial(self, **kwargs: Any) -> "Task[Function]": + def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg: ... @@ -536,7 +538,10 @@ class TaskDecorator(Protocol): """Type declaration for ``task_decorator_factory`` return type.""" @overload - def __call__(self, python_callable: Function) -> Task[Function]: + def __call__( # type: ignore[misc] + self, + python_callable: Callable[FParams, FReturn], + ) -> Task[FParams, FReturn]: """For the "bare decorator" ``@task`` case.""" @overload @@ -545,7 +550,7 @@ def __call__( *, multiple_outputs: Optional[bool] = None, **kwargs: Any, - ) -> Callable[[Function], Task[Function]]: + ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]: """For the decorator factory ``@task()`` case.""" @@ -556,16 +561,20 @@ def task_decorator_factory( decorated_operator_class: Type[BaseOperator], **kwargs, ) -> TaskDecorator: - """ - A factory that generates a wrapper that wraps a function into an Airflow operator. - Accepts kwargs for operator kwarg. Can be reused in a single DAG. + """Generate a wrapper that wraps a function into an Airflow operator. - :param python_callable: Function to decorate - :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to - multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. - :param decorated_operator_class: The operator that executes the logic needed to run the python function in - the correct environment + Can be reused in a single DAG. + :param python_callable: Function to decorate. + :param multiple_outputs: If set to True, the decorated function's return + value will be unrolled to multiple XCom values. Dict will unroll to XCom + values with its keys as XCom keys. If set to False (default), only at + most one XCom value is pushed. + :param decorated_operator_class: The operator that executes the logic needed + to run the python function in the correct environment. + + Other kwargs are directly forwarded to the underlying operator class when + it's instantiated. """ if multiple_outputs is None: multiple_outputs = cast(bool, attr.NOTHING) @@ -579,10 +588,13 @@ def task_decorator_factory( return cast(TaskDecorator, decorator) elif python_callable is not None: raise TypeError('No args allowed while using @task, use kwargs instead') - decorator_factory = functools.partial( - _TaskDecorator, - multiple_outputs=multiple_outputs, - operator_class=decorated_operator_class, - kwargs=kwargs, - ) + + def decorator_factory(python_callable): + return _TaskDecorator( + function=python_callable, + multiple_outputs=multiple_outputs, + operator_class=decorated_operator_class, + kwargs=kwargs, + ) + return cast(TaskDecorator, decorator_factory) diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index 90d4fb88de9c1..ef2ebed6ae0ef 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -328,7 +328,7 @@ def _get_ti_pod_labels( if include_try_number: labels.update(try_number=ti.try_number) # In the case of sub dags this is just useful - if context['dag'].is_subdag: + if context['dag'].parent_dag: labels['parent_dag_id'] = context['dag'].parent_dag.dag_id # Ensure that label is valid for Kube, # and if not truncate/remove invalid chars and replace with short hash. diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index 07a327a04596d..fd370502a5cde 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -87,7 +87,7 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: class JobRunInfo(TypedDict): """Type class for the ``job_run_info`` dictionary.""" - account_id: int + account_id: Optional[int] run_id: int diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index 7c58ee7c5a704..77a1ab656e60f 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -30,7 +30,6 @@ from google.api_core.exceptions import Conflict from google.cloud.exceptions import GoogleCloudError -from pendulum.datetime import DateTime from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -723,22 +722,25 @@ def __init__( def execute(self, context: "Context") -> List[str]: # Define intervals and prefixes. try: - timespan_start = context["data_interval_start"] - timespan_end = context["data_interval_end"] + orig_start = context["data_interval_start"] + orig_end = context["data_interval_end"] except KeyError: - timespan_start = pendulum.instance(context["execution_date"]) + orig_start = pendulum.instance(context["execution_date"]) following_execution_date = context["dag"].following_schedule(context["execution_date"]) if following_execution_date is None: - timespan_end = None + orig_end = None else: - timespan_end = pendulum.instance(following_execution_date) - - if timespan_end is None: # Only possible in Airflow before 2.2. - self.log.warning("No following schedule found, setting timespan end to max %s", timespan_end) - timespan_end = DateTime.max - elif timespan_start >= timespan_end: # Airflow 2.2 sets start == end for non-perodic schedules. - self.log.warning("DAG schedule not periodic, setting timespan end to max %s", timespan_end) - timespan_end = DateTime.max + orig_end = pendulum.instance(following_execution_date) + + timespan_start = orig_start + if orig_end is None: # Only possible in Airflow before 2.2. + self.log.warning("No following schedule found, setting timespan end to max %s", orig_end) + timespan_end = pendulum.instance(datetime.datetime.max) + elif orig_start >= orig_end: # Airflow 2.2 sets start == end for non-perodic schedules. + self.log.warning("DAG schedule not periodic, setting timespan end to max %s", orig_end) + timespan_end = pendulum.instance(datetime.datetime.max) + else: + timespan_end = orig_end timespan_start = timespan_start.in_timezone(timezone.utc) timespan_end = timespan_end.in_timezone(timezone.utc) diff --git a/airflow/providers/qubole/hooks/qubole.py b/airflow/providers/qubole/hooks/qubole.py index 3b0d4bdd1a5f3..340cf4fe131af 100644 --- a/airflow/providers/qubole/hooks/qubole.py +++ b/airflow/providers/qubole/hooks/qubole.py @@ -46,6 +46,7 @@ from airflow.utils.state import State if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context @@ -139,7 +140,7 @@ def __init__(self, *args, **kwargs) -> None: self.kwargs = kwargs self.cls = COMMAND_CLASSES[self.kwargs['command_type']] self.cmd: Optional[Command] = None - self.task_instance = None + self.task_instance: Optional["TaskInstance"] = None @staticmethod def handle_failure_retry(context) -> None: diff --git a/airflow/providers/salesforce/operators/bulk.py b/airflow/providers/salesforce/operators/bulk.py index 110ed685ea1fd..39de722032807 100644 --- a/airflow/providers/salesforce/operators/bulk.py +++ b/airflow/providers/salesforce/operators/bulk.py @@ -47,7 +47,7 @@ class SalesforceBulkOperator(BaseOperator): def __init__( self, *, - operation: Literal[available_operations], + operation: Literal['insert', 'update', 'upsert', 'delete', 'hard_delete'], object_name: str, payload: list, external_id_field: str = 'Id', diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py index 163889b8a2975..ec1846438fc67 100644 --- a/airflow/typing_compat.py +++ b/airflow/typing_compat.py @@ -21,10 +21,28 @@ codebase easier. """ -try: - # Literal, Protocol and TypedDict are only added to typing module starting from - # python 3.8 we can safely remove this shim import after Airflow drops - # support for <3.8 - from typing import Literal, Protocol, TypedDict, runtime_checkable # type: ignore -except ImportError: - from typing_extensions import Literal, Protocol, TypedDict, runtime_checkable # type: ignore # noqa +__all__ = [ + "Literal", + "ParamSpec", + "Protocol", + "TypedDict", + "runtime_checkable", +] + +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol, TypedDict, runtime_checkable +else: + from typing_extensions import Protocol, TypedDict, runtime_checkable + +# Literal in 3.8 is limited to one single argument, not e.g. "Literal[1, 2]". +if sys.version_info >= (3, 9): + from typing import Literal +else: + from typing_extensions import Literal + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec