From 6ae1aa2cd03ceb58bc08fbcd81374cdc856f70ee Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 24 Feb 2025 17:53:15 -0500 Subject: [PATCH 01/31] Add hooks for task submission and resolution --- hamilton/lifecycle/api.py | 135 ++++++++++++++++++++++++++++++++++--- hamilton/lifecycle/base.py | 109 +++++++++++++++++++++++------- 2 files changed, 210 insertions(+), 34 deletions(-) diff --git a/hamilton/lifecycle/api.py b/hamilton/lifecycle/api.py index ce810e5d8..c1c2d9f69 100644 --- a/hamilton/lifecycle/api.py +++ b/hamilton/lifecycle/api.py @@ -32,9 +32,11 @@ BasePostTaskExecute, BasePostTaskExpand, BasePostTaskGroup, + BasePostTaskResolution, BasePreGraphExecute, BasePreNodeExecute, BasePreTaskExecute, + BasePreTaskSubmission, BaseValidateGraph, BaseValidateNode, ) @@ -371,10 +373,121 @@ def run_after_graph_execution( pass +class TaskSubmissionHook(BasePreTaskSubmission, abc.ABC): + """Implement this to hook into the task submission process. Tasks are submitted to an executor, + which then controls how and where the nodes associated with the task are run.""" + + @override + def pre_task_submission( + self, + *, + run_id: str, + task_id: str, + nodes: List["node.Node"], + inputs: Dict[str, Any], + overrides: Dict[str, Any], + spawning_task_id: Optional[str], + purpose: NodeGroupPurpose, + ): + self.run_before_task_submission( + run_id=run_id, + task_id=task_id, + nodes=nodes, + inputs=inputs, + overrides=overrides, + spawning_task_id=spawning_task_id, + purpose=purpose, + ) + + @abc.abstractmethod + def run_before_task_submission( + self, + *, + run_id: str, + task_id: str, + nodes: List["node.Node"], + inputs: Dict[str, Any], + overrides: Dict[str, Any], + spawning_task_id: Optional[str], + purpose: NodeGroupPurpose, + **future_kwargs, + ): + """Runs prior to a task being submitted to an executor. By definition this is run *outside* + of the task executor, on the process that executed the driver. + + :param run_id: ID of the run this is under. + :param task_id: ID of the task we're launching. + :param nodes: Nodes that are part of this task + :param inputs: Inputs to the task + :param overrides: Overrides passed to the task + :param spawning_task_id: ID of the task that spawned this task + :param purpose: Purpose of the current task group + :param future_kwargs: Reserved for backwards compatibility. + """ + pass + + +class TaskResolutionHook(BasePostTaskResolution, abc.ABC): + """Implement this to hook into the task resolution process. Tasks are submitted to an executor, + which then returns a task future to be resolved at a later time when the task is complete.""" + + @override + def post_task_resolution( + self, + *, + run_id: str, + task_id: str, + nodes: List["node.Node"], + result: Any, + success: bool, + error: Optional[Exception], + spawning_task_id: Optional[str], + purpose: NodeGroupPurpose, + ): + self.run_after_task_resolution( + run_id=run_id, + task_id=task_id, + nodes=nodes, + result=result, + success=success, + error=error, + spawning_task_id=spawning_task_id, + purpose=purpose, + ) + + @abc.abstractmethod + def run_after_task_resolution( + self, + *, + run_id: str, + task_id: str, + nodes: List["node.Node"], + result: Any, + success: bool, + error: Optional[Exception], + spawning_task_id: Optional[str], + purpose: NodeGroupPurpose, + **future_kwargs, + ): + """Runs after a task (future) has been resolved into a returns value. By definition this is + run *outside* of the task executor,on the process that executed the driver. + + :param run_id: ID of the run this is under. + :param task_id: ID of the task that was just executed. + :param nodes: Nodes that were part of this task + :param result: Result of the task + :param success: Whether the task was successful + :param error: The error the task threw, if any + :param spawning_task_id: ID of the task that spawned this task + :param purpose: Purpose of the current task group + :param future_kwargs: Reserved for backwards compatibility. + """ + pass + + class TaskExecutionHook(BasePreTaskExecute, BasePostTaskExecute, abc.ABC): - """Implement this to run something after task execution. Tasks are tols used to group nodes. - Note that this is currently run *inside* the task, although we do not guarantee where it will be run - (it could easily move to outside the task).""" + """Implement this to hook into the task execution process. Tasks consist of a group of one or + more nodes that are run on a task executor.""" def pre_task_execute( self, @@ -433,9 +546,8 @@ def run_before_task_execution( purpose: NodeGroupPurpose, **future_kwargs, ): - """Implement this to run something after task execution. Tasks are tols used to group nodes. - Note that this is currently run *inside* the task, although we do not guarantee where it will be run - (it could easily move to outside the task). + """Runs prior to any of the nodes associated with a task. By definition this is run *inside* + of the executor and therefore may be run on separate or distributed processes. :param task_id: ID of the task we're launching. :param run_id: ID of the run this is under. @@ -462,7 +574,8 @@ def run_after_task_execution( purpose: NodeGroupPurpose, **future_kwargs, ): - """Implement this to run something after task execution. See note in run_before_task_execution. + """Runs after all of the nodes associated with a task have been executed. By definition this + is run *inside* of the executor and therefore may be run on separate or distributed processes. :param task_id: ID of the task that was just executed :param run_id: ID of the run this was under. @@ -651,7 +764,9 @@ def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, A @abc.abstractmethod def run_after_task_grouping(self, *, run_id: str, task_ids: List[str], **future_kwargs): - """Hook that is called after task grouping. + """Runs after task grouping. This allows you to capture information about the which tasks were + created for a given run. + :param run_id: ID of the run, unique in scope of the driver. :param task_ids: List of tasks that were grouped together. :param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility. @@ -662,7 +777,9 @@ def run_after_task_grouping(self, *, run_id: str, task_ids: List[str], **future_ def run_after_task_expansion( self, *, run_id: str, task_id: str, parameters: Dict[str, Any], **future_kwargs ): - """Hook that is called after task expansion. + """Runs after task expansion in Parallelize/Collect blocks. This allows you to capture information + about the task that was expanded. + :param run_id: ID of the run, unique in scope of the driver. :param task_id: ID of the task that was expanded. :param parameters: Parameters that were passed to the task. diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 460aa09a9..616e9a1af 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -410,6 +410,60 @@ async def pre_graph_execute( pass +@lifecycle.base_hook("post_task_group") +class BasePostTaskGroup(abc.ABC): + @abc.abstractmethod + def post_task_group(self, *, run_id: str, task_ids: List[str]): + """Hook that is called immediately after a task group is created. Note that this is only useful in dynamic + execution, although we reserve the right to add this back into the standard hamilton execution pattern. + + :param run_id: ID of the run, unique in scope of the driver. + :param task_ids: IDs of tasks that are in the group.""" + pass + + +@lifecycle.base_hook("post_task_expand") +class BasePostTaskExpand(abc.ABC): + @abc.abstractmethod + def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]): + """Hook that is called immediately after a task is expanded into separate task. Note that this is only useful + in dynamic execution. + + :param run_id: ID of the run, unique in scope of the driver. + :param task_id: ID of the task. + :param parameters: Parameters that are being passed to each of the expanded tasks.""" + pass + + +@lifecycle.base_hook("pre_task_submission") +class BasePreTaskSubmission(abc.ABC): + @abc.abstractmethod + def pre_task_submission( + self, + *, + run_id: str, + task_id: str, + nodes: List["node.Node"], + inputs: Dict[str, Any], + overrides: Dict[str, Any], + spawning_task_id: Optional[str], + purpose: NodeGroupPurpose, + ): + """Hook that is called immediately prior to task submission to an executor as a task future. + Note that this is only useful in dynamic execution, although we reserve the right to add this back + into the standard hamilton execution pattern. + + :param run_id: ID of the run, unique in scope of the driver. + :param task_id: ID of the task. + :param nodes: Nodes that are being executed. + :param inputs: Inputs to the task. + :param overrides: Overrides to task execution. + :param spawning_task_id: ID of the task that spawned this task. + :param purpose: Purpose of the current task group. + """ + pass + + @lifecycle.base_hook("pre_task_execute") class BasePreTaskExecute(abc.ABC): @abc.abstractmethod @@ -626,31 +680,6 @@ async def post_node_execute( pass -@lifecycle.base_hook("post_task_group") -class BasePostTaskGroup(abc.ABC): - @abc.abstractmethod - def post_task_group(self, *, run_id: str, task_ids: List[str]): - """Hook that is called immediately after a task group is created. Note that this is only useful in dynamic - execution, although we reserve the right to add this back into the standard hamilton execution pattern. - - :param run_id: ID of the run, unique in scope of the driver. - :param task_ids: IDs of tasks that are in the group.""" - pass - - -@lifecycle.base_hook("post_task_expand") -class BasePostTaskExpand(abc.ABC): - @abc.abstractmethod - def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]): - """Hook that is called immediately after a task is expanded into separate task. Note that this is only useful - in dynamic execution. - - :param run_id: ID of the run, unique in scope of the driver. - :param task_id: ID of the task. - :param parameters: Parameters that are being passed to each of the expanded tasks.""" - pass - - @lifecycle.base_hook("post_task_execute") class BasePostTaskExecute(abc.ABC): @abc.abstractmethod @@ -711,6 +740,36 @@ async def post_task_execute( pass +@lifecycle.base_hook("post_task_resolution") +class BasePostTaskResolution(abc.ABC): + @abc.abstractmethod + def post_task_resolution( + self, + *, + run_id: str, + task_id: str, + nodes: List["node.Node"], + result: Any, + success: bool, + error: Exception, + spawning_task_id: Optional[str], + purpose: NodeGroupPurpose, + ): + """Hook called immediately after a task future (as submitted to a task executor) is resolved. + Note that this is only useful in dynamic execution, although we reserve the right to add this + back into the standard hamilton execution pattern. + + :param run_id: ID of the run, unique in scope of the driver. + :param task_id: ID of the task + :param result: Return value of the task (from task future resolution). + :param success: Whether or not the task executed successfully + :param error: The error that was raised, if any + :param spawning_task_id: ID of the task that spawned this task + :param purpose: Purpose of the current task group + """ + pass + + @lifecycle.base_hook("post_graph_execute") class BasePostGraphExecute(abc.ABC): @abc.abstractmethod From 13b8f98065148b1ad0a3716ed9a3f3f12a323fd6 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 24 Feb 2025 17:53:47 -0500 Subject: [PATCH 02/31] Make `TaskImplementation` hashable by task ID --- hamilton/execution/grouping.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/hamilton/execution/grouping.py b/hamilton/execution/grouping.py index ed68ffd06..4b91f7f42 100644 --- a/hamilton/execution/grouping.py +++ b/hamilton/execution/grouping.py @@ -154,6 +154,14 @@ def __post_init__(self): super(TaskImplementation, self).__post_init__() self.task_id = self.determine_task_id(self.base_id, self.spawning_task_id, self.group_id) + def __hash__(self) -> int: + return hash(self.task_id) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TaskImplementation): + return False + return self.task_id == other.task_id + class GroupingStrategy(abc.ABC): """Base class for grouping nodes""" From fda3f64d913a197d99687e78f6f4883e8bfd1061 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 24 Feb 2025 17:54:19 -0500 Subject: [PATCH 03/31] Add calls to task submission and resolution hooks --- hamilton/execution/executors.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/hamilton/execution/executors.py b/hamilton/execution/executors.py index dc7827274..c1d7aec9c 100644 --- a/hamilton/execution/executors.py +++ b/hamilton/execution/executors.py @@ -388,6 +388,17 @@ def run_graph_to_completion( if next_task is not None: task_executor = execution_manager.get_executor_for_task(next_task) if task_executor.can_submit_task(): + if next_task.adapter.does_hook("pre_task_submission", is_async=False): + next_task.adapter.call_all_lifecycle_hooks_sync( + "pre_task_submission", + run_id=next_task.run_id, + task_id=next_task.task_id, + nodes=next_task.nodes, + inputs=next_task.dynamic_inputs, + overrides=next_task.overrides, + spawning_task_id=next_task.spawning_task_id, + purpose=next_task.purpose, + ) try: submitted = task_executor.submit_task(next_task) except Exception as e: @@ -396,7 +407,7 @@ def run_graph_to_completion( f"{[item.name for item in next_task.nodes]}" ) raise e - task_futures[next_task.task_id] = submitted + task_futures[next_task] = submitted else: # Whoops, back on the queue # We should probably wait a bit here, but for now we're going to keep @@ -404,12 +415,24 @@ def run_graph_to_completion( execution_state.reject_task(task_to_reject=next_task) # update all the tasks in flight # copy so we can modify - for task_name, task_future in task_futures.copy().items(): + for task, task_future in task_futures.copy().items(): state = task_future.get_state() result = task_future.get_result() - execution_state.update_task_state(task_name, state, result) + execution_state.update_task_state(task.task_id, state, result) if TaskState.is_terminal(state): - del task_futures[task_name] + if task.adapter.does_hook("post_task_resolution", is_async=False): + task.adapter.call_all_lifecycle_hooks_sync( + "post_task_resolution", + run_id=task.run_id, + task_id=task.task_id, + nodes=task.nodes, + success=state == TaskState.SUCCESSFUL, + error=None, # FIXME -- we should get the error here + result=result, + spawning_task_id=task.spawning_task_id, + purpose=task.purpose, + ) + del task_futures[task] logger.info(f"Graph is done, graph state is {execution_state.get_graph_state()}") finally: execution_manager.finalize() From d6df6e4c733b506283898941d08a36f9db1bce6e Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 2 Mar 2025 14:43:19 -0500 Subject: [PATCH 04/31] Add tests for task submission and resolution hooks --- .../lifecycle_adapters_for_testing.py | 33 ++++++ ...ifecycle_adapters_end_to_end_task_based.py | 101 ++++++++++++++++++ 2 files changed, 134 insertions(+) diff --git a/tests/lifecycle/lifecycle_adapters_for_testing.py b/tests/lifecycle/lifecycle_adapters_for_testing.py index 6a5f01c42..32fbe9fef 100644 --- a/tests/lifecycle/lifecycle_adapters_for_testing.py +++ b/tests/lifecycle/lifecycle_adapters_for_testing.py @@ -18,10 +18,12 @@ BasePostTaskExecute, BasePostTaskExpand, BasePostTaskGroup, + BasePostTaskResolution, BasePreDoAnythingHook, BasePreGraphExecute, BasePreNodeExecute, BasePreTaskExecute, + BasePreTaskSubmission, BaseValidateGraph, BaseValidateNode, LifecycleAdapterSet, @@ -151,6 +153,37 @@ def post_node_execute( pass +class TrackingPreTaskSubmissionHook(ExtendToTrackCalls, BasePreTaskSubmission): + def pre_task_submission( + self, + *, + run_id: str, + task_id: str, + nodes: List[Node], + inputs: Dict[str, Any], + overrides: Dict[str, Any], + spawning_task_id: str | None, + purpose: NodeGroupPurpose, + ): + pass + + +class TrackingPostTaskResolutionHook(ExtendToTrackCalls, BasePostTaskResolution): + def post_task_resolution( + self, + *, + run_id: str, + task_id: str, + nodes: List[Node], + result: Any, + success: bool, + error: Exception, + spawning_task_id: str | None, + purpose: NodeGroupPurpose, + ): + pass + + class TrackingPostTaskExecuteHook(ExtendToTrackCalls, BasePostTaskExecute): def post_task_execute( self, diff --git a/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py b/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py index 2f3294cae..1dd739d3d 100644 --- a/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py +++ b/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py @@ -16,10 +16,12 @@ BasePostTaskExecute, BasePostTaskExpand, BasePostTaskGroup, + BasePostTaskResolution, BasePreDoAnythingHook, BasePreGraphExecute, BasePreNodeExecute, BasePreTaskExecute, + BasePreTaskSubmission, ) from hamilton.node import Node @@ -31,7 +33,9 @@ TrackingPostTaskExecuteHook, TrackingPostTaskExpandHook, TrackingPostTaskGroupHook, + TrackingPostTaskResolutionHook, TrackingPreNodeExecuteHook, + TrackingPreTaskSubmissionHook, ) if TYPE_CHECKING: @@ -242,6 +246,72 @@ def test_individual_post_task_expand_hook(): assert len(relevant_calls[0].bound_kwargs["run_id"]) > 10 # Should be UUID(ish)... +def test_individual_pre_task_submission_hook(): + hook_name = "pre_task_submission" + hook = TrackingPreTaskSubmissionHook(name=hook_name) + dr = _sample_driver(hook) + dr.execute(["output"], inputs={"n_iters_input": 5}) + relevant_calls = [item for item in hook.calls if item.name == hook_name] + assert len(relevant_calls) == 10 + spawning_task_ids = Counter([item.bound_kwargs["spawning_task_id"] for item in relevant_calls]) + assert spawning_task_ids == {"expand-parallel_over": 5, None: 5} + purposes = Counter([item.bound_kwargs["purpose"] for item in relevant_calls]) + assert purposes == { + NodeGroupPurpose.EXECUTE_BLOCK: 5, + NodeGroupPurpose.EXECUTE_SINGLE: 3, + NodeGroupPurpose.EXPAND_UNORDERED: 1, + NodeGroupPurpose.GATHER: 1, + } + nodes = {node.name for item in relevant_calls for node in item.bound_kwargs["nodes"]} + assert nodes == { + "parallel_over", + "n_iters", + "processed", + "more_processed", + "collect", + "output", + "n_iters_input", + } + + +def test_individual_post_task_resolution_hook(): + hook_name = "post_task_resolution" + hook = TrackingPostTaskResolutionHook(name=hook_name) + dr = _sample_driver(hook) + dr.execute(["output"], inputs={"n_iters_input": 5}) + relevant_calls = [item for item in hook.calls if item.name == hook_name] + assert len(relevant_calls) == 10 + spawning_task_ids = Counter([item.bound_kwargs["spawning_task_id"] for item in relevant_calls]) + assert spawning_task_ids == {"expand-parallel_over": 5, None: 5} + purposes = Counter([item.bound_kwargs["purpose"] for item in relevant_calls]) + assert purposes == { + NodeGroupPurpose.EXECUTE_BLOCK: 5, + NodeGroupPurpose.EXECUTE_SINGLE: 3, + NodeGroupPurpose.EXPAND_UNORDERED: 1, + NodeGroupPurpose.GATHER: 1, + } + nodes = {node.name for item in relevant_calls for node in item.bound_kwargs["nodes"]} + assert nodes == { + "parallel_over", + "n_iters", + "processed", + "more_processed", + "collect", + "output", + "n_iters_input", + } + results = { + item.bound_kwargs["result"]["more_processed"] + for item in relevant_calls + if "more_processed" in item.bound_kwargs["result"] # only block execute results + } + assert results == {0, 1, 16, 81, 256} + success = {item.bound_kwargs["success"] for item in relevant_calls} + assert success == {True} + errors = {item.bound_kwargs["error"] for item in relevant_calls} + assert errors == {None} + + def test_multi_hook(): class MultiHook( BasePreDoAnythingHook, @@ -255,6 +325,8 @@ class MultiHook( BasePostGraphExecute, BasePostTaskGroup, BasePostTaskExpand, + BasePreTaskSubmission, + BasePostTaskResolution, ExtendToTrackCalls, ): def pre_task_execute( @@ -342,6 +414,33 @@ def post_task_group(self, run_id: str, task_ids: List[str]): def post_task_expand(self, run_id: str, task_id: str, parameters: Dict[str, Any]): pass + def pre_task_submission( + self, + *, + run_id: str, + task_id: str, + nodes: List[Node], + inputs: Dict[str, Any], + overrides: Dict[str, Any], + spawning_task_id: str | None, + purpose: NodeGroupPurpose, + ): + pass + + def post_task_resolution( + self, + *, + run_id: str, + task_id: str, + nodes: List[Node], + result: Any, + success: bool, + error: Exception, + spawning_task_id: str | None, + purpose: NodeGroupPurpose, + ): + pass + multi_hook = MultiHook(name="multi_hook") dr = _sample_driver(multi_hook) @@ -358,6 +457,8 @@ def post_task_expand(self, run_id: str, task_id: str, parameters: Dict[str, Any] "post_node_execute": 14, "post_task_execute": 10, "post_graph_execute": 1, + "pre_task_submission": 10, + "post_task_resolution": 10, "post_task_group": 1, "post_task_expand": 1, } From 9892959f486c2b4ffd4e7eb8403f81e55bea1a64 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 2 Mar 2025 14:51:15 -0500 Subject: [PATCH 05/31] Add docs for task submission and resolution hooks --- docs/reference/lifecycle-hooks/TaskResolutionHook.rst | 9 +++++++++ docs/reference/lifecycle-hooks/TaskSubmissionHook.rst | 9 +++++++++ docs/reference/lifecycle-hooks/index.rst | 2 ++ 3 files changed, 20 insertions(+) create mode 100644 docs/reference/lifecycle-hooks/TaskResolutionHook.rst create mode 100644 docs/reference/lifecycle-hooks/TaskSubmissionHook.rst diff --git a/docs/reference/lifecycle-hooks/TaskResolutionHook.rst b/docs/reference/lifecycle-hooks/TaskResolutionHook.rst new file mode 100644 index 000000000..92e2d8a3d --- /dev/null +++ b/docs/reference/lifecycle-hooks/TaskResolutionHook.rst @@ -0,0 +1,9 @@ +================================ +lifecycle.api.TaskResolutionHook +================================ + + +.. autoclass:: hamilton.lifecycle.api.TaskResolutionHook + :special-members: __init__ + :members: + :inherited-members: diff --git a/docs/reference/lifecycle-hooks/TaskSubmissionHook.rst b/docs/reference/lifecycle-hooks/TaskSubmissionHook.rst new file mode 100644 index 000000000..1d8c09316 --- /dev/null +++ b/docs/reference/lifecycle-hooks/TaskSubmissionHook.rst @@ -0,0 +1,9 @@ +================================ +lifecycle.api.TaskSubmissionHook +================================ + + +.. autoclass:: hamilton.lifecycle.api.TaskSubmissionHook + :special-members: __init__ + :members: + :inherited-members: diff --git a/docs/reference/lifecycle-hooks/index.rst b/docs/reference/lifecycle-hooks/index.rst index c29cc95c6..8b97ea265 100644 --- a/docs/reference/lifecycle-hooks/index.rst +++ b/docs/reference/lifecycle-hooks/index.rst @@ -24,6 +24,8 @@ looking forward. NodeExecutionMethod StaticValidator GraphConstructionHook + TaskSubmissionHook + TaskResolutionHook TaskExecutionHook TaskGroupingHook From 2e31305b3dfab0d2701a385f52b648041108b93f Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 2 Mar 2025 14:53:51 -0500 Subject: [PATCH 06/31] Update comment 'FIXME' -> 'TODO' --- hamilton/execution/executors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hamilton/execution/executors.py b/hamilton/execution/executors.py index c1d7aec9c..f1f7c59bc 100644 --- a/hamilton/execution/executors.py +++ b/hamilton/execution/executors.py @@ -427,7 +427,7 @@ def run_graph_to_completion( task_id=task.task_id, nodes=task.nodes, success=state == TaskState.SUCCESSFUL, - error=None, # FIXME -- we should get the error here + error=None, # TODO -- we could get the error from the task future result=result, spawning_task_id=task.spawning_task_id, purpose=task.purpose, From 7e9b70b444402f000739ab3d74b89acb44654d46 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 2 Mar 2025 15:27:39 -0500 Subject: [PATCH 07/31] Convert `str | None` to `Optional[str]` --- tests/lifecycle/lifecycle_adapters_for_testing.py | 4 ++-- .../test_lifecycle_adapters_end_to_end_task_based.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/lifecycle/lifecycle_adapters_for_testing.py b/tests/lifecycle/lifecycle_adapters_for_testing.py index 32fbe9fef..cc10eedff 100644 --- a/tests/lifecycle/lifecycle_adapters_for_testing.py +++ b/tests/lifecycle/lifecycle_adapters_for_testing.py @@ -162,7 +162,7 @@ def pre_task_submission( nodes: List[Node], inputs: Dict[str, Any], overrides: Dict[str, Any], - spawning_task_id: str | None, + spawning_task_id: Optional[str], purpose: NodeGroupPurpose, ): pass @@ -178,7 +178,7 @@ def post_task_resolution( result: Any, success: bool, error: Exception, - spawning_task_id: str | None, + spawning_task_id: Optional[str], purpose: NodeGroupPurpose, ): pass diff --git a/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py b/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py index 1dd739d3d..dd1aaec90 100644 --- a/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py +++ b/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py @@ -422,7 +422,7 @@ def pre_task_submission( nodes: List[Node], inputs: Dict[str, Any], overrides: Dict[str, Any], - spawning_task_id: str | None, + spawning_task_id: Optional[str], purpose: NodeGroupPurpose, ): pass @@ -436,7 +436,7 @@ def post_task_resolution( result: Any, success: bool, error: Exception, - spawning_task_id: str | None, + spawning_task_id: Optional[str], purpose: NodeGroupPurpose, ): pass From 4efee2c1a8f7dde2008cec250033d160a9aaedb6 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 2 Mar 2025 15:29:17 -0500 Subject: [PATCH 08/31] Fix docstring typos --- hamilton/lifecycle/api.py | 2 +- hamilton/lifecycle/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hamilton/lifecycle/api.py b/hamilton/lifecycle/api.py index c1c2d9f69..278873f43 100644 --- a/hamilton/lifecycle/api.py +++ b/hamilton/lifecycle/api.py @@ -469,7 +469,7 @@ def run_after_task_resolution( purpose: NodeGroupPurpose, **future_kwargs, ): - """Runs after a task (future) has been resolved into a returns value. By definition this is + """Runs after a task (future) has been resolved into a return value. By definition this is run *outside* of the task executor,on the process that executed the driver. :param run_id: ID of the run this is under. diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 616e9a1af..243796889 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -426,7 +426,7 @@ def post_task_group(self, *, run_id: str, task_ids: List[str]): class BasePostTaskExpand(abc.ABC): @abc.abstractmethod def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]): - """Hook that is called immediately after a task is expanded into separate task. Note that this is only useful + """Hook that is called immediately after a task is expanded into parallelizable tasks. Note that this is only useful in dynamic execution. :param run_id: ID of the run, unique in scope of the driver. From a19de6e1b85c797bcaeb9ceefb06c9076639618f Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Fri, 14 Mar 2025 22:17:38 -0400 Subject: [PATCH 09/31] Make `TaskFuture` a protocol, add `TaskFutureWrappingFunction` --- hamilton/execution/executors.py | 49 +++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/hamilton/execution/executors.py b/hamilton/execution/executors.py index f1f7c59bc..75a766ba5 100644 --- a/hamilton/execution/executors.py +++ b/hamilton/execution/executors.py @@ -10,7 +10,7 @@ from concurrent.futures.process import ProcessPoolExecutor from concurrent.futures import Executor, Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Protocol from hamilton import node from hamilton.execution.graph_functions import execute_subdag @@ -20,13 +20,17 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class TaskFuture: +class TaskFuture(Protocol): """Simple representation of a future. TODO -- add cancel(). This a clean wrapper over a python future, and we may end up just using that at some point.""" - get_state: Callable[[], TaskState] - get_result: Callable[[], Any] + def get_state(self) -> TaskState: + """Returns the state of the task.""" + ... + + def get_result(self) -> Any: + """Returns the result of the task.""" + ... class TaskExecutor(abc.ABC): @@ -162,6 +166,36 @@ def base_execute_task(task: TaskImplementation) -> Dict[str, Any]: return final_retval +@dataclasses.dataclass +class TaskFutureWrappingFunction(TaskFuture): + """Wraps a python function call in a TaskFuture.""" + + def __init__(self, function: Callable[[], Any]): + self.function = function + self._results = None + self._done = False + self._exception = None + + def get_state(self): + try: + self._results = self.function() + except Exception as e: + logger.exception("Task failed") + self._exception = e + return TaskState.FAILED + finally: + self._done = True + return TaskState.SUCCESSFUL + + def get_result(self): + if self._exception is not None: + raise self._exception + if not self._done: + self._results = self.function() + self._done = True + return self._results + + class SynchronousLocalTaskExecutor(TaskExecutor): """Basic synchronous/local task executor that runs tasks in the same process, at submit time.""" @@ -172,9 +206,7 @@ def submit_task(self, task: TaskImplementation) -> TaskFuture: :param task: Task to submit :return: Future associated with this task """ - # No error management for now - result = base_execute_task(task) - return TaskFuture(get_state=lambda: TaskState.SUCCESSFUL, get_result=lambda: result) + return TaskFutureWrappingFunction(functools.partial(base_execute_task, task)) def can_submit_task(self) -> bool: """We can always submit a task as the task submission is blocking! @@ -190,6 +222,7 @@ def finalize(self): pass +@dataclasses.dataclass class TaskFutureWrappingPythonFuture(TaskFuture): """Wraps a python future in a TaskFuture""" From e9b7a4e4e32bc924bea3f27f57fb079f0acfab9d Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Fri, 14 Mar 2025 22:18:00 -0400 Subject: [PATCH 10/31] Add better exception handling; clean up code --- hamilton/execution/executors.py | 51 +++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/hamilton/execution/executors.py b/hamilton/execution/executors.py index 75a766ba5..890f35c1e 100644 --- a/hamilton/execution/executors.py +++ b/hamilton/execution/executors.py @@ -416,7 +416,7 @@ def run_graph_to_completion( execution_manager.init() try: while not GraphState.is_terminal(execution_state.get_graph_state()): - # get the next task from the queue + # Get the next task from the queue next_task = execution_state.release_next_task() if next_task is not None: task_executor = execution_manager.get_executor_for_task(next_task) @@ -442,30 +442,37 @@ def run_graph_to_completion( raise e task_futures[next_task] = submitted else: - # Whoops, back on the queue - # We should probably wait a bit here, but for now we're going to keep - # burning through + # TODO: Investigate a backoff strategy, now for add back on the queue execution_state.reject_task(task_to_reject=next_task) - # update all the tasks in flight - # copy so we can modify + + # Update all the tasks in flight (copy so we can modify) for task, task_future in task_futures.copy().items(): + result, error = None, None state = task_future.get_state() - result = task_future.get_result() - execution_state.update_task_state(task.task_id, state, result) - if TaskState.is_terminal(state): - if task.adapter.does_hook("post_task_resolution", is_async=False): - task.adapter.call_all_lifecycle_hooks_sync( - "post_task_resolution", - run_id=task.run_id, - task_id=task.task_id, - nodes=task.nodes, - success=state == TaskState.SUCCESSFUL, - error=None, # TODO -- we could get the error from the task future - result=result, - spawning_task_id=task.spawning_task_id, - purpose=task.purpose, - ) - del task_futures[task] + try: + result = task_future.get_result() + except Exception as e: + logger.exception( + f"Exception resolving task {task.task_id}, with nodes: " + f"{[item.name for item in task.nodes]}" + ) + raise e + finally: + execution_state.update_task_state(task.task_id, state, result) + if TaskState.is_terminal(state): + if task.adapter.does_hook("post_task_resolution", is_async=False): + task.adapter.call_all_lifecycle_hooks_sync( + "post_task_resolution", + run_id=task.run_id, + task_id=task.task_id, + nodes=task.nodes, + success=state == TaskState.SUCCESSFUL, + error=error, + result=result, + spawning_task_id=task.spawning_task_id, + purpose=task.purpose, + ) + del task_futures[task] logger.info(f"Graph is done, graph state is {execution_state.get_graph_state()}") finally: execution_manager.finalize() From a8b042a7e05f4ae1c00457c018a8b2f1713a977d Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Fri, 14 Mar 2025 22:18:11 -0400 Subject: [PATCH 11/31] Add imports for task and submission hooks --- hamilton/lifecycle/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hamilton/lifecycle/__init__.py b/hamilton/lifecycle/__init__.py index 856bf0057..9b864443b 100644 --- a/hamilton/lifecycle/__init__.py +++ b/hamilton/lifecycle/__init__.py @@ -10,6 +10,8 @@ StaticValidator, TaskExecutionHook, TaskGroupingHook, + TaskResolutionHook, + TaskSubmissionHook, ) from .base import LifecycleAdapter # noqa: F401 from .default import ( # noqa: F401 @@ -43,4 +45,6 @@ "TaskGroupingHook", "FunctionInputOutputTypeChecker", "NoEdgeAndInputTypeChecking", + "TaskResolutionHook", + "TaskSubmissionHook", ] From 627bf1b8913a2cd6613154de4f024d456c9b2164 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 16 Mar 2025 22:37:43 -0400 Subject: [PATCH 12/31] Make sure finally block raises exception --- hamilton/execution/executors.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/hamilton/execution/executors.py b/hamilton/execution/executors.py index 890f35c1e..76b3238c7 100644 --- a/hamilton/execution/executors.py +++ b/hamilton/execution/executors.py @@ -442,7 +442,8 @@ def run_graph_to_completion( raise e task_futures[next_task] = submitted else: - # TODO: Investigate a backoff strategy, now for add back on the queue + # Whoops, back on the queue. We should probably wait a bit here, but for + # now we're going to keep burning through execution_state.reject_task(task_to_reject=next_task) # Update all the tasks in flight (copy so we can modify) @@ -456,7 +457,7 @@ def run_graph_to_completion( f"Exception resolving task {task.task_id}, with nodes: " f"{[item.name for item in task.nodes]}" ) - raise e + error = e finally: execution_state.update_task_state(task.task_id, state, result) if TaskState.is_terminal(state): @@ -473,6 +474,8 @@ def run_graph_to_completion( purpose=task.purpose, ) del task_futures[task] + if error: + raise error logger.info(f"Graph is done, graph state is {execution_state.get_graph_state()}") finally: execution_manager.finalize() From c29eb1afafcc2c26c0e8fb402cc008f8be8ceae2 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 16 Mar 2025 22:52:17 -0400 Subject: [PATCH 13/31] Add comments regarding task counts --- ...ifecycle_adapters_end_to_end_task_based.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py b/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py index dd1aaec90..412ed5556 100644 --- a/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py +++ b/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py @@ -252,15 +252,15 @@ def test_individual_pre_task_submission_hook(): dr = _sample_driver(hook) dr.execute(["output"], inputs={"n_iters_input": 5}) relevant_calls = [item for item in hook.calls if item.name == hook_name] - assert len(relevant_calls) == 10 + assert len(relevant_calls) == 10 # Total number of tasks spawning_task_ids = Counter([item.bound_kwargs["spawning_task_id"] for item in relevant_calls]) - assert spawning_task_ids == {"expand-parallel_over": 5, None: 5} + assert spawning_task_ids == {"expand-parallel_over": 5, None: 5} # Number of parallel tasks purposes = Counter([item.bound_kwargs["purpose"] for item in relevant_calls]) assert purposes == { - NodeGroupPurpose.EXECUTE_BLOCK: 5, - NodeGroupPurpose.EXECUTE_SINGLE: 3, - NodeGroupPurpose.EXPAND_UNORDERED: 1, - NodeGroupPurpose.GATHER: 1, + NodeGroupPurpose.EXPAND_UNORDERED: 1, # Expanding task - 'parallel_over' + NodeGroupPurpose.EXECUTE_BLOCK: 5, # Tasks group from 'parallel_over' + NodeGroupPurpose.GATHER: 1, # Gathering task - 'collect' + NodeGroupPurpose.EXECUTE_SINGLE: 3, # All other tasks outside parallelization - single node } nodes = {node.name for item in relevant_calls for node in item.bound_kwargs["nodes"]} assert nodes == { @@ -280,15 +280,15 @@ def test_individual_post_task_resolution_hook(): dr = _sample_driver(hook) dr.execute(["output"], inputs={"n_iters_input": 5}) relevant_calls = [item for item in hook.calls if item.name == hook_name] - assert len(relevant_calls) == 10 + assert len(relevant_calls) == 10 # Total number of tasks spawning_task_ids = Counter([item.bound_kwargs["spawning_task_id"] for item in relevant_calls]) - assert spawning_task_ids == {"expand-parallel_over": 5, None: 5} + assert spawning_task_ids == {"expand-parallel_over": 5, None: 5} # Number of parallel tasks purposes = Counter([item.bound_kwargs["purpose"] for item in relevant_calls]) assert purposes == { - NodeGroupPurpose.EXECUTE_BLOCK: 5, - NodeGroupPurpose.EXECUTE_SINGLE: 3, - NodeGroupPurpose.EXPAND_UNORDERED: 1, - NodeGroupPurpose.GATHER: 1, + NodeGroupPurpose.EXPAND_UNORDERED: 1, # Expanding task - 'parallel_over' + NodeGroupPurpose.EXECUTE_BLOCK: 5, # Tasks group from 'parallel_over' + NodeGroupPurpose.GATHER: 1, # Gathering task - 'collect' + NodeGroupPurpose.EXECUTE_SINGLE: 3, # All other tasks outside parallelization - single node } nodes = {node.name for item in relevant_calls for node in item.bound_kwargs["nodes"]} assert nodes == { From d9754b8c636fab1785a0bebaa20fae8e1dbc0fbc Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 16 Mar 2025 23:00:22 -0400 Subject: [PATCH 14/31] Add submission and resolution hooks to `LifecycleAdapter` union --- hamilton/lifecycle/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 243796889..83741bfbb 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -842,6 +842,8 @@ def do_build_result(self, *, outputs: Any) -> Any: BasePreGraphExecuteAsync, BasePostTaskGroup, BasePostTaskExpand, + BasePreTaskSubmission, + BasePostTaskResolution, BasePreTaskExecute, BasePreTaskExecuteAsync, BasePreNodeExecute, From 87e87436769ce29edb9879ef03186220abf20275 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Thu, 20 Mar 2025 19:31:44 -0400 Subject: [PATCH 15/31] Check cache in `TaskFutureWrappingFunction.get_state` --- hamilton/execution/executors.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/hamilton/execution/executors.py b/hamilton/execution/executors.py index 76b3238c7..ec21de9a9 100644 --- a/hamilton/execution/executors.py +++ b/hamilton/execution/executors.py @@ -177,14 +177,17 @@ def __init__(self, function: Callable[[], Any]): self._exception = None def get_state(self): - try: - self._results = self.function() - except Exception as e: - logger.exception("Task failed") - self._exception = e + if self._exception is not None: return TaskState.FAILED - finally: - self._done = True + if not self._done: + try: + self._results = self.function() + except Exception as e: + logger.exception("Task failed") + self._exception = e + return TaskState.FAILED + finally: + self._done = True return TaskState.SUCCESSFUL def get_result(self): From a5ba41330a8c82eda178d8bc1c1356ba88b0d35f Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sat, 22 Mar 2025 07:17:46 -0400 Subject: [PATCH 16/31] Rename 'task resolution' to 'task return' --- ...skResolutionHook.rst => TaskReturnHook.rst} | 4 ++-- docs/reference/lifecycle-hooks/index.rst | 2 +- hamilton/execution/executors.py | 4 ++-- hamilton/lifecycle/__init__.py | 4 ++-- hamilton/lifecycle/api.py | 18 +++++++++--------- hamilton/lifecycle/base.py | 16 ++++++++-------- .../lifecycle_adapters_for_testing.py | 6 +++--- ...lifecycle_adapters_end_to_end_task_based.py | 16 ++++++++-------- 8 files changed, 35 insertions(+), 35 deletions(-) rename docs/reference/lifecycle-hooks/{TaskResolutionHook.rst => TaskReturnHook.rst} (59%) diff --git a/docs/reference/lifecycle-hooks/TaskResolutionHook.rst b/docs/reference/lifecycle-hooks/TaskReturnHook.rst similarity index 59% rename from docs/reference/lifecycle-hooks/TaskResolutionHook.rst rename to docs/reference/lifecycle-hooks/TaskReturnHook.rst index 92e2d8a3d..a8466c513 100644 --- a/docs/reference/lifecycle-hooks/TaskResolutionHook.rst +++ b/docs/reference/lifecycle-hooks/TaskReturnHook.rst @@ -1,9 +1,9 @@ ================================ -lifecycle.api.TaskResolutionHook +lifecycle.api.TaskReturnHook ================================ -.. autoclass:: hamilton.lifecycle.api.TaskResolutionHook +.. autoclass:: hamilton.lifecycle.api.TaskReturnHook :special-members: __init__ :members: :inherited-members: diff --git a/docs/reference/lifecycle-hooks/index.rst b/docs/reference/lifecycle-hooks/index.rst index 8b97ea265..d67247836 100644 --- a/docs/reference/lifecycle-hooks/index.rst +++ b/docs/reference/lifecycle-hooks/index.rst @@ -25,7 +25,7 @@ looking forward. StaticValidator GraphConstructionHook TaskSubmissionHook - TaskResolutionHook + TaskReturnHook TaskExecutionHook TaskGroupingHook diff --git a/hamilton/execution/executors.py b/hamilton/execution/executors.py index ec21de9a9..443c97f7f 100644 --- a/hamilton/execution/executors.py +++ b/hamilton/execution/executors.py @@ -464,9 +464,9 @@ def run_graph_to_completion( finally: execution_state.update_task_state(task.task_id, state, result) if TaskState.is_terminal(state): - if task.adapter.does_hook("post_task_resolution", is_async=False): + if task.adapter.does_hook("post_task_return", is_async=False): task.adapter.call_all_lifecycle_hooks_sync( - "post_task_resolution", + "post_task_return", run_id=task.run_id, task_id=task.task_id, nodes=task.nodes, diff --git a/hamilton/lifecycle/__init__.py b/hamilton/lifecycle/__init__.py index 9b864443b..68f04859d 100644 --- a/hamilton/lifecycle/__init__.py +++ b/hamilton/lifecycle/__init__.py @@ -10,7 +10,7 @@ StaticValidator, TaskExecutionHook, TaskGroupingHook, - TaskResolutionHook, + TaskReturnHook, TaskSubmissionHook, ) from .base import LifecycleAdapter # noqa: F401 @@ -45,6 +45,6 @@ "TaskGroupingHook", "FunctionInputOutputTypeChecker", "NoEdgeAndInputTypeChecking", - "TaskResolutionHook", + "TaskReturnHook", "TaskSubmissionHook", ] diff --git a/hamilton/lifecycle/api.py b/hamilton/lifecycle/api.py index 278873f43..98db5a9c1 100644 --- a/hamilton/lifecycle/api.py +++ b/hamilton/lifecycle/api.py @@ -32,7 +32,7 @@ BasePostTaskExecute, BasePostTaskExpand, BasePostTaskGroup, - BasePostTaskResolution, + BasePostTaskReturn, BasePreGraphExecute, BasePreNodeExecute, BasePreTaskExecute, @@ -427,12 +427,12 @@ def run_before_task_submission( pass -class TaskResolutionHook(BasePostTaskResolution, abc.ABC): - """Implement this to hook into the task resolution process. Tasks are submitted to an executor, - which then returns a task future to be resolved at a later time when the task is complete.""" +class TaskReturnHook(BasePostTaskReturn, abc.ABC): + """Implement this to hook into the task return process. Tasks are submitted to an executor, + which executes the task and returns the results (or raises an error).""" @override - def post_task_resolution( + def post_task_return( self, *, run_id: str, @@ -444,7 +444,7 @@ def post_task_resolution( spawning_task_id: Optional[str], purpose: NodeGroupPurpose, ): - self.run_after_task_resolution( + self.run_after_task_return( run_id=run_id, task_id=task_id, nodes=nodes, @@ -456,7 +456,7 @@ def post_task_resolution( ) @abc.abstractmethod - def run_after_task_resolution( + def run_after_task_return( self, *, run_id: str, @@ -469,8 +469,8 @@ def run_after_task_resolution( purpose: NodeGroupPurpose, **future_kwargs, ): - """Runs after a task (future) has been resolved into a return value. By definition this is - run *outside* of the task executor,on the process that executed the driver. + """Runs after a task has been returned from a executor. By definition this is run *outside* + of the task executor, on the process that executed the driver. :param run_id: ID of the run this is under. :param task_id: ID of the task that was just executed. diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 83741bfbb..f19d1af4b 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -740,10 +740,10 @@ async def post_task_execute( pass -@lifecycle.base_hook("post_task_resolution") -class BasePostTaskResolution(abc.ABC): +@lifecycle.base_hook("post_task_return") +class BasePostTaskReturn(abc.ABC): @abc.abstractmethod - def post_task_resolution( + def post_task_return( self, *, run_id: str, @@ -755,13 +755,13 @@ def post_task_resolution( spawning_task_id: Optional[str], purpose: NodeGroupPurpose, ): - """Hook called immediately after a task future (as submitted to a task executor) is resolved. - Note that this is only useful in dynamic execution, although we reserve the right to add this - back into the standard hamilton execution pattern. + """Hook called immediately after a task returns from an executor. Note that this is only + useful in dynamic execution, although we reserve the right to add this back into the + standard hamilton execution pattern. :param run_id: ID of the run, unique in scope of the driver. :param task_id: ID of the task - :param result: Return value of the task (from task future resolution). + :param result: Return value of the task. :param success: Whether or not the task executed successfully :param error: The error that was raised, if any :param spawning_task_id: ID of the task that spawned this task @@ -843,7 +843,7 @@ def do_build_result(self, *, outputs: Any) -> Any: BasePostTaskGroup, BasePostTaskExpand, BasePreTaskSubmission, - BasePostTaskResolution, + BasePostTaskReturn, BasePreTaskExecute, BasePreTaskExecuteAsync, BasePreNodeExecute, diff --git a/tests/lifecycle/lifecycle_adapters_for_testing.py b/tests/lifecycle/lifecycle_adapters_for_testing.py index cc10eedff..62ff5cdd2 100644 --- a/tests/lifecycle/lifecycle_adapters_for_testing.py +++ b/tests/lifecycle/lifecycle_adapters_for_testing.py @@ -18,7 +18,7 @@ BasePostTaskExecute, BasePostTaskExpand, BasePostTaskGroup, - BasePostTaskResolution, + BasePostTaskReturn, BasePreDoAnythingHook, BasePreGraphExecute, BasePreNodeExecute, @@ -168,8 +168,8 @@ def pre_task_submission( pass -class TrackingPostTaskResolutionHook(ExtendToTrackCalls, BasePostTaskResolution): - def post_task_resolution( +class TrackingPostTaskReturnHook(ExtendToTrackCalls, BasePostTaskReturn): + def post_task_return( self, *, run_id: str, diff --git a/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py b/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py index 412ed5556..be22dd129 100644 --- a/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py +++ b/tests/lifecycle/test_lifecycle_adapters_end_to_end_task_based.py @@ -16,7 +16,7 @@ BasePostTaskExecute, BasePostTaskExpand, BasePostTaskGroup, - BasePostTaskResolution, + BasePostTaskReturn, BasePreDoAnythingHook, BasePreGraphExecute, BasePreNodeExecute, @@ -33,7 +33,7 @@ TrackingPostTaskExecuteHook, TrackingPostTaskExpandHook, TrackingPostTaskGroupHook, - TrackingPostTaskResolutionHook, + TrackingPostTaskReturnHook, TrackingPreNodeExecuteHook, TrackingPreTaskSubmissionHook, ) @@ -274,9 +274,9 @@ def test_individual_pre_task_submission_hook(): } -def test_individual_post_task_resolution_hook(): - hook_name = "post_task_resolution" - hook = TrackingPostTaskResolutionHook(name=hook_name) +def test_individual_post_task_return_hook(): + hook_name = "post_task_return" + hook = TrackingPostTaskReturnHook(name=hook_name) dr = _sample_driver(hook) dr.execute(["output"], inputs={"n_iters_input": 5}) relevant_calls = [item for item in hook.calls if item.name == hook_name] @@ -326,7 +326,7 @@ class MultiHook( BasePostTaskGroup, BasePostTaskExpand, BasePreTaskSubmission, - BasePostTaskResolution, + BasePostTaskReturn, ExtendToTrackCalls, ): def pre_task_execute( @@ -427,7 +427,7 @@ def pre_task_submission( ): pass - def post_task_resolution( + def post_task_return( self, *, run_id: str, @@ -458,7 +458,7 @@ def post_task_resolution( "post_task_execute": 10, "post_graph_execute": 1, "pre_task_submission": 10, - "post_task_resolution": 10, + "post_task_return": 10, "post_task_group": 1, "post_task_expand": 1, } From 46c7f6a009f9cf615c6d7230a9ba73da82d0f6ee Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 16 Mar 2025 22:41:40 -0400 Subject: [PATCH 17/31] Add context-aware logging adapters (plus helper functions) --- hamilton/plugins/h_logging.py | 432 ++++++++++++++++++++++++++++++++++ 1 file changed, 432 insertions(+) create mode 100644 hamilton/plugins/h_logging.py diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py new file mode 100644 index 000000000..bc32cf726 --- /dev/null +++ b/hamilton/plugins/h_logging.py @@ -0,0 +1,432 @@ +"""Synchronous/asynchronous adapter and functions for context-aware logging with Hamilton.""" + +import logging +import sys +from contextvars import ContextVar +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional, Set, Union + +from hamilton.graph_types import HamiltonNode +from hamilton.lifecycle.api import ( + GraphExecutionHook, + NodeExecutionHook, + TaskExecutionHook, + TaskGroupingHook, + TaskResolutionHook, + TaskSubmissionHook, +) +from hamilton.lifecycle.base import BasePostNodeExecuteAsync, BasePreNodeExecute +from hamilton.node import Node + +try: + from typing import override +except ImportError: + override = lambda x: x # noqa E731 + + +if sys.version_info >= (3, 12): + LoggerAdapter = logging.LoggerAdapter[logging.Logger] +else: + if TYPE_CHECKING: + LoggerAdapter = logging.LoggerAdapter[logging.Logger] + else: + LoggerAdapter = logging.LoggerAdapter + + +@dataclass(frozen=True) +class _LoggingContext: + """Represents the current logging context.""" + + graph: Optional[str] = None + node: Optional[str] = None + task: Optional[str] = None + + +# Context variables for context-aware logging +_local_context = ContextVar("context", default=_LoggingContext()) + + +def get_logger(name: Optional[str] = None) -> "ContextLogger": + """Returns a context-aware logger for the specified name (created if necessary). + + :param name: Name of the logger, defaults to root logger if not provided. + """ + logger = logging.getLogger(name) + return ContextLogger(logger) + + +class ContextLogger(LoggerAdapter): + """Custom logger adapter for Hamilton that adds context to log messages. + + This logger a adds context-aware prefix to log messages based on the current execution. The + logger is intended to be used with hamilton the `LoggingAdapter` lifecycle adapter. The context + is both thread-safe and async-safe. Context includes the following details: + - Graph run + - Task ID + - Node ID + + The adapter also supports the following extra fields: + - `override_context`: Overrides the current context with the specified context. + - `skip_context`: Skips the context for the current log message. + - `node_context`: Includes additional node information for task-based log messages. + """ + + @override + def process( + self, msg: str, kwargs: MutableMapping[str, Any] + ) -> tuple[str, MutableMapping[str, Any]]: + # Ensure that the extra fields are passed through correctly + kwargs["extra"] = {**(self.extra or {}), **(kwargs.get("extra") or {})} + + # Add the current prefix to the log message + prefix = self._get_current_context(kwargs["extra"]) + msg = f"{prefix}{msg}" + + return (msg, kwargs) + + def _get_current_context(self, extra: Mapping[str, Any]) -> str: + """Returns the current context.""" + + # Extra option to override context + context = extra.get("override_context", None) + if not isinstance(context, _LoggingContext): + context = _local_context.get() + + # Extra option to skip context + if "skip_context" in extra: + return "" + + if context.task: + # Extra option to include node information on task-based log messages + if context.node and "node_context" in extra: + return f"Task '{context.task}' - Node '{context.node}' - " + return f"Task '{context.task}' - " + + if context.node: + return f"Node '{context.node}' - " + + if context.graph: + return f"Graph run '{context.graph}' - " + + return "" + + +class LoggingAdapter( + GraphExecutionHook, + NodeExecutionHook, + TaskGroupingHook, + TaskSubmissionHook, + TaskExecutionHook, + TaskResolutionHook, +): + """Hamilton lifecycle adapter that logs runtime execution events. + + This adapter logs the following hamilton events: + - Graph start (`GraphExecutionHook`) + - Task grouping (`TaskGroupingHook`) + - Task submission (`TaskSubmissionHook`) + - Task pre-execution (`TaskExecutionHook`)) + - Node pre-execution (`NodeExecutionHook`) + - Node post-execution (`NodeExecutionHook`) + - Task post-execution (`TaskExecutionHook`)) + - Task resolution (`TaskResolutionHook`) + - Graph completion (`GraphExecutionHook`) + + This adapter can be run with both node-based and task-based execution (using the V2 executor). + When run with a node-based executor, the adapter logs the execution of each *node* as `INFO`. + When run with a task-based executor, the adapter logs the execution of each *task* as `INFO` + and the execution of each *node* as `DEBUG`. + """ + + def __init__(self, logger: Union[str, logging.Logger, None] = None) -> None: + # Precompute or overridden nodes + self._inputs_nodes: Set[str] = set() + self._override_nodes: Set[str] = set() + + if logger is None: + self.logger = logging.getLogger(__name__) + elif isinstance(logger, logging.Logger): + self.logger = logger + else: + self.logger = logging.getLogger(logger) + + if not isinstance(self.logger, ContextLogger): + self.logger = ContextLogger(self.logger) + + self._exception_logged = False # For tracking remote exceptions + + @override + def run_before_graph_execution( + self, + *, + inputs: Optional[Dict[str, Any]], + overrides: Optional[Dict[str, Any]], + run_id: str, + **future_kwargs: Any, + ): + # Set context before logging + _local_context.set(_LoggingContext(graph=run_id)) + + self._inputs_nodes.update(inputs or []) + self._override_nodes.update(overrides or []) + + message = "Starting graph execution" + self.logger.info(message) + + if inputs: + names = ", ".join(f"'{key}'" for key in inputs) + self.logger.info("Using inputs %s", names) + + if overrides: + names = ", ".join(f"'{key}'" for key in overrides) + self.logger.info("Using overrides %s", names) + + @override + def run_after_task_grouping(self, *, run_id: str, task_ids: List[str], **future_kwargs): + self.logger.info("Dynamic DAG detected; task-based logging is enabled") + + @override + def run_after_task_expansion(self, **future_kwargs): + pass # Note currently uses; required for TaskGroupingHook + + @override + def run_before_task_submission( + self, + *, + run_id: str, + task_id: str, + spawning_task_id: Optional[str], + **future_kwargs, + ): + # Set context before logging + _local_context.set(_LoggingContext(graph=run_id, task=task_id)) + + if spawning_task_id: + self.logger.debug("Spawning task and submitting to executor") + else: + self.logger.debug("Initializing new task and submitting to executor") + + @override + def run_before_task_execution( + self, + *, + task_id: str, + run_id: str, + nodes: List[HamiltonNode], + **future_kwargs, + ): + # Set context before logging + _local_context.set(_LoggingContext(graph=run_id, task=task_id)) + + # Do not log if the task matches the inputs or overrides + if task_id in self._inputs_nodes or task_id in self._override_nodes: + return + + message = "Starting execution" + if len(nodes) == 1 and nodes[0].name == task_id: # single node task + self.logger.debug(message) + else: + message += " of nodes %s" + names = ", ".join(f"'{node.name}'" for node in nodes) + self.logger.debug(message, names) + + @override + def run_before_node_execution( + self, + *, + node_name: str, + node_kwargs: Dict[str, Any], + task_id: Optional[str], + run_id: str, + **future_kwargs: Any, + ): + # Set context before logging + _local_context.set(_LoggingContext(graph=run_id, task=task_id, node=node_name)) + + message = "Starting execution" + extra = {"include_task_node": True} + if node_kwargs: + message += " with dependencies %s" + params = ", ".join(f"'{key}'" for key in node_kwargs) + self.logger.debug(message, params, extra=extra) + else: + message += " without dependencies" + self.logger.debug(message, extra=extra) + + @override + def run_after_node_execution( + self, + *, + node_name: str, + error: Optional[Exception], + success: bool, + task_id: Optional[str], + run_id: str, + **future_kwargs: Any, + ): + # Reset context before logging via the token + _local_context.set(_LoggingContext(graph=run_id, task=task_id)) + + # Logger should use previous context and include node information for this method + extra = { + "override_context": _LoggingContext(graph=run_id, task=task_id, node=node_name), + "node_context": True, + } + + if success: + log_func = self.logger.debug if task_id else self.logger.info + log_func("Finished execution [OK]", extra=extra) + elif error: + self.logger.exception("Encountered error", extra=extra) + self._exception_logged = True + + @override + def run_after_task_execution( + self, + *, + task_id: str, + run_id: str, + success: bool, + error: Exception, + **future_kwargs, + ): + # Reset context before logging + _local_context.set(_LoggingContext(graph=run_id)) + + # Logger should use previous context for this method + extra = {"override_context": _LoggingContext(graph=run_id, task=task_id)} + + # Do not log if the task matches the inputs or overrides + if task_id in self._inputs_nodes or task_id in self._override_nodes: + return + + if success: + self.logger.debug("Finished execution [Ok]", extra=extra) + elif error: + self.logger.error("Execution failed due to errors", extra=extra) + + @override + def run_after_task_resolution( + self, + *, + run_id: str, + task_id: str, + nodes: List[Node], + success: bool, + error: Optional[Exception], + **future_kwargs, + ): + # Hard reset context before logging + _local_context.set(_LoggingContext(graph=run_id)) + + # Logger should use previous context for this method + extra = {"override_context": _LoggingContext(graph=run_id, task=task_id)} + + if success: + # Input and override tasks should be logged as debug + if task_id in self._inputs_nodes or task_id in self._override_nodes: + log_func = self.logger.debug + else: + log_func = self.logger.info + log_func("Task completed [OK]", extra=extra) + elif error and not self._exception_logged: + self.logger.exception("Task completion failed due to errors", extra=extra) + self._exception_logged = True + + @override + def run_after_graph_execution( + self, + *, + success: bool, + run_id: str, + **future_kwargs: Any, + ): + # Hard reset context before logging + _local_context.set(_LoggingContext()) + + # Logger should use previous context for this method + extra = {"override_context": _LoggingContext(graph=run_id)} + + if success: + self.logger.info("Finished graph execution [OK]", extra=extra) + else: + self.logger.error("Graph execution failed due to errors", extra=extra) + + +class AsyncLoggingAdapter(GraphExecutionHook, BasePreNodeExecute, BasePostNodeExecuteAsync): + """Async version of the `LoggingAdapter`. + + This adapter logs the following hamilton events: + - Graph start (`BasePreGraphExecuteAsync`) + - Node pre-execution (`BasePreNodeExecuteAsync`) + - Node post-execution (`BasePostNodeExecuteAsync`) + - Graph completion (`BasePostGraphExecuteAsync`) + + Note that this adapter is intended to be used with the async driver. Due to current limitations + with the async driver, is only able to approximate when the async node has been submitted. It + cannot currently log the exact moment the async node begins execution. + """ + + def __init__(self, logger: Union[str, logging.Logger, None] = None) -> None: + self._impl = LoggingAdapter(logger) + + @override + def run_before_graph_execution( + self, + *, + inputs: Dict[str, Any], + overrides: Dict[str, Any], + run_id: str, + **future_kwargs: Any, + ): + self._impl.run_before_graph_execution(inputs=inputs, overrides=overrides, run_id=run_id) + + @override + def pre_node_execute( + self, *, run_id: str, node_: Node, kwargs: Dict[str, Any], task_id: str | None = None + ): + # NOTE: We call the base synchronous method here in order to approximate when the async task + # has bee submitted. This is a workaround until further work is done on the async adapter. + + # Set context before logging + _local_context.set(_LoggingContext(graph=run_id, task=None, node=node_.name)) + + message = "Submitting async node" + extra = {"include_task_node": True} + if kwargs: + message += " with dependencies %s" + params = ", ".join(f"'{key}'" for key in kwargs) + self._impl.logger.debug(message, params, extra=extra) + else: + message += " without dependencies" + self._impl.logger.debug(message, extra=extra) + + @override + async def post_node_execute( + self, + *, + run_id: str, + node_: Node, + kwargs: Dict[str, Any], + success: bool, + error: Exception | None, + result: Any, + task_id: str | None = None, + ): + self._impl.run_after_node_execution( + node_name=node_.name, + error=error, + success=success, + task_id=task_id, + run_id=run_id, + ) + + @override + def run_after_graph_execution( + self, + *, + success: bool, + run_id: str, + **future_kwargs: Any, + ): + self._impl.run_after_graph_execution(success=success, run_id=run_id) From e638ffbb51c7ac399c7781f0c93cd90767f658e9 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 16 Mar 2025 22:42:01 -0400 Subject: [PATCH 18/31] Add tests for Add context-aware logging adapters --- tests/plugins/test_logging.py | 387 ++++++++++++++++++++++++++++++++++ 1 file changed, 387 insertions(+) create mode 100644 tests/plugins/test_logging.py diff --git a/tests/plugins/test_logging.py b/tests/plugins/test_logging.py new file mode 100644 index 000000000..9bd913997 --- /dev/null +++ b/tests/plugins/test_logging.py @@ -0,0 +1,387 @@ +import asyncio +import logging +import sys + +import pytest + +from hamilton import ad_hoc_utils, async_driver, driver +from hamilton.execution import executors +from hamilton.htypes import Collect, Parallelizable +from hamilton.plugins.h_dask import DaskExecutor +from hamilton.plugins.h_logging import AsyncLoggingAdapter, LoggingAdapter, get_logger +from hamilton.plugins.h_ray import RayTaskExecutor +from hamilton.plugins.h_threadpool import FutureAdapter + + +def _split_log_messages(caplog, name): + debug, info, warning, error = [], [], [], [] + for record in caplog.records: + if record.name == name: + if record.levelno == logging.DEBUG: + debug.append(record.message) + elif record.levelno == logging.INFO: + info.append(record.message) + elif record.levelno == logging.WARNING: + warning.append(record.message) + elif record.levelno == logging.ERROR: + error.append(record.message) + return debug, info, warning, error + + +def test_logging_serial_nodes_at_info_level(caplog): + """Test logging of serial nodes at INFO level - log order matters for this test.""" + + name = "test_logging_serial_nodes_at_info_level" + caplog.set_level(logging.INFO, logger=name) + + def a() -> str: + return "a" + + def b(a: str) -> str: + logger = get_logger(name) + logger.warning("Context aware message") + return a + " b" + + def c(b: str) -> str: + return b + " c" + + modules = ad_hoc_utils.create_temporary_module(a, b, c) + dr = driver.Builder().with_modules(modules).with_adapters(LoggingAdapter(name)).build() + result = dr.execute(["c"]) + assert result["c"] == "a b c" + + messages = [record.message for record in caplog.records if record.name == name] + + assert messages[0].startswith("Graph run") + assert messages[1:-1] == [ + "Node 'a' - Finished execution [OK]", + "Node 'b' - Context aware message", + "Node 'b' - Finished execution [OK]", + "Node 'c' - Finished execution [OK]", + ] + assert messages[-1].endswith("- Finished graph execution [OK]") + + levels = [record.levelname for record in caplog.records if record.name == name] + assert levels == ["INFO", "INFO", "WARNING", "INFO", "INFO", "INFO"] + + +def test_logging_serial_nodes_at_debug_level(caplog): + """Test logging of serial nodes at DEBUG level - log order matters for this test.""" + + name = "test_logging_serial_nodes_at_debug_level" + caplog.set_level(logging.DEBUG, logger=name) + + def a() -> str: + return "a" + + def b(a: str) -> str: + logger = get_logger(name) + logger.warning("Context aware message") + return a + " b" + + def c(b: str) -> str: + return b + " c" + + modules = ad_hoc_utils.create_temporary_module(a, b, c) + dr = driver.Builder().with_modules(modules).with_adapters(LoggingAdapter(name)).build() + + result = dr.execute(["c"]) + assert result["c"] == "a b c" + + messages = [record.message for record in caplog.records if record.name == name] + assert messages[0].startswith("Graph run") + assert messages[1:-1] == [ + "Node 'a' - Starting execution without dependencies", + "Node 'a' - Finished execution [OK]", + "Node 'b' - Starting execution with dependencies 'a'", + "Node 'b' - Context aware message", + "Node 'b' - Finished execution [OK]", + "Node 'c' - Starting execution with dependencies 'b'", + "Node 'c' - Finished execution [OK]", + ] + assert messages[-1].endswith("- Finished graph execution [OK]") + + levels = [record.levelname for record in caplog.records if record.name == name] + assert levels == ["INFO", "DEBUG", "INFO", "DEBUG", "WARNING", "INFO", "DEBUG", "INFO", "INFO"] + + +@pytest.mark.parametrize("adapter", [None, FutureAdapter]) +def test_logging_branching_nodes(caplog, adapter): + """Test logging of branching nodes at multiple logging levels.""" + + name = "test_logging_branching_nodes" + caplog.set_level(logging.DEBUG, logger=name) + + def a() -> str: + return "a" + + def b() -> str: + return "b" + + def c() -> str: + logger = get_logger(name) + logger.warning("Context aware message") + return "c" + + def d(a: str, b: str) -> str: + return a + " " + b + " d" + + def e(c: str) -> str: + return c + " e" + + def f(d: str, e: str) -> str: + return d + " " + e + " f" + + modules = ad_hoc_utils.create_temporary_module(a, b, c, d, e, f) + adapters = [LoggingAdapter(name)] + if adapter: + adapters.append(adapter()) + dr = driver.Builder().with_modules(modules).with_adapters(*adapters).build() + + result = dr.execute(["f"]) + assert result["f"] == "a b d c e f" + + debug, info, warning, _ = _split_log_messages(caplog, name) + + assert info[0].startswith("Graph run") + assert set(info[1:-1]) == { + "Node 'a' - Finished execution [OK]", + "Node 'b' - Finished execution [OK]", + "Node 'c' - Finished execution [OK]", + "Node 'd' - Finished execution [OK]", + "Node 'e' - Finished execution [OK]", + "Node 'f' - Finished execution [OK]", + } + assert info[-1].endswith("- Finished graph execution [OK]") + + assert set(debug) == { + "Node 'a' - Starting execution without dependencies", + "Node 'b' - Starting execution without dependencies", + "Node 'c' - Starting execution without dependencies", + "Node 'd' - Starting execution with dependencies 'a', 'b'", + "Node 'e' - Starting execution with dependencies 'c'", + "Node 'f' - Starting execution with dependencies 'd', 'e'", + } + + assert len(warning) == 1 + assert warning[0] == "Node 'c' - Context aware message" + + +def test_logging_async_nodes(caplog): + """Test logging of async nodes at at multiple logging levels.""" + + name = "test_logging_async_nodes" + caplog.set_level(logging.DEBUG, logger=name) + + async def a() -> str: + return "a" + + async def b() -> str: + return "b" + + async def c() -> str: + logger = get_logger(name) + logger.warning("Context aware message") + return "c" + + async def d(a: str, b: str) -> str: + return a + " " + b + " d" + + async def e(c: str) -> str: + return c + " e" + + async def f(d: str, e: str) -> str: + return d + " " + e + " f" + + async def run_async(module, name): + dr = ( + await async_driver.Builder() # type: ignore + .with_modules(module) + .with_adapters(AsyncLoggingAdapter(name)) + .build() + ) + result = await dr.execute(["f"]) + return result + + module = ad_hoc_utils.create_temporary_module(a, b, c, d, e, f) + result = asyncio.run(run_async(module, name)) + + assert result["f"] == "a b d c e f" + + debug, info, warning, _ = _split_log_messages(caplog, name) + + assert info[0].startswith("Graph run") + assert set(info[1:-1]) == { + "Node 'a' - Finished execution [OK]", + "Node 'b' - Finished execution [OK]", + "Node 'c' - Finished execution [OK]", + "Node 'd' - Finished execution [OK]", + "Node 'e' - Finished execution [OK]", + "Node 'f' - Finished execution [OK]", + } + assert info[-1].endswith("- Finished graph execution [OK]") + + assert set(debug) == { + "Node 'a' - Submitting async node without dependencies", + "Node 'b' - Submitting async node without dependencies", + "Node 'c' - Submitting async node without dependencies", + "Node 'd' - Submitting async node with dependencies 'a', 'b'", + "Node 'e' - Submitting async node with dependencies 'c'", + "Node 'f' - Submitting async node with dependencies 'd', 'e'", + } + + assert len(warning) == 1 + assert warning[0] == "Node 'c' - Context aware message" + + +@pytest.mark.parametrize( + ["executor_type", "executor_args"], + [ + (executors.SynchronousLocalTaskExecutor, {}), + pytest.param( + executors.MultiProcessingExecutor, + {"max_tasks": 5}, + marks=pytest.mark.skipif( + sys.platform == "win32", reason="Windows does not support fork" + ), + ), + (executors.MultiThreadingExecutor, {"max_tasks": 5}), + (RayTaskExecutor, {}), + (DaskExecutor, {"client": None}), + ], +) +def test_logging_parallel_nodes(caplog, executor_type, executor_args): + """Test logging of parallel nodes at multiple logging levels.""" + + # NOTE: These test is brittle, as it depends on undocumented names of the expanded tasks. + + name = "test_logging_parallel_nodes_at_info_level" + caplog.set_level(logging.DEBUG, logger=name) + + def b() -> int: + return 5 + + def c(b: int) -> Parallelizable[int]: + for i in range(b): + yield i + + def d(c: int) -> int: + logger = get_logger(name) + logger.warning("Context aware message") + return 2 * c + + def e(d: Collect[int]) -> int: + return sum(d) + + def f(e: int) -> int: + return e + + if executor_type == DaskExecutor: + import dask.distributed + + cluster = dask.distributed.LocalCluster(n_workers=5) + client = dask.distributed.Client(cluster) + executor_args["client"] = client + + modules = ad_hoc_utils.create_temporary_module(b, c, d, e, f) + adapters = [LoggingAdapter(name)] + dr = ( + driver.Builder() + .with_modules(modules) + .with_adapters(*adapters) + .enable_dynamic_execution(allow_experimental_mode=True) + .with_remote_executor(executor_type(**executor_args)) + .build() + ) + + result = dr.execute(["f"]) + assert result["f"] == 20 + + debug, info, warning, _ = _split_log_messages(caplog, name) + + assert info[0].startswith("Graph run") + assert info[1].endswith("task-based logging is enabled") + assert info[-1].endswith("- Finished graph execution [OK]") + + assert set(info[2:-1]) == { + "Task 'b' - Task completed [OK]", + "Task 'expand-c' - Task completed [OK]", + "Task 'expand-c.0.block-c' - Task completed [OK]", + "Task 'expand-c.1.block-c' - Task completed [OK]", + "Task 'expand-c.2.block-c' - Task completed [OK]", + "Task 'expand-c.3.block-c' - Task completed [OK]", + "Task 'expand-c.4.block-c' - Task completed [OK]", + "Task 'collect-c' - Task completed [OK]", + "Task 'f' - Task completed [OK]", + } + + # Note: Certain executors do not log task and node level debug messages (especially if they + # are not running in the same process as the driver). + local_debug_log = { + "Task 'b' - Initializing new task and submitting to executor", + "Task 'b' - Starting execution", + "Task 'b' - Starting execution without dependencies", + "Task 'b' - Node 'b' - Finished execution [OK]", + "Task 'b' - Finished execution [Ok]", + "Task 'expand-c' - Initializing new task and submitting to executor", + "Task 'expand-c' - Starting execution of nodes 'c'", + "Task 'expand-c' - Starting execution with dependencies 'b'", + "Task 'expand-c' - Node 'c' - Finished execution [OK]", + "Task 'expand-c' - Finished execution [Ok]", + "Task 'expand-c.0.block-c' - Spawning task and submitting to executor", + "Task 'expand-c.1.block-c' - Spawning task and submitting to executor", + "Task 'expand-c.2.block-c' - Spawning task and submitting to executor", + "Task 'expand-c.3.block-c' - Spawning task and submitting to executor", + "Task 'expand-c.4.block-c' - Spawning task and submitting to executor", + "Task 'collect-c' - Initializing new task and submitting to executor", + "Task 'collect-c' - Starting execution of nodes 'e'", + "Task 'collect-c' - Starting execution with dependencies 'd'", + "Task 'collect-c' - Node 'e' - Finished execution [OK]", + "Task 'collect-c' - Finished execution [Ok]", + "Task 'f' - Initializing new task and submitting to executor", + "Task 'f' - Starting execution", + "Task 'f' - Starting execution with dependencies 'e'", + "Task 'f' - Node 'f' - Finished execution [OK]", + "Task 'f' - Finished execution [Ok]", + } + assert local_debug_log.issubset(set(debug)) + + +def test_logging_with_inputs(caplog): + """Test logging of nodes with inputs.""" + + name = "test_logging_with_inputs" + caplog.set_level(logging.DEBUG, logger=name) + + def a(x: str) -> str: + return x + + modules = ad_hoc_utils.create_temporary_module(a) + dr = driver.Builder().with_modules(modules).with_adapters(LoggingAdapter(name)).build() + + result = dr.execute(["a"], inputs={"x": "test"}) + assert result["a"] == "test" + + _, info, _, _ = _split_log_messages(caplog, name) + + assert info[1].endswith("Using inputs 'x'") + + +def test_logging_with_overrides(caplog): + """Test logging of nodes with overrides.""" + + name = "test_logging_with_overrides" + caplog.set_level(logging.DEBUG, logger=name) + + def a(x: str) -> str: + return x + + modules = ad_hoc_utils.create_temporary_module(a) + dr = driver.Builder().with_modules(modules).with_adapters(LoggingAdapter(name)).build() + + result = dr.execute(["a"], overrides={"a": "test"}) + assert result["a"] == "test" + + _, info, _, _ = _split_log_messages(caplog, name) + + assert info[1].endswith("Using overrides 'a'") From 8504e067eaa33334ca824d778fe80bb91012eb48 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 16 Mar 2025 23:23:20 -0400 Subject: [PATCH 19/31] Fix docstrings --- hamilton/plugins/h_logging.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index bc32cf726..9857383c9 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -357,14 +357,14 @@ class AsyncLoggingAdapter(GraphExecutionHook, BasePreNodeExecute, BasePostNodeEx """Async version of the `LoggingAdapter`. This adapter logs the following hamilton events: - - Graph start (`BasePreGraphExecuteAsync`) - - Node pre-execution (`BasePreNodeExecuteAsync`) + - Graph start (`GraphExecutionHook`) + - Node pre-execution (`BasePreNodeExecute`) - Node post-execution (`BasePostNodeExecuteAsync`) - - Graph completion (`BasePostGraphExecuteAsync`) + - Graph completion (`GraphExecutionHook`) Note that this adapter is intended to be used with the async driver. Due to current limitations - with the async driver, is only able to approximate when the async node has been submitted. It - cannot currently log the exact moment the async node begins execution. + with the async driver, this adapter is only able to approximate when the async node has been + submitted. It cannot currently log the exact moment the async node begins execution. """ def __init__(self, logger: Union[str, logging.Logger, None] = None) -> None: From 9aaf948b22f50f76c2c76a64dd77cf7ffcce5ba0 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 17 Mar 2025 00:24:04 -0400 Subject: [PATCH 20/31] Fix typo Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- hamilton/lifecycle/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hamilton/lifecycle/api.py b/hamilton/lifecycle/api.py index 98db5a9c1..63e176264 100644 --- a/hamilton/lifecycle/api.py +++ b/hamilton/lifecycle/api.py @@ -764,7 +764,7 @@ def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, A @abc.abstractmethod def run_after_task_grouping(self, *, run_id: str, task_ids: List[str], **future_kwargs): - """Runs after task grouping. This allows you to capture information about the which tasks were + """Runs after task grouping. This allows you to capture information about which tasks were created for a given run. :param run_id: ID of the run, unique in scope of the driver. From 11fb134b193c2d21910b7ce18b97f76666c4ef73 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 17 Mar 2025 00:24:36 -0400 Subject: [PATCH 21/31] Fix typo Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- hamilton/plugins/h_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index 9857383c9..1168eae81 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -58,7 +58,7 @@ def get_logger(name: Optional[str] = None) -> "ContextLogger": class ContextLogger(LoggerAdapter): """Custom logger adapter for Hamilton that adds context to log messages. - This logger a adds context-aware prefix to log messages based on the current execution. The + This logger adds context-aware prefix to log messages based on the current execution. The logger is intended to be used with hamilton the `LoggingAdapter` lifecycle adapter. The context is both thread-safe and async-safe. Context includes the following details: - Graph run From 84925135daa2563f50b63bbd2c90c671a3053da5 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 17 Mar 2025 00:24:46 -0400 Subject: [PATCH 22/31] Fix typo Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- tests/plugins/test_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_logging.py b/tests/plugins/test_logging.py index 9bd913997..34dfb4b01 100644 --- a/tests/plugins/test_logging.py +++ b/tests/plugins/test_logging.py @@ -168,7 +168,7 @@ def f(d: str, e: str) -> str: def test_logging_async_nodes(caplog): - """Test logging of async nodes at at multiple logging levels.""" + """Test logging of async nodes at multiple logging levels.""" name = "test_logging_async_nodes" caplog.set_level(logging.DEBUG, logger=name) From ad45dc7cb22bb26829df98b5665cf114bdcecd0f Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 17 Mar 2025 00:25:09 -0400 Subject: [PATCH 23/31] Fix typo Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- tests/plugins/test_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_logging.py b/tests/plugins/test_logging.py index 34dfb4b01..5a1f67778 100644 --- a/tests/plugins/test_logging.py +++ b/tests/plugins/test_logging.py @@ -322,7 +322,7 @@ def f(e: int) -> int: "Task 'b' - Starting execution", "Task 'b' - Starting execution without dependencies", "Task 'b' - Node 'b' - Finished execution [OK]", - "Task 'b' - Finished execution [Ok]", + "Task 'b' - Finished execution [OK]", "Task 'expand-c' - Initializing new task and submitting to executor", "Task 'expand-c' - Starting execution of nodes 'c'", "Task 'expand-c' - Starting execution with dependencies 'b'", From 370e4de6d71ecc91775e34c96ef036dfdcba2fd4 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 17 Mar 2025 00:27:00 -0400 Subject: [PATCH 24/31] Add message about duplicate exception logs --- hamilton/plugins/h_logging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index 1168eae81..4a35fa2c0 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -330,6 +330,7 @@ def run_after_task_resolution( log_func = self.logger.info log_func("Task completed [OK]", extra=extra) elif error and not self._exception_logged: + # NOTE: _exception_logged is used to prevent duplicate exception logging self.logger.exception("Task completion failed due to errors", extra=extra) self._exception_logged = True From d5ab9f30b3440011f49495b94f956e39ab29d920 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 17 Mar 2025 00:34:33 -0400 Subject: [PATCH 25/31] Temporarily remove dask and ray test --- tests/plugins/test_logging.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/plugins/test_logging.py b/tests/plugins/test_logging.py index 5a1f67778..7b321951f 100644 --- a/tests/plugins/test_logging.py +++ b/tests/plugins/test_logging.py @@ -7,9 +7,11 @@ from hamilton import ad_hoc_utils, async_driver, driver from hamilton.execution import executors from hamilton.htypes import Collect, Parallelizable -from hamilton.plugins.h_dask import DaskExecutor + +# from hamilton.plugins.h_dask import DaskExecutor # FIXME: Not available CI (see below) from hamilton.plugins.h_logging import AsyncLoggingAdapter, LoggingAdapter, get_logger -from hamilton.plugins.h_ray import RayTaskExecutor + +# from hamilton.plugins.h_ray import RayTaskExecutor # FIXME: Not available in the CI (see below) from hamilton.plugins.h_threadpool import FutureAdapter @@ -246,8 +248,8 @@ async def run_async(module, name): ), ), (executors.MultiThreadingExecutor, {"max_tasks": 5}), - (RayTaskExecutor, {}), - (DaskExecutor, {"client": None}), + # (RayTaskExecutor, {}), + # (DaskExecutor, {"client": None}), ], ) def test_logging_parallel_nodes(caplog, executor_type, executor_args): @@ -276,12 +278,13 @@ def e(d: Collect[int]) -> int: def f(e: int) -> int: return e - if executor_type == DaskExecutor: - import dask.distributed + # FIXME: dask is not available in the CI environment + # if executor_type == DaskExecutor: + # import dask.distributed - cluster = dask.distributed.LocalCluster(n_workers=5) - client = dask.distributed.Client(cluster) - executor_args["client"] = client + # cluster = dask.distributed.LocalCluster(n_workers=5) + # client = dask.distributed.Client(cluster) + # executor_args["client"] = client modules = ad_hoc_utils.create_temporary_module(b, c, d, e, f) adapters = [LoggingAdapter(name)] From 9d09ef0f40073f9af06673df067aa16fab6b37f7 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Wed, 19 Mar 2025 22:58:00 -0400 Subject: [PATCH 26/31] Make type hints backward compatible --- hamilton/plugins/h_logging.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index 4a35fa2c0..caf5368ba 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -4,7 +4,18 @@ import sys from contextvars import ContextVar from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional, Set, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) from hamilton.graph_types import HamiltonNode from hamilton.lifecycle.api import ( @@ -74,7 +85,7 @@ class ContextLogger(LoggerAdapter): @override def process( self, msg: str, kwargs: MutableMapping[str, Any] - ) -> tuple[str, MutableMapping[str, Any]]: + ) -> Tuple[str, MutableMapping[str, Any]]: # Ensure that the extra fields are passed through correctly kwargs["extra"] = {**(self.extra or {}), **(kwargs.get("extra") or {})} @@ -384,7 +395,7 @@ def run_before_graph_execution( @override def pre_node_execute( - self, *, run_id: str, node_: Node, kwargs: Dict[str, Any], task_id: str | None = None + self, *, run_id: str, node_: Node, kwargs: Dict[str, Any], task_id: Optional[str] = None ): # NOTE: We call the base synchronous method here in order to approximate when the async task # has bee submitted. This is a workaround until further work is done on the async adapter. @@ -410,9 +421,9 @@ async def post_node_execute( node_: Node, kwargs: Dict[str, Any], success: bool, - error: Exception | None, + error: Optional[Exception], result: Any, - task_id: str | None = None, + task_id: Optional[str] = None, ): self._impl.run_after_node_execution( node_name=node_.name, From a53a61ffad0baf78d9c9d3f47156ca258449e084 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Wed, 19 Mar 2025 22:58:30 -0400 Subject: [PATCH 27/31] =?UTF-8?q?Fix=20log=20typo=20=F0=9F=A4=A6=E2=80=8D?= =?UTF-8?q?=E2=99=82=EF=B8=8F=20and=20test=20assertion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- hamilton/plugins/h_logging.py | 2 +- tests/plugins/test_logging.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index caf5368ba..ca2a09365 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -312,7 +312,7 @@ def run_after_task_execution( return if success: - self.logger.debug("Finished execution [Ok]", extra=extra) + self.logger.debug("Finished execution [OK]", extra=extra) elif error: self.logger.error("Execution failed due to errors", extra=extra) diff --git a/tests/plugins/test_logging.py b/tests/plugins/test_logging.py index 7b321951f..80a8610a2 100644 --- a/tests/plugins/test_logging.py +++ b/tests/plugins/test_logging.py @@ -330,7 +330,7 @@ def f(e: int) -> int: "Task 'expand-c' - Starting execution of nodes 'c'", "Task 'expand-c' - Starting execution with dependencies 'b'", "Task 'expand-c' - Node 'c' - Finished execution [OK]", - "Task 'expand-c' - Finished execution [Ok]", + "Task 'expand-c' - Finished execution [OK]", "Task 'expand-c.0.block-c' - Spawning task and submitting to executor", "Task 'expand-c.1.block-c' - Spawning task and submitting to executor", "Task 'expand-c.2.block-c' - Spawning task and submitting to executor", @@ -340,14 +340,14 @@ def f(e: int) -> int: "Task 'collect-c' - Starting execution of nodes 'e'", "Task 'collect-c' - Starting execution with dependencies 'd'", "Task 'collect-c' - Node 'e' - Finished execution [OK]", - "Task 'collect-c' - Finished execution [Ok]", + "Task 'collect-c' - Finished execution [OK]", "Task 'f' - Initializing new task and submitting to executor", "Task 'f' - Starting execution", "Task 'f' - Starting execution with dependencies 'e'", "Task 'f' - Node 'f' - Finished execution [OK]", - "Task 'f' - Finished execution [Ok]", + "Task 'f' - Finished execution [OK]", } - assert local_debug_log.issubset(set(debug)) + assert local_debug_log.difference(set(debug)) == set() def test_logging_with_inputs(caplog): From ad2246623be34531cac38864ceb4f8347bc2c66e Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Wed, 19 Mar 2025 23:01:58 -0400 Subject: [PATCH 28/31] Add context check; move task nodes to separate file --- tests/plugins/test_logging.py | 56 +++++++++--------------- tests/plugins/test_logging_task_nodes.py | 29 ++++++++++++ 2 files changed, 49 insertions(+), 36 deletions(-) create mode 100644 tests/plugins/test_logging_task_nodes.py diff --git a/tests/plugins/test_logging.py b/tests/plugins/test_logging.py index 80a8610a2..f1ce4e34a 100644 --- a/tests/plugins/test_logging.py +++ b/tests/plugins/test_logging.py @@ -1,12 +1,10 @@ import asyncio import logging -import sys import pytest from hamilton import ad_hoc_utils, async_driver, driver from hamilton.execution import executors -from hamilton.htypes import Collect, Parallelizable # from hamilton.plugins.h_dask import DaskExecutor # FIXME: Not available CI (see below) from hamilton.plugins.h_logging import AsyncLoggingAdapter, LoggingAdapter, get_logger @@ -14,6 +12,8 @@ # from hamilton.plugins.h_ray import RayTaskExecutor # FIXME: Not available in the CI (see below) from hamilton.plugins.h_threadpool import FutureAdapter +from . import test_logging_task_nodes + def _split_log_messages(caplog, name): debug, info, warning, error = [], [], [], [] @@ -237,47 +237,23 @@ async def run_async(module, name): @pytest.mark.parametrize( - ["executor_type", "executor_args"], + ["executor_type", "executor_args", "check_context"], [ - (executors.SynchronousLocalTaskExecutor, {}), - pytest.param( - executors.MultiProcessingExecutor, - {"max_tasks": 5}, - marks=pytest.mark.skipif( - sys.platform == "win32", reason="Windows does not support fork" - ), - ), - (executors.MultiThreadingExecutor, {"max_tasks": 5}), - # (RayTaskExecutor, {}), - # (DaskExecutor, {"client": None}), + (executors.SynchronousLocalTaskExecutor, {}, True), + (executors.MultiProcessingExecutor, {"max_tasks": 1}, False), + (executors.MultiThreadingExecutor, {"max_tasks": 2}, True), + # (RayTaskExecutor, {}, True), # FIXME: Not available in the CI environment + # (DaskExecutor, {"client": None}, False), # FIXME: Not available in the CI environment ], ) -def test_logging_parallel_nodes(caplog, executor_type, executor_args): +def test_logging_parallel_nodes(caplog, executor_type, executor_args, check_context): """Test logging of parallel nodes at multiple logging levels.""" # NOTE: These test is brittle, as it depends on undocumented names of the expanded tasks. - name = "test_logging_parallel_nodes_at_info_level" + name = "test_logging_parallel_nodes" caplog.set_level(logging.DEBUG, logger=name) - def b() -> int: - return 5 - - def c(b: int) -> Parallelizable[int]: - for i in range(b): - yield i - - def d(c: int) -> int: - logger = get_logger(name) - logger.warning("Context aware message") - return 2 * c - - def e(d: Collect[int]) -> int: - return sum(d) - - def f(e: int) -> int: - return e - # FIXME: dask is not available in the CI environment # if executor_type == DaskExecutor: # import dask.distributed @@ -286,11 +262,10 @@ def f(e: int) -> int: # client = dask.distributed.Client(cluster) # executor_args["client"] = client - modules = ad_hoc_utils.create_temporary_module(b, c, d, e, f) adapters = [LoggingAdapter(name)] dr = ( driver.Builder() - .with_modules(modules) + .with_modules(test_logging_task_nodes) .with_adapters(*adapters) .enable_dynamic_execution(allow_experimental_mode=True) .with_remote_executor(executor_type(**executor_args)) @@ -349,6 +324,15 @@ def f(e: int) -> int: } assert local_debug_log.difference(set(debug)) == set() + if check_context: + assert set(warning) == { + "Task 'expand-c.0.block-c' - Context aware message", + "Task 'expand-c.1.block-c' - Context aware message", + "Task 'expand-c.2.block-c' - Context aware message", + "Task 'expand-c.3.block-c' - Context aware message", + "Task 'expand-c.4.block-c' - Context aware message", + } + def test_logging_with_inputs(caplog): """Test logging of nodes with inputs.""" diff --git a/tests/plugins/test_logging_task_nodes.py b/tests/plugins/test_logging_task_nodes.py new file mode 100644 index 000000000..323ff5300 --- /dev/null +++ b/tests/plugins/test_logging_task_nodes.py @@ -0,0 +1,29 @@ +# NOTE: This file contains nodes for the test 'test_logging_parallel_nodes' in +# test_logging_task_nodes.py. They are required to be in a separate file in order to properly +# test the multi-processing executor. + +from hamilton.htypes import Collect, Parallelizable +from hamilton.plugins.h_logging import get_logger + + +def b() -> int: + return 5 + + +def c(b: int) -> Parallelizable[int]: + for i in range(b): + yield i + + +def d(c: int) -> int: + logger = get_logger("test_logging_parallel_nodes") + logger.warning("Context aware message") + return 2 * c + + +def e(d: Collect[int]) -> int: + return sum(d) + + +def f(e: int) -> int: + return e From 37f36ea62dc3c2efd2e6d9ac48cbf058843d2716 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Thu, 20 Mar 2025 20:06:50 -0400 Subject: [PATCH 29/31] Add default `extra` parameter for python <3.10 --- hamilton/plugins/h_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index ca2a09365..0d07b9393 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -63,7 +63,7 @@ def get_logger(name: Optional[str] = None) -> "ContextLogger": :param name: Name of the logger, defaults to root logger if not provided. """ logger = logging.getLogger(name) - return ContextLogger(logger) + return ContextLogger(logger, extra=None) class ContextLogger(LoggerAdapter): From 713538a1a6806f2eaf722a238a18026d1cc17e0c Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Thu, 20 Mar 2025 20:40:17 -0400 Subject: [PATCH 30/31] Use empty dict for default `extra` parameters --- hamilton/plugins/h_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index 0d07b9393..74cf0204c 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -63,7 +63,7 @@ def get_logger(name: Optional[str] = None) -> "ContextLogger": :param name: Name of the logger, defaults to root logger if not provided. """ logger = logging.getLogger(name) - return ContextLogger(logger, extra=None) + return ContextLogger(logger, extra={}) class ContextLogger(LoggerAdapter): @@ -162,7 +162,7 @@ def __init__(self, logger: Union[str, logging.Logger, None] = None) -> None: self.logger = logging.getLogger(logger) if not isinstance(self.logger, ContextLogger): - self.logger = ContextLogger(self.logger) + self.logger = ContextLogger(self.logger, extra={}) self._exception_logged = False # For tracking remote exceptions From de99d49688496f8a84db9670800e6d1744b963ee Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sat, 22 Mar 2025 07:25:52 -0400 Subject: [PATCH 31/31] Update 'task resolution' to 'task return' --- hamilton/plugins/h_logging.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index 74cf0204c..542b6c425 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -23,7 +23,7 @@ NodeExecutionHook, TaskExecutionHook, TaskGroupingHook, - TaskResolutionHook, + TaskReturnHook, TaskSubmissionHook, ) from hamilton.lifecycle.base import BasePostNodeExecuteAsync, BasePreNodeExecute @@ -128,7 +128,7 @@ class LoggingAdapter( TaskGroupingHook, TaskSubmissionHook, TaskExecutionHook, - TaskResolutionHook, + TaskReturnHook, ): """Hamilton lifecycle adapter that logs runtime execution events. @@ -140,7 +140,7 @@ class LoggingAdapter( - Node pre-execution (`NodeExecutionHook`) - Node post-execution (`NodeExecutionHook`) - Task post-execution (`TaskExecutionHook`)) - - Task resolution (`TaskResolutionHook`) + - Task resolution (`TaskReturnHook`) - Graph completion (`GraphExecutionHook`) This adapter can be run with both node-based and task-based execution (using the V2 executor). @@ -317,7 +317,7 @@ def run_after_task_execution( self.logger.error("Execution failed due to errors", extra=extra) @override - def run_after_task_resolution( + def run_after_task_return( self, *, run_id: str,