Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 64 additions & 54 deletions airflow-core/src/airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
import sys
from typing import TYPE_CHECKING, cast

from sqlalchemy import func, select
from rich import print as rich_print
from sqlalchemy import select

from airflow._shared.timezones import timezone
from airflow.api_fastapi.core_api.datamodels.dag_run import TriggerDAGRunPostBody
Expand Down Expand Up @@ -62,7 +63,7 @@
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from collections.abc import Iterator

from graphviz.dot import Dot
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -487,62 +488,60 @@ def iter_next_dagrun_info() -> Iterator[DagRunInfo | None]:


@cli_utils.action_cli
@deprecated_for_airflowctl("airflowctl dags list")
@suppress_logs_and_warning
@providers_configuration_loaded
@provide_session
def dag_list_dags(args, *, session: Session = NEW_SESSION) -> None:
"""Display dags with or without stats at the command line."""
def dag_list_dags(args) -> None:
"""Display Dags with or without stats at the command line."""
cols = args.columns if args.columns else []

if invalid_cols := [c for c in cols if c not in DAG_DETAIL_FIELDS]:
from rich import print as rich_print

rich_print(
f"[red][bold]Error:[/bold] Ignoring the following invalid columns: {invalid_cols}. "
f"List of valid columns: {sorted(DAG_DETAIL_FIELDS)}",
file=sys.stderr,
)

dagbag_import_errors = 0
dags_list = []
if args.local:
from airflow.dag_processing.dagbag import DagBag
_list_local_dags(args, cols=cols)
else:
_list_dags_from_api(args, cols=cols)

# Get import errors from the local area
if args.bundle_name:
manager = DagBundlesManager()
validate_dag_bundle_arg(args.bundle_name)
all_bundles = list(manager.get_all_dag_bundles())
bundles_to_search = set(args.bundle_name)

for bundle in all_bundles:
if bundle.name in bundles_to_search:
bundle_dagbag = BundleDagBag(
bundle.path, bundle_path=bundle.path, bundle_name=bundle.name
)
bundle_dagbag.collect_dags()
dags_list.extend(list(bundle_dagbag.dags.values()))
dagbag_import_errors += len(bundle_dagbag.import_errors)
else:
dagbag = DagBag()
dagbag.collect_dags()
dags_list.extend(list(dagbag.dags.values()))
dagbag_import_errors += len(dagbag.import_errors)
def _print_dag_import_error_warning() -> None:
rich_print(
"[red][bold]Error:[/bold] Failed to load all files. "
"For details, run `airflow dags list-import-errors`",
file=sys.stderr,
)


@provide_session
def _list_local_dags(args, cols: list[str] | tuple[str, ...], *, session: Session = NEW_SESSION) -> None:
"""List Dags parsed from local Dag bundles."""
dagbag_import_errors = 0
dags_list = []

if args.bundle_name:
manager = DagBundlesManager()
Comment thread
bugraoz93 marked this conversation as resolved.
validate_dag_bundle_arg(args.bundle_name)
all_bundles = list(manager.get_all_dag_bundles())
bundles_to_search = set(args.bundle_name)

for bundle in all_bundles:
if bundle.name in bundles_to_search:
bundle_dagbag = BundleDagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name)
bundle_dagbag.collect_dags()
dags_list.extend(list(bundle_dagbag.dags.values()))
dagbag_import_errors += len(bundle_dagbag.import_errors)
else:
dags_list.extend(cast("DAG", sm.dag) for sm in session.scalars(select(SerializedDagModel)))
pie_stmt = select(func.count()).select_from(ParseImportError)
if args.bundle_name:
pie_stmt = pie_stmt.where(ParseImportError.bundle_name.in_(args.bundle_name))
dagbag_import_errors = session.scalar(pie_stmt) or 0
dagbag = DagBag()
dagbag.collect_dags()
dags_list.extend(list(dagbag.dags.values()))
dagbag_import_errors += len(dagbag.import_errors)

if dagbag_import_errors > 0:
from rich import print as rich_print

rich_print(
"[red][bold]Error:[/bold] Failed to load all files. "
"For details, run `airflow dags list-import-errors`",
file=sys.stderr,
)
_print_dag_import_error_warning()

def get_dag_detail(dag: DAG) -> dict:
if dag_model := DagModel.get_dagmodel(dag.dag_id, session=session):
Expand All @@ -553,22 +552,33 @@ def get_dag_detail(dag: DAG) -> dict:
return dag_detail
return {col: dag_detail[col] for col in cols if col in DAG_DETAIL_FIELDS}

def filter_dags_by_bundle(dags: Iterable[DAG], bundle_names: list[str] | None) -> Iterable[DAG]:
"""Filter DAGs based on the specified bundle name, if provided."""
if not bundle_names:
return dags
AirflowConsole().print_as(
data=sorted(dags_list, key=operator.attrgetter("dag_id")),
output=args.output,
mapper=get_dag_detail,
)

validate_dag_bundle_arg(bundle_names)
selected_dag_ids = set(
session.scalars(select(DagModel.dag_id).where(DagModel.bundle_name.in_(bundle_names)))
)
return (dag for dag in dags if dag.dag_id in selected_dag_ids)

