From b24a358cebf969cbf5d024a0c74f3a8fd02e68c2 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Wed, 9 Jul 2025 23:47:35 +0100 Subject: [PATCH 1/3] Bugfix: return DataflowJobMessagesSensor and DataflowJobAutoScalingEventsSensor result with xcom_value --- .../google/cloud/sensors/dataflow.py | 23 ++++++++++++++----- .../google/cloud/sensors/test_dataflow.py | 4 ++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py index bc71934e9dc87..99ef2c45d707f 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py @@ -38,6 +38,7 @@ ) from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.providers.google.version_compat import BaseSensorOperator +from airflow.sdk import PokeReturnValue if TYPE_CHECKING: from airflow.utils.context import Context @@ -342,7 +343,7 @@ def __init__( self.deferrable = deferrable self.poll_interval = poll_interval - def poke(self, context: Context) -> bool: + def poke(self, context: Context) -> PokeReturnValue | bool: if self.fail_on_terminal_state: job = self.hook.get_job( job_id=self.job_id, @@ -359,8 +360,13 @@ def poke(self, context: Context) -> bool: project_id=self.project_id, location=self.location, ) - - return result if self.callback is None else self.callback(result) + result = result if self.callback is None else self.callback(result) + if bool(result): + return PokeReturnValue( + is_done=True, + xcom_value=result, + ) + return False def execute(self, context: Context) -> Any: """Airflow runs this method on the worker and defers using the trigger.""" @@ -464,7 +470,7 @@ def __init__( self.deferrable = deferrable self.poll_interval = poll_interval - def poke(self, context: Context) -> bool: + def poke(self, context: Context) -> PokeReturnValue | bool: if self.fail_on_terminal_state: job = self.hook.get_job( job_id=self.job_id, @@ -481,8 +487,13 @@ def poke(self, context: Context) -> bool: project_id=self.project_id, location=self.location, ) - - return result if self.callback is None else self.callback(result) + result = result if self.callback is None else self.callback(result) + if bool(result): + return PokeReturnValue( + is_done=True, + xcom_value=result, + ) + return False def execute(self, context: Context) -> Any: """Airflow runs this method on the worker and defers using the trigger.""" diff --git a/providers/google/tests/unit/google/cloud/sensors/test_dataflow.py b/providers/google/tests/unit/google/cloud/sensors/test_dataflow.py index 873780d0c353a..9b9a0fcff9f99 100644 --- a/providers/google/tests/unit/google/cloud/sensors/test_dataflow.py +++ b/providers/google/tests/unit/google/cloud/sensors/test_dataflow.py @@ -376,7 +376,7 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): results = task.poke(mock.MagicMock()) - assert callback.return_value == results + assert callback.return_value == results.xcom_value mock_hook.assert_called_once_with( gcp_conn_id=TEST_GCP_CONN_ID, @@ -552,7 +552,7 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): results = task.poke(mock.MagicMock()) - assert callback.return_value == results + assert callback.return_value == results.xcom_value mock_hook.assert_called_once_with( gcp_conn_id=TEST_GCP_CONN_ID, From bf7afe4fd2daeaf7b97317415365ee8019e21564 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Wed, 9 Jul 2025 23:49:51 +0100 Subject: [PATCH 2/3] Bugfix: return DataflowJobMessagesSensor and DataflowJobAutoScalingEventsSensor result with xcom_value --- .../src/airflow/providers/google/cloud/sensors/dataflow.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py index 99ef2c45d707f..542573c28a127 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py @@ -361,6 +361,10 @@ def poke(self, context: Context) -> PokeReturnValue | bool: location=self.location, ) result = result if self.callback is None else self.callback(result) + + if isinstance(result, PokeReturnValue): + return result + if bool(result): return PokeReturnValue( is_done=True, @@ -488,6 +492,9 @@ def poke(self, context: Context) -> PokeReturnValue | bool: location=self.location, ) result = result if self.callback is None else self.callback(result) + if isinstance(result, PokeReturnValue): + return result + if bool(result): return PokeReturnValue( is_done=True, From c47a18560ad175abf8c11e2dfbe651f3bd1e3ec2 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Wed, 9 Jul 2025 23:55:10 +0100 Subject: [PATCH 3/3] Bugfix: return DataflowJobMessagesSensor and DataflowJobAutoScalingEventsSensor result with xcom_value --- .../src/airflow/providers/google/cloud/sensors/dataflow.py | 3 +-- .../google/src/airflow/providers/google/version_compat.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py index 542573c28a127..c9f3a642802b0 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py @@ -37,8 +37,7 @@ DataflowJobStatusTrigger, ) from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID -from airflow.providers.google.version_compat import BaseSensorOperator -from airflow.sdk import PokeReturnValue +from airflow.providers.google.version_compat import BaseSensorOperator, PokeReturnValue if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/google/src/airflow/providers/google/version_compat.py b/providers/google/src/airflow/providers/google/version_compat.py index 45bbf473aff69..345946d1d4530 100644 --- a/providers/google/src/airflow/providers/google/version_compat.py +++ b/providers/google/src/airflow/providers/google/version_compat.py @@ -52,10 +52,11 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: from airflow.sdk import ( BaseOperatorLink, BaseSensorOperator, + PokeReturnValue, ) else: from airflow.models import BaseOperatorLink # type: ignore[no-redef] - from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] + from airflow.sensors.base import BaseSensorOperator, PokeReturnValue # type: ignore[no-redef] # Explicitly export these imports to protect them from being removed by linters __all__ = [ @@ -65,4 +66,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: "BaseOperator", "BaseSensorOperator", "BaseOperatorLink", + "PokeReturnValue", ]