diff --git a/airflow-core/src/airflow/config_templates/airflow_local_settings.py b/airflow-core/src/airflow/config_templates/airflow_local_settings.py index 5ac7f513d1f87..e06d3a4e7e263 100644 --- a/airflow-core/src/airflow/config_templates/airflow_local_settings.py +++ b/airflow-core/src/airflow/config_templates/airflow_local_settings.py @@ -20,12 +20,15 @@ from __future__ import annotations import os -from typing import Any +from typing import TYPE_CHECKING, Any from urllib.parse import urlsplit from airflow.configuration import conf from airflow.exceptions import AirflowException +if TYPE_CHECKING: + from airflow.logging_config import RemoteLogIO + LOG_LEVEL: str = conf.get_mandatory_value("logging", "LOGGING_LEVEL").upper() @@ -48,7 +51,7 @@ DAG_PROCESSOR_LOG_TARGET: str = conf.get_mandatory_value("logging", "DAG_PROCESSOR_LOG_TARGET") -BASE_LOG_FOLDER: str = conf.get_mandatory_value("logging", "BASE_LOG_FOLDER") +BASE_LOG_FOLDER: str = os.path.expanduser(conf.get_mandatory_value("logging", "BASE_LOG_FOLDER")) PROCESSOR_LOG_FOLDER: str = conf.get_mandatory_value("scheduler", "CHILD_PROCESS_LOG_DIRECTORY") @@ -84,7 +87,7 @@ "task": { "class": "airflow.utils.log.file_task_handler.FileTaskHandler", "formatter": "airflow", - "base_log_folder": os.path.expanduser(BASE_LOG_FOLDER), + "base_log_folder": BASE_LOG_FOLDER, "filters": ["mask_secrets"], }, }, @@ -126,6 +129,7 @@ ################## REMOTE_LOGGING: bool = conf.getboolean("logging", "remote_logging") +REMOTE_TASK_LOG: RemoteLogIO | None = None if REMOTE_LOGGING: ELASTICSEARCH_HOST: str | None = conf.get("elasticsearch", "HOST") @@ -137,64 +141,86 @@ # WASB buckets should start with "wasb" # HDFS path should start with "hdfs://" # just to help Airflow select correct handler - REMOTE_BASE_LOG_FOLDER: str = conf.get_mandatory_value("logging", "REMOTE_BASE_LOG_FOLDER") - REMOTE_TASK_HANDLER_KWARGS = conf.getjson("logging", "REMOTE_TASK_HANDLER_KWARGS", fallback={}) + remote_base_log_folder: str = conf.get_mandatory_value("logging", "remote_base_log_folder") + remote_task_handler_kwargs = conf.getjson("logging", "remote_task_handler_kwargs", fallback={}) + if not isinstance(remote_task_handler_kwargs, dict): + raise ValueError( + "logging/remote_task_handler_kwargs must be a JSON object (a python dict), we got " + f"{type(remote_task_handler_kwargs)}" + ) + delete_local_copy = conf.getboolean("logging", "delete_local_logs") - if REMOTE_BASE_LOG_FOLDER.startswith("s3://"): - S3_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = { - "task": { - "class": "airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler", - "formatter": "airflow", - "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), - "s3_log_folder": REMOTE_BASE_LOG_FOLDER, - }, - } + if remote_base_log_folder.startswith("s3://"): + from airflow.providers.amazon.aws.log.s3_task_handler import S3RemoteLogIO - DEFAULT_LOGGING_CONFIG["handlers"].update(S3_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith("cloudwatch://"): - url_parts = urlsplit(REMOTE_BASE_LOG_FOLDER) - CLOUDWATCH_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = { - "task": { - "class": "airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler", - "formatter": "airflow", - "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), - "log_group_arn": url_parts.netloc + url_parts.path, - }, - } + REMOTE_TASK_LOG = S3RemoteLogIO( + **( + { + "base_log_folder": BASE_LOG_FOLDER, + "remote_base": remote_base_log_folder, + "delete_local_copy": delete_local_copy, + } + | remote_task_handler_kwargs + ) + ) + remote_task_handler_kwargs = {} - DEFAULT_LOGGING_CONFIG["handlers"].update(CLOUDWATCH_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith("gs://"): - key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None) - GCS_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = { - "task": { - "class": "airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler", - "formatter": "airflow", - "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), - "gcs_log_folder": REMOTE_BASE_LOG_FOLDER, - "gcp_key_path": key_path, - }, - } + elif remote_base_log_folder.startswith("cloudwatch://"): + from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudWatchRemoteLogIO + + url_parts = urlsplit(remote_base_log_folder) + REMOTE_TASK_LOG = CloudWatchRemoteLogIO( + **( + { + "base_log_folder": BASE_LOG_FOLDER, + "remote_base": remote_base_log_folder, + "delete_local_copy": delete_local_copy, + "log_group_arn": url_parts.netloc + url_parts.path, + } + | remote_task_handler_kwargs + ) + ) + remote_task_handler_kwargs = {} + elif remote_base_log_folder.startswith("gs://"): + from airflow.providers.google.cloud.logs.gcs_task_handler import GCSRemoteLogIO + + key_path = conf.get_mandatory_value("logging", "google_key_path", fallback=None) + + REMOTE_TASK_LOG = GCSRemoteLogIO( + **( + { + "base_log_folder": BASE_LOG_FOLDER, + "remote_base": remote_base_log_folder, + "delete_local_copy": delete_local_copy, + "gcp_key_path": key_path, + } + | remote_task_handler_kwargs + ) + ) + remote_task_handler_kwargs = {} + elif remote_base_log_folder.startswith("wasb"): + from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO - DEFAULT_LOGGING_CONFIG["handlers"].update(GCS_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith("wasb"): wasb_log_container = conf.get_mandatory_value( "azure_remote_logging", "remote_wasb_log_container", fallback="airflow-logs" ) - WASB_REMOTE_HANDLERS: dict[str, dict[str, str | bool | None]] = { - "task": { - "class": "airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler", - "formatter": "airflow", - "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), - "wasb_log_folder": REMOTE_BASE_LOG_FOLDER, - "wasb_container": wasb_log_container, - }, - } - DEFAULT_LOGGING_CONFIG["handlers"].update(WASB_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith("stackdriver://"): + REMOTE_TASK_LOG = WasbRemoteLogIO( + **( + { + "base_log_folder": BASE_LOG_FOLDER, + "remote_base": remote_base_log_folder, + "delete_local_copy": delete_local_copy, + "wasb_container": wasb_log_container, + } + | remote_task_handler_kwargs + ) + ) + remote_task_handler_kwargs = {} + elif remote_base_log_folder.startswith("stackdriver://"): key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None) # stackdriver:///airflow-tasks => airflow-tasks - log_name = urlsplit(REMOTE_BASE_LOG_FOLDER).path[1:] + log_name = urlsplit(remote_base_log_folder).path[1:] STACKDRIVER_REMOTE_HANDLERS = { "task": { "class": "airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler", @@ -205,23 +231,27 @@ } DEFAULT_LOGGING_CONFIG["handlers"].update(STACKDRIVER_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith("oss://"): - OSS_REMOTE_HANDLERS = { - "task": { - "class": "airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler", - "formatter": "airflow", - "base_log_folder": os.path.expanduser(BASE_LOG_FOLDER), - "oss_log_folder": REMOTE_BASE_LOG_FOLDER, - }, - } - DEFAULT_LOGGING_CONFIG["handlers"].update(OSS_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith("hdfs://"): + elif remote_base_log_folder.startswith("oss://"): + from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSRemoteLogIO + + REMOTE_TASK_LOG = OSSRemoteLogIO( + **( + { + "base_log_folder": BASE_LOG_FOLDER, + "remote_base": remote_base_log_folder, + "delete_local_copy": delete_local_copy, + } + | remote_task_handler_kwargs + ) + ) + remote_task_handler_kwargs = {} + elif remote_base_log_folder.startswith("hdfs://"): HDFS_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = { "task": { "class": "airflow.providers.apache.hdfs.log.hdfs_task_handler.HdfsTaskHandler", "formatter": "airflow", - "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), - "hdfs_log_folder": REMOTE_BASE_LOG_FOLDER, + "base_log_folder": BASE_LOG_FOLDER, + "hdfs_log_folder": remote_base_log_folder, }, } DEFAULT_LOGGING_CONFIG["handlers"].update(HDFS_REMOTE_HANDLERS) @@ -240,7 +270,7 @@ "task": { "class": "airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler", "formatter": "airflow", - "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), + "base_log_folder": BASE_LOG_FOLDER, "end_of_log_mark": ELASTICSEARCH_END_OF_LOG_MARK, "host": ELASTICSEARCH_HOST, "frontend": ELASTICSEARCH_FRONTEND, @@ -270,7 +300,7 @@ "task": { "class": "airflow.providers.opensearch.log.os_task_handler.OpensearchTaskHandler", "formatter": "airflow", - "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), + "base_log_folder": BASE_LOG_FOLDER, "end_of_log_mark": OPENSEARCH_END_OF_LOG_MARK, "host": OPENSEARCH_HOST, "port": OPENSEARCH_PORT, @@ -290,4 +320,4 @@ "section 'elasticsearch' if you are using Elasticsearch. In the other case, " "'remote_base_log_folder' option in the 'logging' section." ) - DEFAULT_LOGGING_CONFIG["handlers"]["task"].update(REMOTE_TASK_HANDLER_KWARGS) + DEFAULT_LOGGING_CONFIG["handlers"]["task"].update(remote_task_handler_kwargs) diff --git a/airflow-core/src/airflow/logging/__init__.py b/airflow-core/src/airflow/logging/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow-core/src/airflow/logging/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow-core/src/airflow/logging/remote.py b/airflow-core/src/airflow/logging/remote.py new file mode 100644 index 0000000000000..8c55cba546495 --- /dev/null +++ b/airflow-core/src/airflow/logging/remote.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + import structlog.typing + + from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo + + +class RemoteLogIO(Protocol): + """Interface for remote task loggers.""" + + @property + def processors(self) -> tuple[structlog.typing.Processor, ...]: ... + + """ + List of structlog processors to install in the task write path. + + This is useful if a remote logging provider wants to either transform the structured log messages as they + are being written to a file, or if you want to upload messages as they are generated. + """ + + def upload(self, path: os.PathLike | str) -> None: + """Upload the given log path to the remote storage.""" + ... + + def read(self, relative_path: str) -> tuple[LogSourceInfo, LogMessages | None]: + """Read logs from the given remote log path.""" + ... diff --git a/airflow-core/src/airflow/logging_config.py b/airflow-core/src/airflow/logging_config.py index f0497c57409ed..b0c0b35515599 100644 --- a/airflow-core/src/airflow/logging_config.py +++ b/airflow-core/src/airflow/logging_config.py @@ -20,39 +20,68 @@ import logging import warnings from logging.config import dictConfig +from typing import TYPE_CHECKING, Any from airflow.configuration import conf from airflow.exceptions import AirflowConfigException from airflow.utils.module_loading import import_string +if TYPE_CHECKING: + from airflow.logging.remote import RemoteLogIO + log = logging.getLogger(__name__) -def configure_logging(): +REMOTE_TASK_LOG: RemoteLogIO | None + + +def __getattr__(name: str): + if name == "REMOTE_TASK_LOG": + global REMOTE_TASK_LOG + load_logging_config() + return REMOTE_TASK_LOG + + +def load_logging_config() -> tuple[dict[str, Any], str]: """Configure & Validate Airflow Logging.""" - logging_class_path = "" - try: - logging_class_path = conf.get("logging", "logging_config_class") - except AirflowConfigException: - log.debug("Could not find key logging_config_class in config") + global REMOTE_TASK_LOG + fallback = "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG" + logging_class_path = conf.get("logging", "logging_config_class", fallback=fallback) - if logging_class_path: - try: - logging_config = import_string(logging_class_path) + # Sometimes we end up with `""` as the value! + logging_class_path = logging_class_path or fallback + + user_defined = logging_class_path != fallback + + try: + logging_config = import_string(logging_class_path) - # Make sure that the variable is in scope - if not isinstance(logging_config, dict): - raise ValueError("Logging Config should be of dict type") + # Make sure that the variable is in scope + if not isinstance(logging_config, dict): + raise ValueError("Logging Config should be of dict type") + if user_defined: log.info("Successfully imported user-defined logging config from %s", logging_class_path) - except Exception as err: - # Import default logging configurations. - raise ImportError(f"Unable to load custom logging from {logging_class_path} due to {err}") + + except Exception as err: + # Import default logging configurations. + raise ImportError( + f"Unable to load {'custom ' if user_defined else ''}logging config from {logging_class_path} due " + f"to: {type(err).__name__}:{err}" + ) else: - logging_class_path = "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG" - logging_config = import_string(logging_class_path) - log.debug("Unable to load custom logging, using default config instead") + mod = logging_class_path.rsplit(".", 1)[0] + try: + remote_task_log = import_string(f"{mod}.REMOTE_TASK_LOG") + REMOTE_TASK_LOG = remote_task_log + except Exception as err: + log.info("Remote task logs will not be available due to an error: %s", err) + + return logging_config, logging_class_path + +def configure_logging(): + logging_config, logging_class_path = load_logging_config() try: # Ensure that the password masking filter is applied to the 'task' handler # no matter what the user did. diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index 2581c7f1e749b..d2f44fd38d39e 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -27,7 +27,7 @@ from datetime import datetime from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Union from urllib.parse import urljoin import pendulum @@ -46,6 +46,15 @@ from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.typing_compat import TypeAlias + + +# These types are similar, but have distinct names to make processing them less error prone +LogMessages: TypeAlias = Union[list["StructuredLogMessage"], list[str]] +"""The log messages themselves, either in already sturcutured form, or a single string blob to be parsed later""" +LogSourceInfo: TypeAlias = list[str] +"""Information _about_ the log fetching process for display to a user""" +LogMetadata: TypeAlias = dict[str, Any] logger = logging.getLogger(__name__) @@ -124,31 +133,44 @@ def _parse_timestamp(line: str): return pendulum.parse(timestamp_str.strip("[]")) -def _parse_log_lines(lines: Iterable[str]) -> Iterable[tuple[datetime | None, int, StructuredLogMessage]]: +def _parse_log_lines( + lines: str | LogMessages, +) -> Iterable[tuple[datetime | None, int, StructuredLogMessage]]: from airflow.utils.timezone import coerce_datetime timestamp = None next_timestamp = None - for idx, line in enumerate(lines): + if isinstance(lines, str): + lines = lines.splitlines() + if isinstance(lines, list) and len(lines) and isinstance(lines[0], str): + # A list of content from each location. It's a super odd format, but this is what we load + # [['a\nb\n'], ['c\nd\ne\n']] -> ['a', 'b', 'c', 'd', 'e'] + lines = itertools.chain.from_iterable(map(str.splitlines, lines)) # type: ignore[assignment,arg-type] + + # https://github.com/python/mypy/issues/8586 + for idx, line in enumerate[Union[str, StructuredLogMessage]](lines): if line: try: - # Try to parse it as json first - log = StructuredLogMessage.model_validate_json(line) + if isinstance(line, StructuredLogMessage): + log = line + else: + log = StructuredLogMessage.model_validate_json(line) except ValidationError: with suppress(Exception): # If we can't parse the timestamp, don't attach one to the row - next_timestamp = _parse_timestamp(line) - log = StructuredLogMessage(event=line, timestamp=next_timestamp) + if isinstance(line, str): + next_timestamp = _parse_timestamp(line) + log = StructuredLogMessage(event=str(line), timestamp=next_timestamp) if log.timestamp: log.timestamp = coerce_datetime(log.timestamp) timestamp = log.timestamp yield timestamp, idx, log -def _interleave_logs(*logs: str) -> Iterable[StructuredLogMessage]: +def _interleave_logs(*logs: str | LogMessages) -> Iterable[StructuredLogMessage]: min_date = pendulum.datetime(2000, 1, 1) - records = itertools.chain.from_iterable(_parse_log_lines(log.splitlines()) for log in logs) + records = itertools.chain.from_iterable(_parse_log_lines(log) for log in logs) last = None for timestamp, _, msg in sorted(records, key=lambda x: (x[0] or min_date, x[1])): if msg != last or not timestamp: # dedupe @@ -372,13 +394,14 @@ def _read( # is needed to get correct log path. worker_log_rel_path = self._render_filename(ti, try_number) source_list: list[str] = [] - remote_logs: list[str] = [] + remote_logs: LogMessages | None = [] local_logs: list[str] = [] sources: list[str] = [] executor_logs: list[str] = [] - served_logs: list[str] = [] + served_logs: LogMessages = [] with suppress(NotImplementedError): sources, remote_logs = self._read_remote_logs(ti, try_number, metadata) + source_list.extend(sources) has_k8s_exec_pod = False if ti.state == TaskInstanceState.RUNNING: @@ -407,7 +430,7 @@ def _read( logs = list( _interleave_logs( *local_logs, - *remote_logs, + (remote_logs or []), *(executor_logs or []), *served_logs, ) @@ -558,7 +581,7 @@ def _read_from_local(worker_log_path: Path) -> tuple[list[str], list[str]]: logs = [file.read_text() for file in paths] return sources, logs - def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[list[str], list[str]]: + def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[LogSourceInfo, LogMessages]: sources = [] logs = [] try: @@ -590,7 +613,7 @@ def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[list[str], li logger.exception("Could not read served logs") return sources, logs - def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]: + def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]: """ Implement in subclasses to read from the remote service. @@ -600,4 +623,20 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], l such as, "reading from x file". * Each element in the logs list should be the content of one file. """ - raise NotImplementedError + remote_io = None + try: + from airflow.logging_config import REMOTE_TASK_LOG + + remote_io = REMOTE_TASK_LOG + except Exception: + pass + + if remote_io is None: + # Import not found, or explicitly set to None + raise NotImplementedError + + # This living here is not really a good plan, but it just about works for now. + # Ideally we move all the read+combine logic in to TaskLogReader and out of the task handler. + path = self._render_filename(ti, try_number) + sources, logs = remote_io.read(path) + return sources, logs or [] diff --git a/airflow-core/tests/unit/core/test_logging_config.py b/airflow-core/tests/unit/core/test_logging_config.py index 28a1a5f249d05..62e82a1cffba3 100644 --- a/airflow-core/tests/unit/core/test_logging_config.py +++ b/airflow-core/tests/unit/core/test_logging_config.py @@ -214,23 +214,11 @@ def test_loading_valid_complex_local_settings(self): with patch.object(log, "info") as mock_info: configure_logging() - mock_info.assert_called_once_with( + mock_info.assert_any_call( "Successfully imported user-defined logging config from %s", f"etc.airflow.config.{SETTINGS_DEFAULT_NAME}.LOGGING_CONFIG", ) - # When we try to load a valid config - def test_loading_valid_local_settings(self): - with settings_context(SETTINGS_FILE_VALID): - from airflow.logging_config import configure_logging, log - - with patch.object(log, "info") as mock_info: - configure_logging() - mock_info.assert_called_once_with( - "Successfully imported user-defined logging config from %s", - f"{SETTINGS_DEFAULT_NAME}.LOGGING_CONFIG", - ) - # When we load an empty file, it should go to default def test_loading_no_local_settings(self): with settings_context(SETTINGS_FILE_EMPTY): @@ -239,23 +227,6 @@ def test_loading_no_local_settings(self): with pytest.raises(ImportError): configure_logging() - # When the key is not available in the configuration - def test_when_the_config_key_does_not_exists(self): - from airflow import logging_config - - with conf_vars({("logging", "logging_config_class"): None}): - with patch.object(logging_config.log, "debug") as mock_debug: - logging_config.configure_logging() - mock_debug.assert_any_call("Unable to load custom logging, using default config instead") - - # Just default - def test_loading_local_settings_without_logging_config(self): - from airflow.logging_config import configure_logging, log - - with patch.object(log, "debug") as mock_info: - configure_logging() - mock_info.assert_called_once_with("Unable to load custom logging, using default config instead") - def test_1_9_config(self): from airflow.logging_config import configure_logging @@ -269,9 +240,9 @@ def test_loading_remote_logging_with_wasb_handler(self): pytest.importorskip( "airflow.providers.microsoft.azure", reason="'microsoft.azure' provider not installed" ) + import airflow.logging_config from airflow.config_templates import airflow_local_settings - from airflow.logging_config import configure_logging - from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbTaskHandler + from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO with conf_vars( { @@ -281,10 +252,9 @@ def test_loading_remote_logging_with_wasb_handler(self): } ): importlib.reload(airflow_local_settings) - configure_logging() + airflow.logging_config.configure_logging() - logger = logging.getLogger("airflow.task") - assert isinstance(logger.handlers[0], WasbTaskHandler) + assert isinstance(airflow.logging_config.REMOTE_TASK_LOG, WasbRemoteLogIO) @pytest.mark.parametrize( "remote_base_log_folder, log_group_arn", @@ -307,8 +277,9 @@ def test_log_group_arns_remote_logging_with_cloudwatch_handler( self, remote_base_log_folder, log_group_arn ): """Test if the correct ARNs are configured for Cloudwatch""" + import airflow.logging_config from airflow.config_templates import airflow_local_settings - from airflow.logging_config import configure_logging + from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudWatchRemoteLogIO with conf_vars( { @@ -318,18 +289,18 @@ def test_log_group_arns_remote_logging_with_cloudwatch_handler( } ): importlib.reload(airflow_local_settings) - configure_logging() - assert ( - airflow_local_settings.DEFAULT_LOGGING_CONFIG["handlers"]["task"]["log_group_arn"] - == log_group_arn - ) + airflow.logging_config.configure_logging() + + remote_io = airflow.logging_config.REMOTE_TASK_LOG + assert isinstance(remote_io, CloudWatchRemoteLogIO) + assert remote_io.log_group_arn == log_group_arn def test_loading_remote_logging_with_kwargs(self): """Test if logging can be configured successfully with kwargs""" pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed") + import airflow.logging_config from airflow.config_templates import airflow_local_settings - from airflow.logging_config import configure_logging - from airflow.providers.amazon.aws.log.s3_task_handler import S3TaskHandler + from airflow.providers.amazon.aws.log.s3_task_handler import S3RemoteLogIO with conf_vars( { @@ -340,8 +311,8 @@ def test_loading_remote_logging_with_kwargs(self): } ): importlib.reload(airflow_local_settings) - configure_logging() + airflow.logging_config.configure_logging() - logger = logging.getLogger("airflow.task") - assert isinstance(logger.handlers[0], S3TaskHandler) - assert getattr(logger.handlers[0], "delete_local_copy") is True + task_log = airflow.logging_config.REMOTE_TASK_LOG + assert isinstance(task_log, S3RemoteLogIO) + assert getattr(task_log, "delete_local_copy") is True diff --git a/airflow-core/tests/unit/logging/__init__.py b/airflow-core/tests/unit/logging/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow-core/tests/unit/logging/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py b/dev/breeze/tests/test_pytest_args_for_test_types.py index 994b2c1be5ced..3dd0ba2dea36d 100644 --- a/dev/breeze/tests/test_pytest_args_for_test_types.py +++ b/dev/breeze/tests/test_pytest_args_for_test_types.py @@ -163,6 +163,7 @@ def _find_all_integration_folders() -> list[str]: "airflow-core/tests/unit/io", "airflow-core/tests/unit/lineage", "airflow-core/tests/unit/listeners", + "airflow-core/tests/unit/logging", "airflow-core/tests/unit/macros", "airflow-core/tests/unit/plugins", "airflow-core/tests/unit/secrets", diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/log/oss_task_handler.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/log/oss_task_handler.py index 4d2e1bbdeeff4..9967c416c1fde 100644 --- a/providers/alibaba/src/airflow/providers/alibaba/cloud/log/oss_task_handler.py +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/log/oss_task_handler.py @@ -19,34 +19,57 @@ import contextlib import os -import pathlib import shutil from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING + +import attrs from airflow.configuration import conf from airflow.providers.alibaba.cloud.hooks.oss import OSSHook from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo -class OSSTaskHandler(FileTaskHandler, LoggingMixin): - """ - OSSTaskHandler is a python log handler that handles and reads task instance logs. - Extends airflow FileTaskHandler and uploads to and reads from OSS remote storage. - """ +@attrs.define(kw_only=True) +class OSSRemoteLogIO(LoggingMixin): # noqa: D101 + base_log_folder: Path = attrs.field(converter=Path) + remote_base: str = "" + delete_local_copy: bool = True - def __init__(self, base_log_folder, oss_log_folder, **kwargs): - self.log.info("Using oss_task_handler for remote logging...") - super().__init__(base_log_folder) - (self.bucket_name, self.base_folder) = OSSHook.parse_oss_url(oss_log_folder) - self.log_relative_path = "" - self._hook = None - self.closed = False - self.upload_on_close = True - self.delete_local_copy = kwargs.get( - "delete_local_copy", conf.getboolean("logging", "delete_local_logs") - ) + processors = () + + def upload(self, path: os.PathLike | str): + """Upload the given log path to the remote storage.""" + path = Path(path) + if path.is_absolute(): + local_loc = path + remote_loc = os.path.join(self.remote_base, path.relative_to(self.base_log_folder)) + else: + local_loc = self.base_log_folder.joinpath(path) + remote_loc = os.path.join(self.remote_base, path) + + if local_loc.is_file(): + # read log and remove old logs to get just the latest additions + log = local_loc.read_text() + has_uploaded = self.oss_write(log, remote_loc) + if has_uploaded and self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) + + @cached_property + def base_folder(self): + (_, base_folder) = OSSHook.parse_oss_url(self.remote_base) + return base_folder + + @cached_property + def bucket_name(self): + (bucket_name, _) = OSSHook.parse_oss_url(self.remote_base) + return bucket_name @cached_property def hook(self): @@ -63,71 +86,15 @@ def hook(self): remote_conn_id, ) - def set_context(self, ti): - """Set the context of the handler.""" - super().set_context(ti) - # Local location and remote location is needed to open and - # upload local log file to OSS remote storage. - self.log_relative_path = self._render_filename(ti, ti.try_number) - self.upload_on_close = not ti.raw - - # Clear the file first so that duplicate data is not uploaded - # when reusing the same path (e.g. with rescheduled sensors) - if self.upload_on_close: - with open(self.handler.baseFilename, "w"): - pass + def read(self, relative_path, ti: TaskInstance | None = None) -> tuple[LogSourceInfo, LogMessages | None]: + logs: list[str] = [] + messages = [relative_path] - def close(self): - """Close and upload local log file to remote storage OSS.""" - # When application exit, system shuts down all handlers by - # calling close method. Here we check if logger is already - # closed to prevent uploading the log to remote storage multiple - # times when `logging.shutdown` is called. - if self.closed: - return - - super().close() - - if not self.upload_on_close: - return - - local_loc = os.path.join(self.local_base, self.log_relative_path) - remote_loc = self.log_relative_path - if os.path.exists(local_loc): - # read log and remove old logs to get just the latest additions - log = pathlib.Path(local_loc).read_text() - oss_write = self.oss_write(log, remote_loc) - if oss_write and self.delete_local_copy: - shutil.rmtree(os.path.dirname(local_loc)) - - # Mark closed so we don't double write if close is called twice - self.closed = True - - def _read(self, ti, try_number, metadata=None): - """ - Read logs of given task instance and try_number from OSS remote storage. - - If failed, read the log from task instance host machine. - - :param ti: task instance object - :param try_number: task instance try_number to read logs from - :param metadata: log metadata, - can be used for steaming log reading and auto-tailing. - """ - # Explicitly getting log relative path is necessary as the given - # task instance might be different from task instance passed in - # set_context method. - log_relative_path = self._render_filename(ti, try_number) - remote_loc = log_relative_path - - if not self.oss_log_exists(remote_loc): - return super()._read(ti, try_number, metadata) - # If OSS remote file exists, we do not fetch logs from task instance - # local machine even if there are errors reading remote logs, as - # returned remote_log will contain error messages. - remote_log = self.oss_read(remote_loc, return_error=True) - log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" - return log, {"end_of_log": True} + if self.oss_log_exists(relative_path): + logs.append(self.oss_read(relative_path, return_error=True)) + return messages, logs + else: + return messages, None def oss_log_exists(self, remote_log_location): """ @@ -149,8 +116,8 @@ def oss_read(self, remote_log_location, return_error=False): :param return_error: if True, returns a string error message if an error occurs. Otherwise, returns '' when an error occurs. """ + oss_remote_log_location = f"{self.base_folder}/{remote_log_location}" try: - oss_remote_log_location = f"{self.base_folder}/{remote_log_location}" self.log.info("read remote log: %s", oss_remote_log_location) return self.hook.read_key(self.bucket_name, oss_remote_log_location) except Exception: @@ -188,3 +155,89 @@ def oss_write(self, log, remote_log_location, append=True) -> bool: ) return False return True + + +class OSSTaskHandler(FileTaskHandler, LoggingMixin): + """ + OSSTaskHandler is a python log handler that handles and reads task instance logs. + + Extends airflow FileTaskHandler and uploads to and reads from OSS remote storage. + """ + + def __init__(self, base_log_folder, oss_log_folder, **kwargs): + self.log.info("Using oss_task_handler for remote logging...") + super().__init__(base_log_folder) + self.log_relative_path = "" + self._hook = None + self.closed = False + self.upload_on_close = True + self.delete_local_copy = kwargs.get( + "delete_local_copy", conf.getboolean("logging", "delete_local_logs") + ) + + self.io = OSSRemoteLogIO( + remote_base=oss_log_folder, + base_log_folder=base_log_folder, + delete_local_copy=kwargs.get( + "delete_local_copy", conf.getboolean("logging", "delete_local_logs") + ), + ) + + def set_context(self, ti): + """Set the context of the handler.""" + super().set_context(ti) + # Local location and remote location is needed to open and + # upload local log file to OSS remote storage. + self.log_relative_path = self._render_filename(ti, ti.try_number) + self.upload_on_close = not ti.raw + + # Clear the file first so that duplicate data is not uploaded + # when reusing the same path (e.g. with rescheduled sensors) + if self.upload_on_close: + with open(self.handler.baseFilename, "w"): + pass + + def close(self): + """Close and upload local log file to remote storage OSS.""" + # When application exit, system shuts down all handlers by + # calling close method. Here we check if logger is already + # closed to prevent uploading the log to remote storage multiple + # times when `logging.shutdown` is called. + if self.closed: + return + + super().close() + + if not self.upload_on_close: + return + + self.io.upload(self.log_relative_path) + + # Mark closed so we don't double write if close is called twice + self.closed = True + + def _read(self, ti, try_number, metadata=None): + """ + Read logs of given task instance and try_number from OSS remote storage. + + If failed, read the log from task instance host machine. + + :param ti: task instance object + :param try_number: task instance try_number to read logs from + :param metadata: log metadata, + can be used for steaming log reading and auto-tailing. + """ + # Explicitly getting log relative path is necessary as the given + # task instance might be different from task instance passed in + # set_context method. + log_relative_path = self._render_filename(ti, try_number) + remote_loc = log_relative_path + + if not self.oss_log_exists(remote_loc): + return super()._read(ti, try_number, metadata) + # If OSS remote file exists, we do not fetch logs from task instance + # local machine even if there are errors reading remote logs, as + # returned remote_log will contain error messages. + remote_log = self.oss_read(remote_loc, return_error=True) + log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" + return log, {"end_of_log": True} diff --git a/providers/alibaba/tests/unit/alibaba/cloud/log/test_oss_task_handler.py b/providers/alibaba/tests/unit/alibaba/cloud/log/test_oss_task_handler.py index f3e73e33b0055..3a3f5b73639db 100644 --- a/providers/alibaba/tests/unit/alibaba/cloud/log/test_oss_task_handler.py +++ b/providers/alibaba/tests/unit/alibaba/cloud/log/test_oss_task_handler.py @@ -24,7 +24,7 @@ import pytest -from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSTaskHandler +from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSRemoteLogIO, OSSTaskHandler # noqa: F401 from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import datetime @@ -74,33 +74,33 @@ def test_hook(self, mock_service, mock_conf_get): mock_conf_get.return_value = "oss_default" # When - self.oss_task_handler.hook + self.oss_task_handler.io.hook # Then mock_conf_get.assert_called_once_with("logging", "REMOTE_LOG_CONN_ID") mock_service.assert_called_once_with(oss_conn_id="oss_default") - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.hook"), new_callable=PropertyMock) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.hook"), new_callable=PropertyMock) def test_oss_log_exists(self, mock_service): - self.oss_task_handler.oss_log_exists("1.log") + self.oss_task_handler.io.oss_log_exists("1.log") mock_service.assert_called_once_with() mock_service.return_value.key_exist.assert_called_once_with(MOCK_BUCKET_NAME, "airflow/logs/1.log") - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.hook"), new_callable=PropertyMock) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.hook"), new_callable=PropertyMock) def test_oss_read(self, mock_service): - self.oss_task_handler.oss_read("1.log") + self.oss_task_handler.io.oss_read("1.log") mock_service.assert_called_once_with() mock_service.return_value.read_key(MOCK_BUCKET_NAME, "airflow/logs/1.log") - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.oss_log_exists")) - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.hook"), new_callable=PropertyMock) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.oss_log_exists")) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.hook"), new_callable=PropertyMock) def test_oss_write_into_remote_existing_file_via_append(self, mock_service, mock_oss_log_exists): # Given mock_oss_log_exists.return_value = True mock_service.return_value.head_key.return_value.content_length = 1 # When - self.oss_task_handler.oss_write(MOCK_CONTENT, "1.log", append=True) + self.oss_task_handler.io.oss_write(MOCK_CONTENT, "1.log", append=True) # Then assert mock_service.call_count == 2 @@ -110,14 +110,14 @@ def test_oss_write_into_remote_existing_file_via_append(self, mock_service, mock MOCK_BUCKET_NAME, MOCK_CONTENT, "airflow/logs/1.log", 1 ) - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.oss_log_exists")) - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.hook"), new_callable=PropertyMock) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.oss_log_exists")) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.hook"), new_callable=PropertyMock) def test_oss_write_into_remote_non_existing_file_via_append(self, mock_service, mock_oss_log_exists): # Given mock_oss_log_exists.return_value = False # When - self.oss_task_handler.oss_write(MOCK_CONTENT, "1.log", append=True) + self.oss_task_handler.io.oss_write(MOCK_CONTENT, "1.log", append=True) # Then assert mock_service.call_count == 1 @@ -127,14 +127,14 @@ def test_oss_write_into_remote_non_existing_file_via_append(self, mock_service, MOCK_BUCKET_NAME, MOCK_CONTENT, "airflow/logs/1.log", 0 ) - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.oss_log_exists")) - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.hook"), new_callable=PropertyMock) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.oss_log_exists")) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.hook"), new_callable=PropertyMock) def test_oss_write_into_remote_existing_file_not_via_append(self, mock_service, mock_oss_log_exists): # Given mock_oss_log_exists.return_value = True # When - self.oss_task_handler.oss_write(MOCK_CONTENT, "1.log", append=False) + self.oss_task_handler.io.oss_write(MOCK_CONTENT, "1.log", append=False) # Then assert mock_service.call_count == 1 @@ -144,14 +144,14 @@ def test_oss_write_into_remote_existing_file_not_via_append(self, mock_service, MOCK_BUCKET_NAME, MOCK_CONTENT, "airflow/logs/1.log", 0 ) - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.oss_log_exists")) - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.hook"), new_callable=PropertyMock) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.oss_log_exists")) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.hook"), new_callable=PropertyMock) def test_oss_write_into_remote_non_existing_file_not_via_append(self, mock_service, mock_oss_log_exists): # Given mock_oss_log_exists.return_value = False # When - self.oss_task_handler.oss_write(MOCK_CONTENT, "1.log", append=False) + self.oss_task_handler.io.oss_write(MOCK_CONTENT, "1.log", append=False) # Then assert mock_service.call_count == 1 @@ -165,7 +165,7 @@ def test_oss_write_into_remote_non_existing_file_not_via_append(self, mock_servi "delete_local_copy, expected_existence_of_local_copy", [(True, False), (False, True)], ) - @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.hook"), new_callable=PropertyMock) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.hook"), new_callable=PropertyMock) def test_close_with_delete_local_copy_conf( self, mock_service, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index b7383c8f46f3b..37cff5bca1941 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -17,20 +17,30 @@ # under the License. from __future__ import annotations +import copy +import json +import logging +import os from datetime import date, datetime, timedelta, timezone from functools import cached_property +from pathlib import Path from typing import TYPE_CHECKING, Any +import attrs import watchtower from airflow.configuration import conf from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: - from airflow.models import TaskInstance + import structlog.typing + + from airflow.models.taskinstance import TaskInstance + from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo def json_serialize_legacy(value: Any) -> str | None: @@ -62,6 +72,144 @@ def json_serialize(value: Any) -> str | None: return watchtower._json_serialize_default(value) +@attrs.define(kw_only=True) +class CloudWatchRemoteLogIO(LoggingMixin): # noqa: D101 + base_log_folder: Path = attrs.field(converter=Path) + remote_base: str = "" + delete_local_copy: bool = True + + log_group_arn: str + log_stream_name: str = "" + log_group: str = attrs.field(init=False, repr=False) + region_name: str = attrs.field(init=False, repr=False) + + @log_group.default + def _(self): + return self.log_group_arn.split(":")[6] + + @region_name.default + def _(self): + return self.log_group_arn.split(":")[3] + + @cached_property + def hook(self): + """Returns AwsLogsHook.""" + return AwsLogsHook( + aws_conn_id=conf.get("logging", "remote_log_conn_id"), region_name=self.region_name + ) + + @cached_property + def handler(self) -> watchtower.CloudWatchLogHandler: + _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None) + return watchtower.CloudWatchLogHandler( + log_group_name=self.log_group, + log_stream_name=self.log_stream_name, + use_queues=True, + boto3_client=self.hook.get_conn(), + json_serialize_default=_json_serialize or json_serialize_legacy, + ) + + @cached_property + def processors(self) -> tuple[structlog.typing.Processor, ...]: + from logging import getLogRecordFactory + + import structlog.stdlib + + logRecordFactory = getLogRecordFactory() + + def proc(logger: structlog.typing.WrappedLogger, method_name: str, event: structlog.typing.EventDict): + name = event.get("logger_name") or event.get("logger", "") + level = structlog.stdlib.NAME_TO_LEVEL.get(method_name.lower(), logging.INFO) + msg = copy.copy(event) + created = None + if ts := msg.pop("timestamp", None): + try: + created = datetime.fromisoformat(ts) + except Exception: + pass + record = logRecordFactory( + name, level, pathname="", lineno=0, msg=msg, args=(), exc_info=None, func=None, sinfo=None + ) + if created is not None: + ct = created.timestamp() + record.created = ct + record.msecs = int((ct - int(ct)) * 1000) + 0.0 # Copied from stdlib logging + self.handler.handle(record) + return event + + return (proc,) + + def close(self): + self.handler.close() + + def upload(self, path: os.PathLike | str): + # No-op, as we upload via the processor as we go + # But we need to give the handler time to finish off its business + self.close() + return + + def read(self, relative_path, ti: TaskInstance | None = None) -> tuple[LogSourceInfo, LogMessages | None]: + logs: LogMessages | None = [] + messages = [ + f"Reading remote log from Cloudwatch log_group: {self.log_group} log_stream: {relative_path}" + ] + try: + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.log.file_task_handler import StructuredLogMessage + + logs = [ + StructuredLogMessage.model_validate(log) + for log in self.get_cloudwatch_logs(relative_path, ti) + ] + else: + logs = [self.get_cloudwatch_logs(relative_path, ti)] # type: ignore[arg-value] + except Exception as e: + logs = None + messages.append(str(e)) + + return messages, logs + + def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance | None): + """ + Return all logs from the given log stream. + + :param stream_name: name of the Cloudwatch log stream to get all logs from + :param task_instance: the task instance to get logs about + :return: string of all logs from the given log stream + """ + # If there is an end_date to the task instance, fetch logs until that date + 30 seconds + # 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted + end_time = ( + None + if task_instance is None or task_instance.end_date is None + else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30)) + ) + events = self.hook.get_log_events( + log_group=self.log_group, + log_stream_name=stream_name, + end_time=end_time, + ) + if AIRFLOW_V_3_0_PLUS: + return list(self._event_to_dict(e) for e in events) + return "\n".join(self._event_to_str(event) for event in events) + + def _event_to_dict(self, event: dict) -> dict: + event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc).isoformat() + message = event["message"] + try: + message = json.loads(message) + message["timestamp"] = event_dt + return message + except Exception: + return {"timestamp": event_dt, "event": message} + + def _event_to_str(self, event: dict) -> str: + event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc) + formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] + message = event["message"] + return f"[{formatted_event_dt}] {message}" + + class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin): """ CloudwatchTaskHandler is a python log handler that handles and reads task instance logs. @@ -84,6 +232,11 @@ def __init__(self, base_log_folder: str, log_group_arn: str, **kwargs): self.region_name = split_arn[3] self.closed = False + self.io = CloudWatchRemoteLogIO( + base_log_folder=base_log_folder, + log_group_arn=log_group_arn, + ) + @cached_property def hook(self): """Returns AwsLogsHook.""" @@ -97,14 +250,9 @@ def _render_filename(self, ti, try_number): def set_context(self, ti: TaskInstance, *, identifier: str | None = None): super().set_context(ti) - _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None) - self.handler = watchtower.CloudWatchLogHandler( - log_group_name=self.log_group, - log_stream_name=self._render_filename(ti, ti.try_number), - use_queues=not getattr(ti, "is_trigger_log_context", False), - boto3_client=self.hook.get_conn(), - json_serialize_default=_json_serialize or json_serialize_legacy, - ) + self.io.log_stream_name = self._render_filename(ti, ti.try_number) + + self.handler = self.io.handler def close(self): """Close the handler responsible for the upload of the local log file to Cloudwatch.""" @@ -120,49 +268,9 @@ def close(self): # Mark closed so we don't double write if close is called twice self.closed = True - def _read(self, task_instance, try_number, metadata=None): + def _read_remote_logs( + self, task_instance, try_number, metadata=None + ) -> tuple[LogSourceInfo, LogMessages]: stream_name = self._render_filename(task_instance, try_number) - try: - return ( - f"*** Reading remote log from Cloudwatch log_group: {self.log_group} " - f"log_stream: {stream_name}.\n" - f"{self.get_cloudwatch_logs(stream_name=stream_name, task_instance=task_instance)}\n", - {"end_of_log": True}, - ) - except Exception as e: - log = ( - f"*** Unable to read remote logs from Cloudwatch (log_group: {self.log_group}, log_stream: " - f"{stream_name})\n*** {e}\n\n" - ) - self.log.error(log) - local_log, metadata = super()._read(task_instance, try_number, metadata) - log += local_log - return log, metadata - - def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> str: - """ - Return all logs from the given log stream. - - :param stream_name: name of the Cloudwatch log stream to get all logs from - :param task_instance: the task instance to get logs about - :return: string of all logs from the given log stream - """ - # If there is an end_date to the task instance, fetch logs until that date + 30 seconds - # 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted - end_time = ( - None - if task_instance.end_date is None - else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30)) - ) - events = self.hook.get_log_events( - log_group=self.log_group, - log_stream_name=stream_name, - end_time=end_time, - ) - return "\n".join(self._event_to_str(event) for event in events) - - def _event_to_str(self, event: dict) -> str: - event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc) - formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] - message = event["message"] - return f"[{formatted_event_dt}] {message}" + messages, logs = self.io.read(stream_name, task_instance) + return messages, logs or [] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/log/s3_task_handler.py b/providers/amazon/src/airflow/providers/amazon/aws/log/s3_task_handler.py index 4e1ce4655b70c..aaab3241ee969 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/log/s3_task_handler.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/log/s3_task_handler.py @@ -24,6 +24,8 @@ from functools import cached_property from typing import TYPE_CHECKING +import attrs + from airflow.configuration import conf from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS @@ -32,28 +34,33 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance + from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo -class S3TaskHandler(FileTaskHandler, LoggingMixin): - """ - S3TaskHandler is a python log handler that handles and reads task instance logs. +@attrs.define +class S3RemoteLogIO(LoggingMixin): # noqa: D101 + remote_base: str + base_log_folder: pathlib.Path = attrs.field(converter=pathlib.Path) + delete_local_copy: bool - It extends airflow FileTaskHandler and uploads to and reads from S3 remote storage. - """ + processors = () - trigger_should_wrap = True + def upload(self, path: os.PathLike | str): + """Upload the given log path to the remote storage.""" + path = pathlib.Path(path) + if path.is_absolute(): + local_loc = path + remote_loc = os.path.join(self.remote_base, path.relative_to(self.base_log_folder)) + else: + local_loc = self.base_log_folder.joinpath(path) + remote_loc = os.path.join(self.remote_base, path) - def __init__(self, base_log_folder: str, s3_log_folder: str, **kwargs): - super().__init__(base_log_folder) - self.handler: logging.FileHandler | None = None - self.remote_base = s3_log_folder - self.log_relative_path = "" - self._hook = None - self.closed = False - self.upload_on_close = True - self.delete_local_copy = kwargs.get( - "delete_local_copy", conf.getboolean("logging", "delete_local_logs") - ) + if local_loc.is_file(): + # read log and remove old logs to get just the latest additions + log = local_loc.read_text() + has_uploaded = self.write(log, remote_loc) + if has_uploaded and self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) @cached_property def hook(self): @@ -63,73 +70,6 @@ def hook(self): transfer_config_args={"use_threads": False}, ) - def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None: - super().set_context(ti, identifier=identifier) - # Local location and remote location is needed to open and - # upload local log file to S3 remote storage. - if TYPE_CHECKING: - assert self.handler is not None - - full_path = self.handler.baseFilename - self.log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix() - is_trigger_log_context = getattr(ti, "is_trigger_log_context", False) - self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None) - # Clear the file first so that duplicate data is not uploaded - # when reusing the same path (e.g. with rescheduled sensors) - if self.upload_on_close: - with open(self.handler.baseFilename, "w"): - pass - - def close(self): - """Close and upload local log file to remote storage S3.""" - # When application exit, system shuts down all handlers by - # calling close method. Here we check if logger is already - # closed to prevent uploading the log to remote storage multiple - # times when `logging.shutdown` is called. - if self.closed: - return - - super().close() - - if not self.upload_on_close: - return - - local_loc = os.path.join(self.local_base, self.log_relative_path) - remote_loc = os.path.join(self.remote_base, self.log_relative_path) - if os.path.exists(local_loc): - # read log and remove old logs to get just the latest additions - log = pathlib.Path(local_loc).read_text() - write_to_s3 = self.s3_write(log, remote_loc) - if write_to_s3 and self.delete_local_copy: - shutil.rmtree(os.path.dirname(local_loc)) - - # Mark closed so we don't double write if close is called twice - self.closed = True - - def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]: - # Explicitly getting log relative path is necessary as the given - # task instance might be different than task instance passed in - # in set_context method. - worker_log_rel_path = self._render_filename(ti, try_number) - - logs = [] - messages = [] - bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, worker_log_rel_path)) - keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix) - if keys: - keys = sorted(f"s3://{bucket}/{key}" for key in keys) - if AIRFLOW_V_3_0_PLUS: - messages = keys - else: - messages.append("Found logs in s3:") - messages.extend(f" * {key}" for key in keys) - for key in keys: - logs.append(self.s3_read(key, return_error=True)) - else: - if not AIRFLOW_V_3_0_PLUS: - messages.append(f"No logs found on s3 for ti={ti}") - return messages, logs - def s3_log_exists(self, remote_log_location: str) -> bool: """ Check if remote_log_location exists in remote storage. @@ -158,7 +98,7 @@ def s3_read(self, remote_log_location: str, return_error: bool = False) -> str: return msg return "" - def s3_write( + def write( self, log: str, remote_log_location: str, @@ -168,7 +108,7 @@ def s3_write( """ Write the log to the remote_log_location; return `True` or fails silently and return `False`. - :param log: the log to write to the remote_log_location + :param log: the contents to write to the remote_log_location :param remote_log_location: the log's location in remote storage :param append: if False, any existing log file is overwritten. If True, the new log is appended to any existing logs. @@ -205,3 +145,96 @@ def s3_write( self.log.exception("Could not write logs to %s", remote_log_location) return False return True + + def read(self, relative_path: str) -> tuple[LogSourceInfo, LogMessages | None]: + logs: list[str] = [] + messages = [] + bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, relative_path)) + keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix) + if keys: + keys = sorted(f"s3://{bucket}/{key}" for key in keys) + if AIRFLOW_V_3_0_PLUS: + messages = keys + else: + messages.append("Found logs in s3:") + messages.extend(f" * {key}" for key in keys) + for key in keys: + logs.append(self.s3_read(key, return_error=True)) + return messages, logs + else: + return messages, None + + +class S3TaskHandler(FileTaskHandler, LoggingMixin): + """ + S3TaskHandler is a python log handler that handles and reads task instance logs. + + It extends airflow FileTaskHandler and uploads to and reads from S3 remote storage. + """ + + def __init__(self, base_log_folder: str, s3_log_folder: str, **kwargs): + super().__init__(base_log_folder) + self.handler: logging.FileHandler | None = None + self.remote_base = s3_log_folder + self.log_relative_path = "" + self._hook = None + self.closed = False + self.upload_on_close = True + self.io = S3RemoteLogIO( + remote_base=s3_log_folder, + base_log_folder=base_log_folder, + delete_local_copy=kwargs.get( + "delete_local_copy", conf.getboolean("logging", "delete_local_logs") + ), + ) + + def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None: + super().set_context(ti, identifier=identifier) + # Local location and remote location is needed to open and + # upload local log file to S3 remote storage. + if TYPE_CHECKING: + assert self.handler is not None + + full_path = self.handler.baseFilename + self.log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix() + is_trigger_log_context = getattr(ti, "is_trigger_log_context", False) + self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None) + # Clear the file first so that duplicate data is not uploaded + # when reusing the same path (e.g. with rescheduled sensors) + if self.upload_on_close: + with open(self.handler.baseFilename, "w"): + pass + + def close(self): + """Close and upload local log file to remote storage S3.""" + # When application exit, system shuts down all handlers by + # calling close method. Here we check if logger is already + # closed to prevent uploading the log to remote storage multiple + # times when `logging.shutdown` is called. + if self.closed: + return + + super().close() + + if not self.upload_on_close: + return + + self.io.upload(self.log_relative_path) + + # Mark closed so we don't double write if close is called twice + self.closed = True + + def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]: + # Explicitly getting log relative path is necessary as the given + # task instance might be different than task instance passed in + # in set_context method. + worker_log_rel_path = self._render_filename(ti, try_number) + + messages, logs = self.io.read(worker_log_rel_path) + + if logs is None: + logs = [] + if not AIRFLOW_V_3_0_PLUS: + messages.append(f"No logs found on s3 for ti={ti}") + + return messages, logs diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py index 48ea1a3cb48fe..dd765ab69458b 100644 --- a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py @@ -18,19 +18,24 @@ from __future__ import annotations import logging +import textwrap import time from datetime import datetime as dt, timedelta, timezone from unittest import mock -from unittest.mock import ANY, Mock, call +from unittest.mock import ANY, call import boto3 import pytest +import time_machine from moto import mock_aws +from pydantic import TypeAdapter +from pydantic_core import TzInfo from watchtower import CloudWatchLogHandler from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.log.cloudwatch_task_handler import ( + CloudWatchRemoteLogIO, CloudwatchTaskHandler, ) from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms @@ -39,7 +44,7 @@ from airflow.utils.timezone import datetime from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS def get_time_str(time_in_milliseconds): @@ -53,6 +58,110 @@ def logmock(): yield +# We only test this directly on Airflow 3 +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This path only works on Airflow 3") +class TestCloudRemoteLogIO: + # We use the cap_structlog so that our changes get reverted for us + @pytest.fixture(autouse=True) + def setup_tests(self, create_runtime_ti, tmp_path, monkeypatch): + import structlog + + import airflow.logging_config + import airflow.sdk.log + from airflow.sdk import BaseOperator + + task = BaseOperator(task_id="task_1") + self.ti = create_runtime_ti(task) + + self.remote_log_group = "log_group_name" + self.region_name = "us-west-2" + self.local_log_location = tmp_path / "local-cloudwatch-log-location" + self.local_log_location.mkdir() + + # The subject under test + self.subject = CloudWatchRemoteLogIO( + base_log_folder=self.local_log_location, + log_group_arn=f"arn:aws:logs:{self.region_name}:11111111:log-group:{self.remote_log_group}", + log_stream_name="dag_id=a/0.log", + ) + + conn = boto3.client("logs", region_name=self.region_name) + conn.create_log_group(logGroupName=self.remote_log_group) + conn.create_log_stream(logGroupName=self.remote_log_group, logStreamName=self.subject.log_stream_name) + + processors = structlog.get_config()["processors"] + old_processors = processors.copy() + + try: + # Modify `_Configuration.default_processors` set via `configure` but always + # keep the list instance intact to not break references held by bound + # loggers. + + # Set up the right chain of processors so the event looks like we want for our full test + monkeypatch.setattr(airflow.logging_config, "REMOTE_TASK_LOG", self.subject) + procs, _ = airflow.sdk.log.logging_processors(enable_pretty_log=False) + processors.clear() + processors.extend(procs) + + # Replace the last "output" one with a DropEvent one instead - else we get the output on stdout + def drop(*args): + raise structlog.DropEvent() + + processors[-1] = drop + structlog.configure(processors=processors) + yield + finally: + # remove LogCapture and restore original processors + processors.clear() + processors.extend(old_processors) + structlog.configure(processors=old_processors) + + @time_machine.travel(datetime(2025, 3, 27, 21, 58, 1, 2345), tick=False) + def test_log_message(self): + import structlog + + log = structlog.get_logger() + log.info("Hi", foo="bar") + # We need to close in order to flush the logs etc. + self.subject.close() + + logs = self.subject.read("dag_id=a/0.log", self.ti) + + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.log.file_task_handler import StructuredLogMessage + + metadata, logs = logs + + results = TypeAdapter(list[StructuredLogMessage]).dump_python(logs) + assert metadata == [ + "Reading remote log from Cloudwatch log_group: log_group_name log_stream: dag_id=a/0.log" + ] + assert results == [ + { + "event": "Hi", + "foo": "bar", + "level": "info", + "timestamp": datetime(2025, 3, 27, 21, 58, 1, 2000, tzinfo=TzInfo(0)), + }, + ] + + def test_event_to_str(self): + handler = self.subject + current_time = int(time.time()) * 1000 + events = [ + {"timestamp": current_time - 2000, "message": "First"}, + {"timestamp": current_time - 1000, "message": "Second"}, + {"timestamp": current_time, "message": "Third"}, + ] + assert [handler._event_to_str(event) for event in events] == ( + [ + f"[{get_time_str(current_time - 2000)}] First", + f"[{get_time_str(current_time - 1000)}] Second", + f"[{get_time_str(current_time)}] Third", + ] + ) + + @pytest.mark.db_test class TestCloudwatchTaskHandler: @conf_vars({("logging", "remote_log_conn_id"): "aws_default"}) @@ -108,6 +217,7 @@ def setup_tests(self, create_log_template, tmp_path_factory, session): yield self.cloudwatch_task_handler.handler = None + del self.cloudwatch_task_handler def test_hook(self): assert isinstance(self.cloudwatch_task_handler.hook, AwsLogsHook) @@ -126,23 +236,8 @@ def test_write(self): handler.handle(message) mock_emit.assert_has_calls([call(message) for message in messages]) - def test_event_to_str(self): - handler = self.cloudwatch_task_handler - current_time = int(time.time()) * 1000 - events = [ - {"timestamp": current_time - 2000, "message": "First"}, - {"timestamp": current_time - 1000, "message": "Second"}, - {"timestamp": current_time, "message": "Third"}, - ] - assert [handler._event_to_str(event) for event in events] == ( - [ - f"[{get_time_str(current_time - 2000)}] First", - f"[{get_time_str(current_time - 1000)}] Second", - f"[{get_time_str(current_time)}] Third", - ] - ) - - def test_read(self): + @time_machine.travel(datetime(2025, 3, 27, 21, 58, 1, 2345), tick=False) + def test_read(self, monkeypatch): # Confirmed via AWS Support call: # CloudWatch events must be ordered chronologically otherwise # boto3 put_log_event API throws InvalidParameterException @@ -158,32 +253,48 @@ def test_read(self): {"timestamp": current_time, "message": "Third"}, ], ) + if AIRFLOW_V_2_10_PLUS: + monkeypatch.setattr(self.cloudwatch_task_handler, "_read_from_logs_server", lambda a, b: ([], [])) + msg_template = textwrap.dedent(""" + INFO - ::group::Log message source details + *** Reading remote log from Cloudwatch log_group: {} log_stream: {} + INFO - ::endgroup:: + {} + """)[1:][:-1] # Strip off leading and trailing new lines, but not spaces + else: + msg_template = textwrap.dedent(""" + *** Reading remote log from Cloudwatch log_group: {} log_stream: {} + {} + """).strip() - msg_template = "*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n" - events = "\n".join( - [ - f"[{get_time_str(current_time - 2000)}] First", - f"[{get_time_str(current_time - 1000)}] Second", - f"[{get_time_str(current_time)}] Third", - ] - ) + logs, metadata = self.cloudwatch_task_handler.read(self.ti) if AIRFLOW_V_3_0_PLUS: - assert self.cloudwatch_task_handler.read(self.ti) == ( - msg_template.format(self.remote_log_group, self.remote_log_stream, events), - {"end_of_log": True}, - ) + from airflow.utils.log.file_task_handler import StructuredLogMessage + + results = TypeAdapter(list[StructuredLogMessage]).dump_python(logs) + assert results[-4:] == [ + {"event": "::endgroup::", "timestamp": None}, + {"event": "First", "timestamp": datetime(2025, 3, 27, 21, 57, 59)}, + {"event": "Second", "timestamp": datetime(2025, 3, 27, 21, 58, 0)}, + {"event": "Third", "timestamp": datetime(2025, 3, 27, 21, 58, 1)}, + ] + assert metadata["log_pos"] == 3 else: - assert self.cloudwatch_task_handler.read(self.ti) == ( + events = "\n".join( [ - [ - ( - "", - msg_template.format(self.remote_log_group, self.remote_log_stream, events), - ) - ] - ], - [{"end_of_log": True}], + f"[{get_time_str(current_time - 2000)}] First", + f"[{get_time_str(current_time - 1000)}] Second", + f"[{get_time_str(current_time)}] Third", + ] ) + assert logs == [ + [ + ( + "", + msg_template.format(self.remote_log_group, self.remote_log_stream, events), + ) + ] + ] @pytest.mark.parametrize( "end_date, expected_end_time", @@ -198,7 +309,7 @@ def test_read(self): @mock.patch.object(AwsLogsHook, "get_log_events") def test_get_cloudwatch_logs(self, mock_get_log_events, end_date, expected_end_time): self.ti.end_date = end_date - self.cloudwatch_task_handler.get_cloudwatch_logs(self.remote_log_stream, self.ti) + self.cloudwatch_task_handler.io.get_cloudwatch_logs(self.remote_log_stream, self.ti) mock_get_log_events.assert_called_once_with( log_group=self.remote_log_group, log_stream_name=self.remote_log_stream, @@ -253,10 +364,8 @@ def __repr__(self): mock.patch("watchtower.threading.Thread"), mock.patch("watchtower.queue.Queue") as mq, ): - mock_queue = Mock() - mq.return_value = mock_queue handler.handle(message) - mock_queue.put.assert_called_once_with( + mq.return_value.put.assert_called_once_with( {"message": expected_serialized_output, "timestamp": ANY} ) diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py index e44991ffe3058..45a39e33ad4e8 100644 --- a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py +++ b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py @@ -45,7 +45,7 @@ def s3mock(): @pytest.mark.db_test -class TestS3TaskHandler: +class TestS3RemoteLogIO: @conf_vars({("logging", "remote_log_conn_id"): "aws_default"}) @pytest.fixture(autouse=True) def setup_tests(self, create_log_template, tmp_path_factory, session): @@ -56,7 +56,8 @@ def setup_tests(self, create_log_template, tmp_path_factory, session): create_log_template("{try_number}.log") self.s3_task_handler = S3TaskHandler(self.local_log_location, self.remote_log_base) # Verify the hook now with the config override - assert self.s3_task_handler.hook is not None + self.subject = self.s3_task_handler.io + assert self.subject.hook is not None date = datetime(2016, 1, 1) self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date) @@ -98,25 +99,123 @@ def setup_tests(self, create_log_template, tmp_path_factory, session): os.remove(self.s3_task_handler.handler.baseFilename) def test_hook(self): - assert isinstance(self.s3_task_handler.hook, S3Hook) - assert self.s3_task_handler.hook.transfer_config.use_threads is False + assert isinstance(self.subject.hook, S3Hook) + assert self.subject.hook.transfer_config.use_threads is False def test_log_exists(self): self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"") - assert self.s3_task_handler.s3_log_exists(self.remote_log_location) + assert self.subject.s3_log_exists(self.remote_log_location) def test_log_exists_none(self): - assert not self.s3_task_handler.s3_log_exists(self.remote_log_location) + assert not self.subject.s3_log_exists(self.remote_log_location) def test_log_exists_raises(self): - assert not self.s3_task_handler.s3_log_exists("s3://nonexistentbucket/foo") + assert not self.subject.s3_log_exists("s3://nonexistentbucket/foo") def test_log_exists_no_hook(self): - handler = S3TaskHandler(self.local_log_location, self.remote_log_base) + subject = S3TaskHandler(self.local_log_location, self.remote_log_base).io with mock.patch.object(S3Hook, "__init__", spec=S3Hook) as mock_hook: mock_hook.side_effect = ConnectionError("Fake: Failed to connect") with pytest.raises(ConnectionError, match="Fake: Failed to connect"): - handler.s3_log_exists(self.remote_log_location) + subject.s3_log_exists(self.remote_log_location) + + def test_s3_read_when_log_missing(self): + url = "s3://bucket/foo" + with mock.patch.object(self.subject.log, "error") as mock_error: + result = self.subject.s3_read(url, return_error=True) + msg = ( + f"Could not read logs from {url} with error: An error occurred (404) when calling the " + f"HeadObject operation: Not Found" + ) + assert result == msg + mock_error.assert_called_once_with(msg, exc_info=True) + + def test_read_raises_return_error(self): + url = "s3://nonexistentbucket/foo" + with mock.patch.object(self.subject.log, "error") as mock_error: + result = self.subject.s3_read(url, return_error=True) + msg = ( + f"Could not read logs from {url} with error: An error occurred (NoSuchBucket) when " + f"calling the HeadObject operation: The specified bucket does not exist" + ) + assert result == msg + mock_error.assert_called_once_with(msg, exc_info=True) + + def test_write(self): + with mock.patch.object(self.subject.log, "error") as mock_error: + self.subject.write("text", self.remote_log_location) + # We shouldn't expect any error logs in the default working case. + mock_error.assert_not_called() + body = boto3.resource("s3").Object("bucket", self.remote_log_key).get()["Body"].read() + + assert body == b"text" + + def test_write_existing(self): + self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"previous ") + self.subject.write("text", self.remote_log_location) + body = boto3.resource("s3").Object("bucket", self.remote_log_key).get()["Body"].read() + + assert body == b"previous \ntext" + + def test_write_raises(self): + url = "s3://nonexistentbucket/foo" + with mock.patch.object(self.subject.log, "error") as mock_error: + self.subject.write("text", url) + mock_error.assert_called_once_with("Could not write logs to %s", url, exc_info=True) + + +@pytest.mark.db_test +class TestS3TaskHandler: + @conf_vars({("logging", "remote_log_conn_id"): "aws_default"}) + @pytest.fixture(autouse=True) + def setup_tests(self, create_log_template, tmp_path_factory, session): + self.remote_log_base = "s3://bucket/remote/log/location" + self.remote_log_location = "s3://bucket/remote/log/location/1.log" + self.remote_log_key = "remote/log/location/1.log" + self.local_log_location = str(tmp_path_factory.mktemp("local-s3-log-location")) + create_log_template("{try_number}.log") + self.s3_task_handler = S3TaskHandler(self.local_log_location, self.remote_log_base) + # Verify the hook now with the config override + assert self.s3_task_handler.io.hook is not None + + date = datetime(2016, 1, 1) + self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date) + task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag) + if AIRFLOW_V_3_0_PLUS: + dag_run = DagRun( + dag_id=self.dag.dag_id, + logical_date=date, + run_id="test", + run_type="manual", + ) + else: + dag_run = DagRun( + dag_id=self.dag.dag_id, + execution_date=date, + run_id="test", + run_type="manual", + ) + session.add(dag_run) + session.commit() + session.refresh(dag_run) + + self.ti = TaskInstance(task=task, run_id=dag_run.run_id) + self.ti.dag_run = dag_run + self.ti.try_number = 1 + self.ti.state = State.RUNNING + session.add(self.ti) + session.commit() + + self.conn = boto3.client("s3") + self.conn.create_bucket(Bucket="bucket") + yield + + self.dag.clear() + + session.query(DagRun).delete() + if self.s3_task_handler.handler: + with contextlib.suppress(Exception): + os.remove(self.s3_task_handler.handler.baseFilename) def test_set_context_raw(self): self.ti.raw = True @@ -137,7 +236,11 @@ def test_set_context_not_raw(self): mock_open().write.assert_not_called() def test_read(self): - self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"Log line\n") + # Test what happens when we have two log files to read + self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"Log line\nLine 2\n") + self.conn.put_object( + Bucket="bucket", Key=self.remote_log_key + ".trigger.log", Body=b"Log line 3\nLine 4\n" + ) ti = copy.copy(self.ti) ti.state = TaskInstanceState.SUCCESS log, metadata = self.s3_task_handler.read(ti) @@ -149,12 +252,15 @@ def test_read(self): assert expected_s3_uri in log[0].sources assert log[1].event == "::endgroup::" assert log[2].event == "Log line" - assert metadata == {"end_of_log": True, "log_pos": 1} + assert log[3].event == "Line 2" + assert log[4].event == "Log line 3" + assert log[5].event == "Line 4" + assert metadata == {"end_of_log": True, "log_pos": 4} else: actual = log[0][0][-1] assert f"*** Found logs in s3:\n*** * {expected_s3_uri}\n" in actual - assert actual.endswith("Log line") - assert metadata == [{"end_of_log": True, "log_pos": 8}] + assert actual.endswith("Line 4") + assert metadata == [{"end_of_log": True, "log_pos": 33}] def test_read_when_s3_log_missing(self): ti = copy.copy(self.ti) @@ -172,53 +278,6 @@ def test_read_when_s3_log_missing(self): assert expected in actual assert metadata[0] == {"end_of_log": True, "log_pos": 0} - def test_s3_read_when_log_missing(self): - handler = self.s3_task_handler - url = "s3://bucket/foo" - with mock.patch.object(handler.log, "error") as mock_error: - result = handler.s3_read(url, return_error=True) - msg = ( - f"Could not read logs from {url} with error: An error occurred (404) when calling the " - f"HeadObject operation: Not Found" - ) - assert result == msg - mock_error.assert_called_once_with(msg, exc_info=True) - - def test_read_raises_return_error(self): - handler = self.s3_task_handler - url = "s3://nonexistentbucket/foo" - with mock.patch.object(handler.log, "error") as mock_error: - result = handler.s3_read(url, return_error=True) - msg = ( - f"Could not read logs from {url} with error: An error occurred (NoSuchBucket) when " - f"calling the HeadObject operation: The specified bucket does not exist" - ) - assert result == msg - mock_error.assert_called_once_with(msg, exc_info=True) - - def test_write(self): - with mock.patch.object(self.s3_task_handler.log, "error") as mock_error: - self.s3_task_handler.s3_write("text", self.remote_log_location) - # We shouldn't expect any error logs in the default working case. - mock_error.assert_not_called() - body = boto3.resource("s3").Object("bucket", self.remote_log_key).get()["Body"].read() - - assert body == b"text" - - def test_write_existing(self): - self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"previous ") - self.s3_task_handler.s3_write("text", self.remote_log_location) - body = boto3.resource("s3").Object("bucket", self.remote_log_key).get()["Body"].read() - - assert body == b"previous \ntext" - - def test_write_raises(self): - handler = self.s3_task_handler - url = "s3://nonexistentbucket/foo" - with mock.patch.object(handler.log, "error") as mock_error: - handler.s3_write("text", url) - mock_error.assert_called_once_with("Could not write logs to %s", url, exc_info=True) - def test_close(self): self.s3_task_handler.set_context(self.ti) assert self.s3_task_handler.upload_on_close diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index a5190cc613bf5..4092c82ec5494 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -25,6 +25,8 @@ from pathlib import Path from typing import TYPE_CHECKING +import attrs + # not sure why but mypy complains on missing `storage` but it is clearly there and is importable from google.cloud import storage # type: ignore[attr-defined] @@ -42,6 +44,7 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance + from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo _DEFAULT_SCOPESS = frozenset( [ @@ -52,6 +55,126 @@ logger = logging.getLogger(__name__) +@attrs.define +class GCSRemoteLogIO(LoggingMixin): # noqa: D101 + remote_base: str + base_log_folder: Path = attrs.field(converter=Path) + delete_local_copy: bool + + gcp_key_path: str | None + gcp_keyfile_dict: dict | None + scopes: Collection[str] | None + project_id: str + + def upload(self, path: os.PathLike): + """Upload the given log path to the remote storage.""" + path = Path(path) + if path.is_absolute(): + local_loc = path + remote_loc = os.path.join(self.remote_base, path.relative_to(self.base_log_folder)) + else: + local_loc = self.base_log_folder.joinpath(path) + remote_loc = os.path.join(self.remote_base, path) + + if local_loc.is_file(): + # read log and remove old logs to get just the latest additions + log = local_loc.read_text() + has_uploaded = self.write(log, remote_loc) + if has_uploaded and self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) + + @cached_property + def hook(self) -> GCSHook | None: + """Returns GCSHook if remote_log_conn_id configured.""" + conn_id = conf.get("logging", "remote_log_conn_id", fallback=None) + if conn_id: + try: + return GCSHook(gcp_conn_id=conn_id) + except AirflowNotFoundException: + pass + return None + + @cached_property + def client(self) -> storage.Client: + """Returns GCS Client.""" + if self.hook: + credentials, project_id = self.hook.get_credentials_and_project_id() + else: + credentials, project_id = get_credentials_and_project_id( + key_path=self.gcp_key_path, + keyfile_dict=self.gcp_keyfile_dict, + scopes=self.scopes, + disable_logging=True, + ) + return storage.Client( + credentials=credentials, + client_info=CLIENT_INFO, + project=self.project_id if self.project_id else project_id, + ) + + def write(self, log: str, remote_log_location: str) -> bool: + """ + Write the log to the remote location and return `True`; fail silently and return `False` on error. + + :param log: the log to write to the remote_log_location + :param remote_log_location: the log's location in remote storage + :return: whether the log is successfully written to remote location or not. + """ + try: + blob = storage.Blob.from_string(remote_log_location, self.client) + old_log = blob.download_as_bytes().decode() + log = f"{old_log}\n{log}" if old_log else log + except Exception as e: + if not self.no_log_found(e): + self.log.warning("Error checking for previous log: %s", e) + try: + blob = storage.Blob.from_string(remote_log_location, self.client) + blob.upload_from_string(log, content_type="text/plain") + except Exception as e: + self.log.error("Could not write logs to %s: %s", remote_log_location, e) + return False + return True + + @staticmethod + def no_log_found(exc): + """ + Given exception, determine whether it is result of log not found. + + :meta private: + """ + return (exc.args and isinstance(exc.args[0], str) and "No such object" in exc.args[0]) or getattr( + exc, "resp", {} + ).get("status") == "404" + + def read(self, relative_path) -> tuple[LogSourceInfo, LogMessages | None]: + messages = [] + logs = [] + remote_loc = os.path.join(self.remote_base, relative_path) + uris = [] + bucket, prefix = _parse_gcs_url(remote_loc) + blobs = list(self.client.list_blobs(bucket_or_name=bucket, prefix=prefix)) + + if blobs: + uris = [f"gs://{bucket}/{b.name}" for b in blobs] + if AIRFLOW_V_3_0_PLUS: + messages = uris + else: + messages.extend(["Found remote logs:", *[f" * {x}" for x in sorted(uris)]]) + else: + return messages, None + + try: + for key in sorted(uris): + blob = storage.Blob.from_string(key, self.client) + remote_log = blob.download_as_bytes().decode() + if remote_log: + logs.append(remote_log) + except Exception as e: + if not AIRFLOW_V_3_0_PLUS: + messages.append(f"Unable to read remote log {e}") + return messages, logs + + class GCSTaskHandler(FileTaskHandler, LoggingMixin): """ GCSTaskHandler is a python log handler that handles and reads task instance logs. @@ -91,45 +214,19 @@ def __init__( ): super().__init__(base_log_folder) self.handler: logging.FileHandler | None = None - self.remote_base = gcs_log_folder self.log_relative_path = "" self.closed = False self.upload_on_close = True - self.gcp_key_path = gcp_key_path - self.gcp_keyfile_dict = gcp_keyfile_dict - self.scopes = gcp_scopes - self.project_id = project_id - self.delete_local_copy = kwargs.get( - "delete_local_copy", conf.getboolean("logging", "delete_local_logs") - ) - - @cached_property - def hook(self) -> GCSHook | None: - """Returns GCSHook if remote_log_conn_id configured.""" - conn_id = conf.get("logging", "remote_log_conn_id", fallback=None) - if conn_id: - try: - return GCSHook(gcp_conn_id=conn_id) - except AirflowNotFoundException: - pass - return None - - @cached_property - def client(self) -> storage.Client: - """Returns GCS Client.""" - if self.hook: - credentials, project_id = self.hook.get_credentials_and_project_id() - else: - credentials, project_id = get_credentials_and_project_id( - key_path=self.gcp_key_path, - keyfile_dict=self.gcp_keyfile_dict, - scopes=self.scopes, - disable_logging=True, - ) - return storage.Client( - credentials=credentials, - client_info=CLIENT_INFO, - project=self.project_id if self.project_id else project_id, + self.io = GCSRemoteLogIO( + base_log_folder=base_log_folder, + remote_base=gcs_log_folder, + delete_local_copy=kwargs.get( + "delete_local_copy", conf.getboolean("logging", "delete_local_logs") + ), + gcp_key_path=gcp_key_path, + gcp_keyfile_dict=gcp_keyfile_dict, + scopes=gcp_scopes, + project_id=project_id, ) def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None: @@ -159,91 +256,22 @@ def close(self): if not self.upload_on_close: return - local_loc = os.path.join(self.local_base, self.log_relative_path) - remote_loc = os.path.join(self.remote_base, self.log_relative_path) - if os.path.exists(local_loc): - # read log and remove old logs to get just the latest additions - with open(local_loc) as logfile: - log = logfile.read() - gcs_write = self.gcs_write(log, remote_loc) - if gcs_write and self.delete_local_copy: - shutil.rmtree(os.path.dirname(local_loc)) + self.io.upload(self.log_relative_path) # Mark closed so we don't double write if close is called twice self.closed = True - def _add_message(self, msg): - filename, lineno, func, stackinfo = logger.findCaller() - record = logging.LogRecord("", logging.INFO, filename, lineno, msg + "\n", None, None, func=func) - return self.format(record) + def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]: + # Explicitly getting log relative path is necessary as the given + # task instance might be different than task instance passed in + # in set_context method. + worker_log_rel_path = self._render_filename(ti, try_number) - def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]: - # Explicitly getting log relative path is necessary because this method - # is called from webserver from TaskLogReader, where we don't call set_context - # and can read logs for different TIs in each request - messages = [] - logs = [] - worker_log_relative_path = self._render_filename(ti, try_number) - remote_loc = os.path.join(self.remote_base, worker_log_relative_path) - uris = [] - bucket, prefix = _parse_gcs_url(remote_loc) - blobs = list(self.client.list_blobs(bucket_or_name=bucket, prefix=prefix)) + messages, logs = self.io.read(worker_log_rel_path) - if blobs: - uris = [f"gs://{bucket}/{b.name}" for b in blobs] - if AIRFLOW_V_3_0_PLUS: - messages = uris - else: - messages.extend(["Found remote logs:", *[f" * {x}" for x in sorted(uris)]]) - else: + if logs is None: + logs = [] if not AIRFLOW_V_3_0_PLUS: messages.append(f"No logs found in GCS; ti={ti}") - try: - for key in sorted(uris): - blob = storage.Blob.from_string(key, self.client) - remote_log = blob.download_as_bytes().decode() - if remote_log: - logs.append(remote_log) - except Exception as e: - if not AIRFLOW_V_3_0_PLUS: - messages.append(f"Unable to read remote log {e}") - return messages, logs - - def gcs_write(self, log, remote_log_location) -> bool: - """ - Write the log to the remote location and return `True`; fail silently and return `False` on error. - - :param log: the log to write to the remote_log_location - :param remote_log_location: the log's location in remote storage - :return: whether the log is successfully written to remote location or not. - """ - try: - blob = storage.Blob.from_string(remote_log_location, self.client) - old_log = blob.download_as_bytes().decode() - log = f"{old_log}\n{log}" if old_log else log - except Exception as e: - if not self.no_log_found(e): - log += self._add_message( - f"Error checking for previous log; if exists, may be overwritten: {e}" - ) - self.log.warning("Error checking for previous log: %s", e) - try: - blob = storage.Blob.from_string(remote_log_location, self.client) - blob.upload_from_string(log, content_type="text/plain") - except Exception as e: - self.log.error("Could not write logs to %s: %s", remote_log_location, e) - return False - return True - @staticmethod - def no_log_found(exc): - """ - Given exception, determine whether it is result of log not found. - - :meta private: - """ - if (exc.args and isinstance(exc.args[0], str) and "No such object" in exc.args[0]) or getattr( - exc, "resp", {} - ).get("status") == "404": - return True - return False + return messages, logs diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index b86a6fa2c2745..205bcfc5cb1bd 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -79,7 +79,7 @@ def test_client_conn_id_behavior(self, mock_get_cred, mock_client, mock_hook, co ) mock_get_cred.return_value = ("test_cred", "test_proj") with conf_vars({("logging", "remote_log_conn_id"): conn_id}): - return_value = self.gcs_task_handler.client + return_value = self.gcs_task_handler.io.client if conn_id: mock_hook.assert_called_once_with(gcp_conn_id="my_gcs_conn") mock_get_cred.assert_not_called() @@ -223,7 +223,7 @@ def test_failed_write_to_remote_on_close(self, mock_blob, mock_client, mock_cred assert caplog.record_tuples == [ ( - "airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler", + "airflow.providers.google.cloud.log.gcs_task_handler.GCSRemoteLogIO", logging.ERROR, "Could not write logs to gs://bucket/remote/log/location/1.log: Failed to connect", ), @@ -261,13 +261,13 @@ def test_write_to_remote_on_close_failed_read_old_logs(self, mock_blob, mock_cli ) self.gcs_task_handler.close() - mock_blob.assert_has_calls( + mock_blob.from_string.assert_has_calls( [ - mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value), - mock.call.from_string().download_as_bytes(), - mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value), - mock.call.from_string().upload_from_string( - "MESSAGE\nError checking for previous log; if exists, may be overwritten: Fail to download\n", + mock.call("gs://bucket/remote/log/location/1.log", mock_client.return_value), + mock.call().download_as_bytes(), + mock.call("gs://bucket/remote/log/location/1.log", mock_client.return_value), + mock.call().upload_from_string( + "MESSAGE\n", content_type="text/plain", ), ], diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/log/wasb_task_handler.py index 8130ef73c6d26..ca8b9bd3aa685 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/log/wasb_task_handler.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/log/wasb_task_handler.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import TYPE_CHECKING +import attrs from azure.core.exceptions import HttpResponseError from airflow.configuration import conf @@ -34,34 +35,35 @@ import logging from airflow.models.taskinstance import TaskInstance + from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo -class WasbTaskHandler(FileTaskHandler, LoggingMixin): - """ - WasbTaskHandler is a python log handler that handles and reads task instance logs. +@attrs.define +class WasbRemoteLogIO(LoggingMixin): # noqa: D101 + remote_base: str + base_log_folder: Path = attrs.field(converter=Path) + delete_local_copy: bool - It extends airflow FileTaskHandler and uploads to and reads from Wasb remote storage. - """ + wasb_container: str - trigger_should_wrap = True + processors = () - def __init__( - self, - base_log_folder: str, - wasb_log_folder: str, - wasb_container: str, - **kwargs, - ) -> None: - super().__init__(base_log_folder) - self.handler: logging.FileHandler | None = None - self.wasb_container = wasb_container - self.remote_base = wasb_log_folder - self.log_relative_path = "" - self.closed = False - self.upload_on_close = True - self.delete_local_copy = kwargs.get( - "delete_local_copy", conf.getboolean("logging", "delete_local_logs") - ) + def upload(self, path: str | os.PathLike): + """Upload the given log path to the remote storage.""" + path = Path(path) + if path.is_absolute(): + local_loc = path + remote_loc = os.path.join(self.remote_base, path.relative_to(self.base_log_folder)) + else: + local_loc = self.base_log_folder.joinpath(path) + remote_loc = os.path.join(self.remote_base, path) + + if local_loc.is_file(): + # read log and remove old logs to get just the latest additions + log = local_loc.read_text() + has_uploaded = self.write(log, remote_loc) + if has_uploaded and self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) @cached_property def hook(self): @@ -81,53 +83,13 @@ def hook(self): ) return None - def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None: - super().set_context(ti, identifier=identifier) - # Local location and remote location is needed to open and - # upload local log file to Wasb remote storage. - if TYPE_CHECKING: - assert self.handler is not None - - full_path = self.handler.baseFilename - self.log_relative_path = Path(full_path).relative_to(self.local_base).as_posix() - is_trigger_log_context = getattr(ti, "is_trigger_log_context", False) - self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None) - - def close(self) -> None: - """Close and upload local log file to remote storage Wasb.""" - # When application exit, system shuts down all handlers by - # calling close method. Here we check if logger is already - # closed to prevent uploading the log to remote storage multiple - # times when `logging.shutdown` is called. - if self.closed: - return - - super().close() - - if not self.upload_on_close: - return - - local_loc = os.path.join(self.local_base, self.log_relative_path) - remote_loc = os.path.join(self.remote_base, self.log_relative_path) - if os.path.exists(local_loc): - # read log and remove old logs to get just the latest additions - with open(local_loc) as logfile: - log = logfile.read() - wasb_write = self.wasb_write(log, remote_loc, append=True) - - if wasb_write and self.delete_local_copy: - shutil.rmtree(os.path.dirname(local_loc)) - # Mark closed so we don't double write if close is called twice - self.closed = True - - def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]: + def read(self, relative_path) -> tuple[LogSourceInfo, LogMessages | None]: messages = [] logs = [] - worker_log_relative_path = self._render_filename(ti, try_number) # TODO: fix this - "relative path" i.e currently REMOTE_BASE_LOG_FOLDER should start with "wasb" # unlike others with shceme in URL itself to identify the correct handler. # This puts limitations on ways users can name the base_path. - prefix = os.path.join(self.remote_base, worker_log_relative_path) + prefix = os.path.join(self.remote_base, relative_path) blob_names = [] try: blob_names = self.hook.get_blobs_list(container_name=self.wasb_container, prefix=prefix) @@ -143,8 +105,7 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], l else: messages.extend(["Found remote logs:", *[f" * {x}" for x in sorted(uris)]]) else: - if not AIRFLOW_V_3_0_PLUS: - messages.append(f"No logs found in WASB; ti={ti}") + return messages, None for name in sorted(blob_names): remote_log = "" @@ -191,7 +152,7 @@ def wasb_read(self, remote_log_location: str, return_error: bool = False): return msg return "" - def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> bool: + def write(self, log: str, remote_log_location: str, append: bool = True) -> bool: """ Write the log to the remote_log_location. Fails silently if no hook was created. @@ -210,3 +171,80 @@ def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> self.log.exception("Could not write logs to %s", remote_log_location) return False return True + + +class WasbTaskHandler(FileTaskHandler, LoggingMixin): + """ + WasbTaskHandler is a python log handler that handles and reads task instance logs. + + It extends airflow FileTaskHandler and uploads to and reads from Wasb remote storage. + """ + + trigger_should_wrap = True + + def __init__( + self, + base_log_folder: str, + wasb_log_folder: str, + wasb_container: str, + **kwargs, + ) -> None: + super().__init__(base_log_folder) + self.handler: logging.FileHandler | None = None + self.log_relative_path = "" + self.closed = False + self.upload_on_close = True + self.io = WasbRemoteLogIO( + base_log_folder=base_log_folder, + remote_base=wasb_log_folder, + wasb_container=wasb_container, + delete_local_copy=kwargs.get( + "delete_local_copy", conf.getboolean("logging", "delete_local_logs") + ), + ) + + def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None: + super().set_context(ti, identifier=identifier) + # Local location and remote location is needed to open and + # upload local log file to Wasb remote storage. + if TYPE_CHECKING: + assert self.handler is not None + + full_path = self.handler.baseFilename + self.log_relative_path = Path(full_path).relative_to(self.local_base).as_posix() + is_trigger_log_context = getattr(ti, "is_trigger_log_context", False) + self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None) + + def close(self) -> None: + """Close and upload local log file to remote storage Wasb.""" + # When application exit, system shuts down all handlers by + # calling close method. Here we check if logger is already + # closed to prevent uploading the log to remote storage multiple + # times when `logging.shutdown` is called. + if self.closed: + return + + super().close() + + if not self.upload_on_close: + return + + self.io.upload(self.log_relative_path) + + # Mark closed so we don't double write if close is called twice + self.closed = True + + def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]: + # Explicitly getting log relative path is necessary as the given + # task instance might be different than task instance passed in + # in set_context method. + worker_log_rel_path = self._render_filename(ti, try_number) + + messages, logs = self.io.read(worker_log_rel_path) + + if logs is None: + logs = [] + if not AIRFLOW_V_3_0_PLUS: + messages.append(f"No logs found in WASB; ti={ti}") + + return messages, logs diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py b/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py index 8f48c96002870..ba7fe5f1b2a4c 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py @@ -26,7 +26,7 @@ from azure.common import AzureHttpError from airflow.providers.microsoft.azure.hooks.wasb import WasbHook -from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbTaskHandler +from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO, WasbTaskHandler from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import datetime @@ -74,16 +74,16 @@ def setup_method(self): @conf_vars({("logging", "remote_log_conn_id"): "wasb_default"}) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") def test_hook(self, mock_service): - assert isinstance(self.wasb_task_handler.hook, WasbHook) + assert isinstance(self.wasb_task_handler.io.hook, WasbHook) @conf_vars({("logging", "remote_log_conn_id"): "wasb_default"}) def test_hook_warns(self): handler = self.wasb_task_handler - with mock.patch.object(handler.log, "exception") as mock_exc: + with mock.patch.object(handler.io.log, "exception") as mock_exc: with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") as mock_hook: mock_hook.side_effect = AzureHttpError("failed to connect", 404) # Initialize the hook - handler.hook + handler.io.hook assert "Could not create a WasbHook with connection id '%s'" in mock_exc.call_args.args[0] @@ -100,7 +100,7 @@ def test_set_context_not_raw(self, ti): def test_wasb_log_exists(self, mock_hook): instance = mock_hook.return_value instance.check_for_blob.return_value = True - self.wasb_task_handler.wasb_log_exists(self.remote_log_location) + self.wasb_task_handler.io.wasb_log_exists(self.remote_log_location) mock_hook.return_value.check_for_blob.assert_called_once_with( self.container_name, self.remote_log_location ) @@ -110,7 +110,7 @@ def test_wasb_read(self, mock_hook_cls, ti): mock_hook = mock_hook_cls.return_value mock_hook.get_blobs_list.return_value = ["abc/hello.log"] mock_hook.read_file.return_value = "Log line" - assert self.wasb_task_handler.wasb_read(self.remote_log_location) == "Log line" + assert self.wasb_task_handler.io.wasb_read(self.remote_log_location) == "Log line" ti = copy.copy(ti) ti.state = TaskInstanceState.SUCCESS @@ -143,31 +143,31 @@ def test_wasb_read(self, mock_hook_cls, ti): ) def test_wasb_read_raises(self, mock_hook): handler = self.wasb_task_handler - with mock.patch.object(handler.log, "error") as mock_error: - handler.wasb_read(self.remote_log_location, return_error=True) + with mock.patch.object(handler.io.log, "error") as mock_error: + handler.io.wasb_read(self.remote_log_location, return_error=True) mock_error.assert_called_once_with( "Could not read logs from remote/log/location/1.log", exc_info=True, ) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") - @mock.patch.object(WasbTaskHandler, "wasb_read") - @mock.patch.object(WasbTaskHandler, "wasb_log_exists") + @mock.patch.object(WasbRemoteLogIO, "wasb_read") + @mock.patch.object(WasbRemoteLogIO, "wasb_log_exists") def test_write_log(self, mock_log_exists, mock_wasb_read, mock_hook): mock_log_exists.return_value = True mock_wasb_read.return_value = "" - self.wasb_task_handler.wasb_write("text", self.remote_log_location) + self.wasb_task_handler.io.write("text", self.remote_log_location) mock_hook.return_value.load_string.assert_called_once_with( "text", self.container_name, self.remote_log_location, overwrite=True ) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") - @mock.patch.object(WasbTaskHandler, "wasb_read") - @mock.patch.object(WasbTaskHandler, "wasb_log_exists") + @mock.patch.object(WasbRemoteLogIO, "wasb_read") + @mock.patch.object(WasbRemoteLogIO, "wasb_log_exists") def test_write_on_existing_log(self, mock_log_exists, mock_wasb_read, mock_hook): mock_log_exists.return_value = True mock_wasb_read.return_value = "old log" - self.wasb_task_handler.wasb_write("text", self.remote_log_location) + self.wasb_task_handler.io.write("text", self.remote_log_location) mock_hook.return_value.load_string.assert_called_once_with( "old log\ntext", self.container_name, @@ -177,18 +177,18 @@ def test_write_on_existing_log(self, mock_log_exists, mock_wasb_read, mock_hook) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") def test_write_when_append_is_false(self, mock_hook): - self.wasb_task_handler.wasb_write("text", self.remote_log_location, False) + self.wasb_task_handler.io.write("text", self.remote_log_location, False) mock_hook.return_value.load_string.assert_called_once_with( "text", self.container_name, self.remote_log_location, overwrite=True ) def test_write_raises(self): handler = self.wasb_task_handler - with mock.patch.object(handler.log, "error") as mock_error: + with mock.patch.object(handler.io.log, "error") as mock_error: with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") as mock_hook: mock_hook.return_value.load_string.side_effect = AzureHttpError("failed to connect", 404) - handler.wasb_write("text", self.remote_log_location, append=False) + handler.io.write("text", self.remote_log_location, append=False) mock_error.assert_called_once_with( "Could not write logs to %s", "remote/log/location/1.log", exc_info=True @@ -198,7 +198,7 @@ def test_write_raises(self): "delete_local_copy, expected_existence_of_local_copy", [(True, False), (False, True)], ) - @mock.patch("airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler.wasb_write") + @mock.patch.object(WasbRemoteLogIO, "write") def test_close_with_delete_local_logs_conf( self, wasb_write_mock, diff --git a/scripts/ci/pre_commit/check_tests_in_right_folders.py b/scripts/ci/pre_commit/check_tests_in_right_folders.py index 73846a7fe2f08..dd3fc4eb5521b 100755 --- a/scripts/ci/pre_commit/check_tests_in_right_folders.py +++ b/scripts/ci/pre_commit/check_tests_in_right_folders.py @@ -53,6 +53,7 @@ "jobs", "lineage", "listeners", + "logging", "macros", "models", "notifications", @@ -107,7 +108,7 @@ ): console.print( "[red]The file is in a wrong folder. Make sure to move it to the right folder " - "listed in `./script/ci/pre_commit/check_tests_in_right_folders.py` " + "listed in `./scripts/ci/pre_commit/check_tests_in_right_folders.py` " "or create new folder and add it to the script if you know what you are doing.[/]" ) console.print(file) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index bcc5dcafa71c2..d5d738f52546b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -121,6 +121,8 @@ class RuntimeTaskInstance(TaskInstance): start_date: AwareDatetime """Start date of the task instance.""" + end_date: AwareDatetime | None = None + is_mapped: bool | None = None """True if the original task was mapped.""" diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 5ce1e2ec5ed52..d209a14a7f4c4 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -34,6 +34,8 @@ if TYPE_CHECKING: from structlog.typing import EventDict, ExcInfo, FilteringBoundLogger, Processor + from airflow.logging_config import RemoteLogIO + __all__ = [ "configure_logging", @@ -226,10 +228,15 @@ def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: ( dict_tracebacks, structlog.processors.UnicodeDecoder(), - json, ), ) + # Include the remote logging provider for tasks if there are any we need (such as upload to Cloudwatch) + if (remote := load_remote_log_handler()) and (remote_processors := getattr(remote, "processors")): + processors.extend(remote_processors) + + processors.append(json) + return processors, { "timestamper": timestamper, "exc_group_processor": exc_group_processor, @@ -253,6 +260,7 @@ def configure_logging( from airflow.configuration import conf log_level = conf.get("logging", "logging_level", fallback="INFO") + lvl = structlog.stdlib.NAME_TO_LEVEL[log_level.lower()] if enable_pretty_log: @@ -285,6 +293,12 @@ def configure_logging( drop_positional_args, ] + if (remote := load_remote_log_handler()) and (remote_processors := getattr(remote, "processors")): + # Ensure we add in any remote log processor before we add `console` or `json` formatter so these get + # called with the event_dict as a dict still + color_formatter.extend(remote_processors) + std_lib_formatter.extend(remote_processors) + wrapper_class = structlog.make_filtering_bound_logger(lvl) if enable_pretty_log: if output is not None and not isinstance(output, TextIO): @@ -476,25 +490,14 @@ def init_log_file(local_relative_path: str) -> Path: return full_path -def load_remote_log_handler() -> logging.Handler | None: - from airflow.logging_config import configure_logging as airflow_configure_logging - from airflow.utils.log.log_reader import TaskLogReader +def load_remote_log_handler() -> RemoteLogIO | None: + import airflow.logging_config - try: - airflow_configure_logging() - - return TaskLogReader().log_handler - finally: - # This is a _monstrosity_ but put our logging back immediately... - configure_logging() + return airflow.logging_config.REMOTE_TASK_LOG def upload_to_remote(logger: FilteringBoundLogger): - # We haven't yet switched the Remote log handlers over, they are still wired up in providers as - # logging.Handlers (but we should re-write most of them to just be the upload and read instead of full - # variants.) In the mean time, lets just create the right handler directly from airflow.configuration import conf - from airflow.utils.log.file_task_handler import FileTaskHandler raw_logger = getattr(logger, "_logger") @@ -512,18 +515,8 @@ def upload_to_remote(logger: FilteringBoundLogger): relative_path = Path(fname).relative_to(base_log_folder) handler = load_remote_log_handler() - if not isinstance(handler, FileTaskHandler): - logger.warning( - "Airflow core logging is not using a FileTaskHandler, can't upload logs to remote", - handler=type(handler), - ) + if not handler: return - # This is a _monstrosity_, and super fragile, but we don't want to do the base FileTaskHandler - # set_context() which opens a real FH again. (And worse, in some cases it _truncates_ the file too). This - # is just for the first Airflow 3 betas, but we will re-write a better remote log interface that isn't - # tied to being a logging Handler. - handler.log_relative_path = relative_path.as_posix() # type: ignore[attr-defined] - handler.upload_on_close = True # type: ignore[attr-defined] - - handler.close() + log_relative_path = relative_path.as_posix() + handler.upload(log_relative_path) diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 0f37c532660e3..6760ea5195990 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -66,6 +66,7 @@ class RuntimeTaskInstanceProtocol(Protocol): max_tries: int hostname: str | None = None start_date: AwareDatetime + end_date: AwareDatetime | None = None def xcom_pull( self,