@provide_api_client
def _list_dags_from_api(args, cols: list[str] | tuple[str, ...], api_client: Client = NEW_API_CLIENT) -> None:
"""List Dags through the Public API."""
dags_list = list(api_client.dags.list().dags)

if args.bundle_name:
bundle_names = set(args.bundle_name)
dags_list = [dag for dag in dags_list if dag.bundle_name in bundle_names]

if any(dag.has_import_errors for dag in dags_list):
_print_dag_import_error_warning()

def get_dag_detail(dag) -> dict:
dag_detail = dag.model_dump()
if not cols:
return dag_detail
return {col: dag_detail[col] for col in cols if col in DAG_DETAIL_FIELDS}

AirflowConsole().print_as(
data=sorted(
filter_dags_by_bundle(dags_list, args.bundle_name if not args.local else None),
key=operator.attrgetter("dag_id"),
),
data=sorted(dags_list, key=operator.attrgetter("dag_id")),
output=args.output,
mapper=get_dag_detail,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DEPRECATED_CLI_COMMANDS = [
(dag_command.dag_trigger, ["dags", "trigger", "example_dag", "--run-id=x"], "airflowctl dags trigger"),
(dag_command.dag_delete, ["dags", "delete", "example_dag", "--yes"], "airflowctl dags delete"),
(dag_command.dag_list_dags, ["dags", "list"], "airflowctl dags list"),
(pool_command.pool_list, ["pools", "list"], "airflowctl pools list"),
(pool_command.pool_get, ["pools", "get", "foo"], "airflowctl pools get"),
(pool_command.pool_set, ["pools", "set", "foo", "1", "desc"], "airflowctl pools create"),
Expand Down
164 changes: 104 additions & 60 deletions airflow-core/tests/unit/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import pendulum
import pytest
import time_machine
from airflowctl.api.datamodels.generated import DAGResponse as AirflowCtlDAGResponse
from airflowctl.api.operations import ServerResponseError
from sqlalchemy import select

Expand Down Expand Up @@ -343,16 +344,6 @@ def test_cli_get_dag_details(self, stdout_capture):
for value in dag_details_values:
assert value in out

def test_cli_list_dags(self, stdout_capture):
args = self.parser.parse_args(["dags", "list", "--output", "json"])
with stdout_capture as temp_stdout:
dag_command.dag_list_dags(args)
out = temp_stdout.getvalue()
dag_list = json.loads(out)
for key in ["dag_id", "fileloc", "owners", "is_paused"]: # "bundle_name", "bundle_version"?
assert key in dag_list[0]
assert any("airflow/example_dags/example_complex.py" in d["fileloc"] for d in dag_list)

def test_cli_list_local_dags(self, stdout_capture):
# Clear the database
clear_db_dags()
Expand Down Expand Up @@ -390,42 +381,6 @@ def test_cli_list_local_dags_with_bundle_name(self, configure_testing_dag_bundle
# Rebuild Test DB for other tests
self.setup_class()

def test_cli_list_dags_custom_cols(self, stdout_capture):
args = self.parser.parse_args(
["dags", "list", "--output", "json", "--columns", "dag_id,last_parsed_time"]
)
with stdout_capture as temp_stdout:
dag_command.dag_list_dags(args)
out = temp_stdout.getvalue()
dag_list = json.loads(out)
for key in ["dag_id", "last_parsed_time"]:
assert key in dag_list[0]
for key in ["fileloc", "owners", "is_paused"]:
assert key not in dag_list[0]

def test_cli_list_dags_invalid_cols(self, stderr_capture):
args = self.parser.parse_args(["dags", "list", "--output", "json", "--columns", "dag_id,invalid_col"])
with stderr_capture as temp_stderr:
dag_command.dag_list_dags(args)
out = temp_stderr.getvalue()
assert "Ignoring the following invalid columns: ['invalid_col']" in out

@conf_vars({("core", "load_examples"): "false"})
def test_cli_list_dags_prints_import_errors(
self, configure_testing_dag_bundle, get_test_dag, stderr_capture
):
path_to_parse = TEST_DAGS_FOLDER / "test_invalid_cron.py"
get_test_dag("test_invalid_cron")

args = self.parser.parse_args(["dags", "list", "--output", "yaml", "--bundle-name", "testing"])

with configure_testing_dag_bundle(path_to_parse):
with stderr_capture as temp_stderr:
dag_command.dag_list_dags(args)
out = temp_stderr.getvalue()

assert "Failed to load all files." in out

@conf_vars({("core", "load_examples"): "false"})
def test_cli_list_dags_prints_local_import_errors(
self, configure_testing_dag_bundle, get_test_dag, stderr_capture
Expand All @@ -448,18 +403,6 @@ def test_cli_list_dags_prints_local_import_errors(
# Rebuild Test DB for other tests
self.setup_class()

@mock.patch("airflow.models.DagModel.get_dagmodel")
def test_list_dags_none_get_dagmodel(self, mock_get_dagmodel, stdout_capture):
mock_get_dagmodel.return_value = None
args = self.parser.parse_args(["dags", "list", "--output", "json"])
with stdout_capture as temp_stdout:
dag_command.dag_list_dags(args)
out = temp_stdout.getvalue()
dag_list = json.loads(out)
for key in ["dag_id", "fileloc", "owners", "is_paused"]:
assert key in dag_list[0]
assert any("airflow/example_dags/example_complex.py" in d["fileloc"] for d in dag_list)

def test_dagbag_dag_col(self, session):
dagbag = DBDagBag()
dag_details = dag_command._get_dagbag_dag_details(
Expand Down Expand Up @@ -1811,12 +1754,113 @@ def setup_class(cls):
cls.parser = cli_parser.get_parser()

@pytest.fixture(autouse=True)
def _default_trigger_response(self, mock_cli_api_client):
"""Give the mocked ``dags.trigger`` a dict response so ``print_as`` can render it."""
def _default_api_responses(self, mock_cli_api_client):
"""Configure default mocked responses for Dag API client commands."""
mock_cli_api_client.dags.trigger.return_value.model_dump.return_value = {
"dag_id": "example_bash_operator",
"dag_run_id": "test_run",
}
mock_cli_api_client.dags.list.return_value.dags = [
self._make_dag_response("example_bash_operator", "example_dags")
]

@staticmethod
def _make_dag_response(
dag_id: str, bundle_name: str, has_import_errors: bool = False
) -> AirflowCtlDAGResponse:
return AirflowCtlDAGResponse(
dag_id=dag_id,
dag_display_name=dag_id,
is_paused=False,
is_stale=False,
last_parsed_time=datetime(2026, 6, 12, tzinfo=timezone.utc),
bundle_name=bundle_name,
bundle_version="1",
relative_fileloc=f"{dag_id}.py",
fileloc=f"/dags/{dag_id}.py",
timetable_partitioned=False,
timetable_periodic=True,
tags=[],
max_active_tasks=16,
max_consecutive_failed_dag_runs=0,
has_task_concurrency_limits=False,
has_import_errors=has_import_errors,
owners=["airflow"],
is_backfillable=True,
file_token="file-token",
)

def test_list_dags(self, mock_cli_api_client, stdout_capture):
args = self.parser.parse_args(["dags", "list", "--output", "json"])

with stdout_capture as temp_stdout:
dag_command.dag_list_dags(args)
dag_list = json.loads(temp_stdout.getvalue())

assert len(dag_list) == 1
assert dag_list[0]["dag_id"] == "example_bash_operator"
assert dag_list[0]["fileloc"] == "/dags/example_bash_operator.py"
assert dag_list[0]["owners"] == ["airflow"]
assert dag_list[0]["bundle_name"] == "example_dags"
mock_cli_api_client.dags.list.assert_called_once_with()

def test_list_dags_custom_cols(self, mock_cli_api_client, stdout_capture):
args = self.parser.parse_args(
["dags", "list", "--output", "json", "--columns", "dag_id,last_parsed_time"]
)

with stdout_capture as temp_stdout:
dag_command.dag_list_dags(args)
dag_list = json.loads(temp_stdout.getvalue())

assert dag_list == [
{"dag_id": "example_bash_operator", "last_parsed_time": "2026-06-12 00:00:00+00:00"}
]

def test_list_dags_filters_bundle_names(self, mock_cli_api_client, stdout_capture):
mock_cli_api_client.dags.list.return_value.dags = [
self._make_dag_response("dag_b", "bundle_b"),
self._make_dag_response("dag_a", "bundle_a"),
self._make_dag_response("dag_c", "bundle_c"),
]
args = self.parser.parse_args(
[
"dags",
"list",
"--output",
"json",
"--bundle-name",
"bundle_a",
"--bundle-name",
"bundle_c",
]
)

with stdout_capture as temp_stdout:
dag_command.dag_list_dags(args)
dag_list = json.loads(temp_stdout.getvalue())

assert [dag["dag_id"] for dag in dag_list] == ["dag_a", "dag_c"]

def test_list_dags_invalid_cols(self, stderr_capture):
args = self.parser.parse_args(["dags", "list", "--output", "json", "--columns", "dag_id,invalid_col"])

with stderr_capture as temp_stderr:
dag_command.dag_list_dags(args)

assert "Ignoring the following invalid columns: ['invalid_col']" in temp_stderr.getvalue()

def test_list_dags_prints_matching_bundle_import_errors(self, mock_cli_api_client, stderr_capture):
mock_cli_api_client.dags.list.return_value.dags = [
self._make_dag_response("dag_a", "bundle_a", has_import_errors=True),
self._make_dag_response("dag_b", "bundle_b"),
]
args = self.parser.parse_args(["dags", "list", "--output", "json", "--bundle-name", "bundle_a"])

with stderr_capture as temp_stderr:
dag_command.dag_list_dags(args)

assert "Failed to load all files." in temp_stderr.getvalue()

def test_trigger_dag(self, mock_cli_api_client):
dag_command.dag_trigger(
Expand Down
Loading