From 7a7df5f161481e610ca13b469d7d5ce1e0ca1cfe Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sun, 13 Apr 2025 13:23:09 -0400 Subject: [PATCH 1/8] Use single synchronous extractor function --- hamilton/function_modifiers/expanders.py | 33 +++++++----------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 4aa380129..f0d56840e 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -836,30 +836,15 @@ def dict_generator(*args, **kwargs): for field, field_type in self.fields.items(): doc_string = base_doc # default doc string of base function. - # if fn is async - if inspect.iscoroutinefunction(fn): - - async def extractor_fn(field_to_extract: str = field, **kwargs) -> field_type: - dt = kwargs[node_.name] - if field_to_extract not in dt: - raise base.InvalidDecoratorException( - f"No such field: {field_to_extract} produced by {node_.name}. " - f"It only produced {list(dt.keys())}" - ) - return kwargs[node_.name][field_to_extract] - - else: - - def extractor_fn( - field_to_extract: str = field, **kwargs - ) -> field_type: # avoiding problems with closures - dt = kwargs[node_.name] - if field_to_extract not in dt: - raise base.InvalidDecoratorException( - f"No such field: {field_to_extract} produced by {node_.name}. " - f"It only produced {list(dt.keys())}" - ) - return kwargs[node_.name][field_to_extract] + # This extractor is constructed to avoid closure issues. + def extractor_fn(field_to_extract: str = field, **kwargs) -> field_type: # type: ignore + dt = kwargs[node_.name] + if field_to_extract not in dt: + raise base.InvalidDecoratorException( + f"No such field: {field_to_extract} produced by {node_.name}. " + f"It only produced {list(dt.keys())}" + ) + return kwargs[node_.name][field_to_extract] output_nodes.append( node.Node( From f23ed37a56cda27166e6874a4ebd88b9cd5945e4 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Thu, 8 May 2025 22:28:36 -0400 Subject: [PATCH 2/8] Upgrade `extract_fields` --- hamilton/function_modifiers/expanders.py | 149 ++++++--- tests/function_modifiers/test_expanders.py | 334 ++++++++++++++++----- 2 files changed, 370 insertions(+), 113 deletions(-) diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index f0d56840e..1207ad2d8 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -3,7 +3,7 @@ import functools import inspect import typing -from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union import typing_extensions import typing_inspect @@ -699,6 +699,88 @@ def extractor_fn( return output_nodes +def _process_extract_fields( + fields: Optional[Dict[str, Any] | List[str]], output_type: Any +) -> Dict[str, Any]: + """Processes the fields and base output type to extract a dict of field types. + + :param fields: Dict of fields to extract. + :param output_type: The output type of the node function. + :return: List of field types. + """ + + output_type_error = ( + f"For extracting fields, the decorated function output type must be a `dict` or a " + f"`typing.Dict` with or without type parameters (i.e. `dict[str, int]` or " + f"`typing.Dict[str, int]`), not: {output_type}" + ) + + if output_type == dict or output_type == Dict: + # NOTE: typing_inspect.is_generic_type(typing.Dict) without type parameters returns True, + # so we need to address the bare dictionaries first before generics. + if fields is None or not isinstance(fields, dict): + raise base.InvalidDecoratorException( + "When extracting fields from a function that returns a bare `dict` output without " + "type parameters, you must supply a `dict` mapping field names to types." + ) + elif typing_inspect.is_generic_type(output_type): + base_type = typing_inspect.get_origin(output_type) + if base_type != dict and base_type != Dict: + raise base.InvalidDecoratorException(output_type_error) + if fields is None: + raise base.InvalidDecoratorException( + "When extracting fields from a function that returns a generic `dict`, you must " + "supply either a `dict` (`typing.Dict`) mapping field names to types or " + "alternatively a `list` (`typing.List`) of field names." + ) + output_args = typing_inspect.get_args(output_type) + if len(output_args) != 2: + raise base.InvalidDecoratorException( + f"When extracting fields from a function that returns a generic `dict`, you " + f"must specify only two type parameters (key, value), not {output_args}." + ) + if isinstance(fields, list): + fields = {field: output_args[1] for field in fields} # Infer type from annotation + elif typing_extensions.is_typeddict(output_type): + typed_dict_fields = typing.get_type_hints(output_type) # Dict of field name -> type + errors = [] + if fields is None: + fields = typed_dict_fields # Infer fields and types from annotation + elif isinstance(fields, list): + reduced_fields = {} + for field in fields: + if field not in typed_dict_fields: + errors.append(f"{field} is not a field in the `TypedDict` {output_type}.") + reduced_fields[field] = typed_dict_fields[field] + fields = reduced_fields + elif isinstance(fields, dict): + for field_name, field_type in fields.items(): + expected_type = typed_dict_fields.get(field_name, None) + if expected_type is None: + errors.append(f"{field_name} is not a field in the `TypedDict` {output_type}.") + continue + elif expected_type == field_type or htypes.custom_subclass_check( + field_type, expected_type + ): + continue + errors.append( + f"Error {field_name} did not match the TypedDict annotation's field " + f"{field_type}. Expected {expected_type}." + ) + if errors: + raise base.InvalidDecoratorException( + f"Error {fields} did not match a subset of the TypedDict annotation's fields " + f"{typed_dict_fields}. The following fields were not valid: {errors}." + ) + else: + raise base.InvalidDecoratorException(output_type_error) + + assert isinstance(fields, dict), "Internal error: fields should be a dict at this point." + _validate_extract_fields(fields) + + return fields + + def _validate_extract_fields(fields: dict): """Validates the fields dict for extract field. Rules are: @@ -739,18 +821,31 @@ def _validate_extract_fields(fields: dict): class extract_fields(base.SingleNodeNodeTransformer): """Extracts fields from a dictionary of output.""" - def __init__(self, fields: dict = None, fill_with: Any = None): + output_type: Any + resolved_fields: Dict[str, Type] + + def __init__( + self, + fields: Optional[Dict[str, Any] | List[str] | Any] = None, + *others, + fill_with: Any = None, + ): """Constructor for a modifier that expands a single function into the following nodes: - n functions, each of which take in the original dict and output a specific field - 1 function that outputs the original dict - :param fields: Fields to extract. A dict of 'field_name' -> 'field_type'. + :param fields: Fields to extract. Can be a dict of field names to types, a list of field names, or a single field name. + :param others: Additional fields names to extract - argument unpacking. Ignored if `fields` is a dict. :param fill_with: If you want to extract a field that doesn't exist, do you want to fill it with a default \ value? Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a \ field value. """ super(extract_fields, self).__init__() + if isinstance(fields, list): + fields = fields + list(others) + elif fields and not isinstance(fields, dict): + fields = [fields] + list(others) self.fields = fields self.fill_with = fill_with @@ -760,40 +855,8 @@ def validate(self, fn: Callable): :param fn: Function to validate. :raises: InvalidDecoratorException If the function is not annotated with a dict or typing.Dict type as output. """ - output_type = typing.get_type_hints(fn).get("return") - if typing_inspect.is_generic_type(output_type): - base_type = typing_inspect.get_origin(output_type) - if base_type == dict or base_type == Dict: - _validate_extract_fields(self.fields) - else: - raise base.InvalidDecoratorException( - f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}" - ) - elif output_type == dict: - _validate_extract_fields(self.fields) - elif typing_extensions.is_typeddict(output_type): - if self.fields is None: - self.fields = typing.get_type_hints(output_type) - else: - # check that fields is a subset of TypedDict that is defined - typed_dict_fields = typing.get_type_hints(output_type) - for field_name, field_type in self.fields.items(): - expected_type = typed_dict_fields.get(field_name, None) - if expected_type == field_type: - pass # we're definitely good - elif expected_type is not None and htypes.custom_subclass_check( - field_type, expected_type - ): - pass - else: - raise base.InvalidDecoratorException( - f"Error {self.fields} did not match a subset of the TypedDict annotation's fields {typed_dict_fields}." - ) - _validate_extract_fields(self.fields) - else: - raise base.InvalidDecoratorException( - f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}" - ) + self.output_type = typing.get_type_hints(fn).get("return") + self.resolved_fields = _process_extract_fields(self.fields, self.output_type) def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable @@ -813,27 +876,27 @@ def transform_node( # if fn is async if inspect.iscoroutinefunction(fn): - async def dict_generator(*args, **kwargs): + async def dict_generator(*args, **kwargs): # type: ignore dict_generated = await fn(*args, **kwargs) if self.fill_with is not None: - for field in self.fields: + for field in self.resolved_fields: if field not in dict_generated: dict_generated[field] = self.fill_with return dict_generated else: - def dict_generator(*args, **kwargs): + def dict_generator(*args, **kwargs): # type: ignore dict_generated = fn(*args, **kwargs) if self.fill_with is not None: - for field in self.fields: + for field in self.resolved_fields: if field not in dict_generated: dict_generated[field] = self.fill_with return dict_generated output_nodes = [node_.copy_with(callabl=dict_generator)] - for field, field_type in self.fields.items(): + for field, field_type in self.resolved_fields.items(): doc_string = base_doc # default doc string of base function. # This extractor is constructed to avoid closure issues. @@ -852,7 +915,7 @@ def extractor_fn(field_to_extract: str = field, **kwargs) -> field_type: # type field_type, doc_string, extractor_fn, - input_types={node_.name: dict}, + input_types={node_.name: self.output_type}, tags=node_.tags.copy(), ) ) diff --git a/tests/function_modifiers/test_expanders.py b/tests/function_modifiers/test_expanders.py index f272cfd54..6939b26e3 100644 --- a/tests/function_modifiers/test_expanders.py +++ b/tests/function_modifiers/test_expanders.py @@ -333,24 +333,6 @@ class MyDictBad(TypedDict): test2: str -@pytest.mark.parametrize( - "return_type", - [ - dict, - Dict, - Dict[str, str], - Dict[str, Any], - MyDict, - ], -) -def test_extract_fields_validate_happy(return_type): - def return_dict() -> return_type: - return {} - - annotation = function_modifiers.extract_fields({"test": int}) - annotation.validate(return_dict) - - class SomeObject: pass @@ -369,95 +351,306 @@ class MyDictInheritanceBadCase(TypedDict): test2: str -def test_extract_fields_validate_happy_inheritance(): - def return_dict() -> MyDictInheritance: - return {} - - annotation = function_modifiers.extract_fields({"test": InheritedObject}) - annotation.validate(return_dict) - - -def test_extract_fields_validate_not_subclass(): - def return_dict() -> MyDictInheritanceBadCase: - return {} - - annotation = function_modifiers.extract_fields({"test": SomeObject}) - with pytest.raises(base.InvalidDecoratorException): - annotation.validate(return_dict) - - @pytest.mark.parametrize( - "return_type", - [(int), (list), (np.ndarray), (pd.DataFrame), (MyDictBad)], + "return_type_str,fields", + [ + ("Dict[str, int]", ("A", "B")), + ("Dict[str, int]", (["A", "B"])), + ("Dict", {"A": str, "B": int}), + ("MyDict", ()), + ("MyDict", {"test2": str}), + ("MyDictInheritance", {"test": InheritedObject}), + pytest.param("dict[str, int]", ("A", "B"), marks=skipif(**prior_to_py39)), + pytest.param("dict[str, int]", (["A", "B"]), marks=skipif(**prior_to_py39)), + pytest.param("dict", {"A": str, "B": int}, marks=skipif(**prior_to_py39)), + ], ) -def test_extract_fields_validate_errors(return_type): - def return_dict() -> return_type: - return {} - - annotation = function_modifiers.extract_fields({"test": int}) - with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): - annotation.validate(return_dict) +def test_extract_fields_valid_annotations_for_inferred_types(return_type_str, fields): + return_type = eval(return_type_str) + def function() -> return_type: # type: ignore + return {} # Only testing validation, so return value doesn't matter -def test_extract_fields_typeddict_empty_fields(): - def return_dict() -> MyDict: - return {} + if isinstance(fields, tuple): + annotation = function_modifiers.extract_fields(*fields) + else: + annotation = function_modifiers.extract_fields(fields) + annotation.validate(function) - # don't need fields for TypedDict - annotation = function_modifiers.extract_fields() - annotation.validate(return_dict) +@pytest.mark.parametrize( + "return_type_str,fields", + [ + ("Dict", ("A", "B")), + ("Dict", (["A", "B"])), + ("Dict", (["A"])), + ("Dict", (["A", "B", "C"])), + ("int", {"A": int}), + ("list", {"A": int}), + ("np.ndarray", {"A": int}), + ("pd.DataFrame", {"A": int}), + ("MyDictBad", {"A": int}), + ("MyDictInheritanceBadCase", {"A": SomeObject}), + pytest.param("dict", ("A", "B"), marks=skipif(**prior_to_py39)), + pytest.param("dict", (["A", "B"]), marks=skipif(**prior_to_py39)), + pytest.param("dict", (["A"]), marks=skipif(**prior_to_py39)), + pytest.param("dict", (["A", "B", "C"]), marks=skipif(**prior_to_py39)), + ], +) +def test_extract_fields_invalid_annotations_for_inferred_types(return_type_str, fields): + return_type = eval(return_type_str) -def test_extract_fields_typeddict_subset(): - def return_dict() -> MyDict: - return {} + def function() -> return_type: # type: ignore + return {} # Only testing validation, so return value doesn't matter - # test that a subset of fields is fine - annotation = function_modifiers.extract_fields({"test2": str}) - annotation.validate(return_dict) + if isinstance(fields, tuple): + annotation = function_modifiers.extract_fields(*fields) + else: + annotation = function_modifiers.extract_fields(fields) + with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): + annotation.validate(function) -def test_valid_extract_fields(): - """Tests whole extract_fields decorator.""" +def test_extract_fields_transform_on_bare_dict_with_explicit_types(): + """Tests whole extract_fields decorator using a bare, non-generic, dict and explicit types.""" annotation = function_modifiers.extract_fields( {"col_1": list, "col_2": int, "col_3": np.ndarray} ) - def dummy_dict_generator() -> dict: + def dummy_dict() -> dict: # bare dict, not generic """dummy doc""" return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])} - nodes = list( - annotation.transform_node(node.Node.from_fn(dummy_dict_generator), {}, dummy_dict_generator) - ) + annotation.validate(dummy_dict) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + assert len(nodes) == 4 assert nodes[0] == node.Node( - name=dummy_dict_generator.__name__, + name=dummy_dict.__name__, typ=dict, - doc_string=dummy_dict_generator.__doc__, - callabl=dummy_dict_generator, + doc_string=getattr(dummy_dict, "__doc__", ""), + callabl=dummy_dict, tags={"module": "tests.function_modifiers.test_expanders"}, ) assert nodes[1].name == "col_1" assert nodes[1].type == list assert nodes[1].documentation == "dummy doc" # we default to base function doc. - assert nodes[1].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + assert nodes[1].input_types == {dummy_dict.__name__: (dict, DependencyType.REQUIRED)} assert nodes[2].name == "col_2" assert nodes[2].type == int assert nodes[2].documentation == "dummy doc" - assert nodes[2].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + assert nodes[2].input_types == {dummy_dict.__name__: (dict, DependencyType.REQUIRED)} assert nodes[3].name == "col_3" assert nodes[3].type == np.ndarray assert nodes[3].documentation == "dummy doc" - assert nodes[3].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + assert nodes[3].input_types == {dummy_dict.__name__: (dict, DependencyType.REQUIRED)} + + +def test_extract_fields_transform_on_generic_dict_with_explicit_types(): + """Tests whole extract_fields decorator using a generic dict and explicit types.""" + annotation = function_modifiers.extract_fields({"col_1": int, "col_2": int}) + + def dummy_dict() -> Dict[str, int]: + """dummy doc""" + return {"col_1": 1, "col_2": 2} + + annotation.validate(dummy_dict) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + + assert len(nodes) == 3 + assert nodes[0] == node.Node( + name=dummy_dict.__name__, + typ=Dict[str, int], + doc_string=getattr(dummy_dict, "__doc__", ""), + callabl=dummy_dict, + tags={"module": "tests.function_modifiers.test_expanders"}, + ) + + assert nodes[1].name == "col_1" + assert nodes[1].type == int + assert nodes[1].documentation == "dummy doc" # we default to base function doc. + assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int], DependencyType.REQUIRED)} + assert nodes[2].name == "col_2" + assert nodes[2].type == int + assert nodes[2].documentation == "dummy doc" + assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int], DependencyType.REQUIRED)} + + +def test_extract_fields_transform_on_generic_dict_with_field_list(): + """Tests whole extract_fields decorator using a generic dict and a list of field names.""" + annotation = function_modifiers.extract_fields(["col_1", "col_2"]) + + def dummy_dict() -> Dict[str, int]: + """dummy doc""" + return {"col_1": 1, "col_2": 2} + + annotation.validate(dummy_dict) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + + assert len(nodes) == 3 + assert nodes[0] == node.Node( + name=dummy_dict.__name__, + typ=Dict[str, int], + doc_string=getattr(dummy_dict, "__doc__", ""), + callabl=dummy_dict, + tags={"module": "tests.function_modifiers.test_expanders"}, + ) + + assert nodes[1].name == "col_1" + assert nodes[1].type == int + assert nodes[1].documentation == "dummy doc" # we default to base function doc. + assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int], DependencyType.REQUIRED)} + assert nodes[2].name == "col_2" + assert nodes[2].type == int + assert nodes[2].documentation == "dummy doc" + assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int], DependencyType.REQUIRED)} + + +def test_extract_fields_transform_on_generic_dict_with_unpacked_fields(): + """Tests whole extract_fields decorator using a generic dict and unpacked field names.""" + annotation = function_modifiers.extract_fields("col_1", "col_2") + + def dummy_dict() -> Dict[str, int]: + """dummy doc""" + return {"col_1": 1, "col_2": 2} + + annotation.validate(dummy_dict) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + + assert len(nodes) == 3 + assert nodes[0] == node.Node( + name=dummy_dict.__name__, + typ=Dict[str, int], + doc_string=getattr(dummy_dict, "__doc__", ""), + callabl=dummy_dict, + tags={"module": "tests.function_modifiers.test_expanders"}, + ) + + assert nodes[1].name == "col_1" + assert nodes[1].type == int + assert nodes[1].documentation == "dummy doc" # we default to base function doc. + assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int], DependencyType.REQUIRED)} + assert nodes[2].name == "col_2" + assert nodes[2].type == int + assert nodes[2].documentation == "dummy doc" + assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int], DependencyType.REQUIRED)} + + +def test_extract_fields_transform_on_typed_dict_with_explicit_types(): + """Tests whole extract_fields decorator using a TypedDict and explicit types.""" + annotation = function_modifiers.extract_fields({"test2": str}) + + def dummy_dict() -> MyDict: + """dummy doc""" + return {"test": 1, "test2": "2"} + + annotation.validate(dummy_dict) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + + assert len(nodes) == 2 + assert nodes[0] == node.Node( + name=dummy_dict.__name__, + typ=MyDict, + doc_string=getattr(dummy_dict, "__doc__", ""), + callabl=dummy_dict, + tags={"module": "tests.function_modifiers.test_expanders"}, + ) + + assert nodes[1].name == "test2" + assert nodes[1].type == str + assert nodes[1].documentation == "dummy doc" + assert nodes[1].input_types == {dummy_dict.__name__: (MyDict, DependencyType.REQUIRED)} + + +def test_extract_fields_transform_on_typed_dict_with_field_list(): + """Tests whole extract_fields decorator using a TypedDict and a list of field names.""" + annotation = function_modifiers.extract_fields(["test2"]) + + def dummy_dict() -> MyDict: + """dummy doc""" + return {"test": 1, "test2": "2"} + + annotation.validate(dummy_dict) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + + assert len(nodes) == 2 + assert nodes[0] == node.Node( + name=dummy_dict.__name__, + typ=MyDict, + doc_string=getattr(dummy_dict, "__doc__", ""), + callabl=dummy_dict, + tags={"module": "tests.function_modifiers.test_expanders"}, + ) + + assert nodes[1].name == "test2" + assert nodes[1].type == str + assert nodes[1].documentation == "dummy doc" + assert nodes[1].input_types == {dummy_dict.__name__: (MyDict, DependencyType.REQUIRED)} + + +def test_extract_fields_transform_on_typed_dict_with_unpacked_fields(): + """Tests whole extract_fields decorator using a TypedDict and explicit types.""" + annotation = function_modifiers.extract_fields("test2") + + def dummy_dict() -> MyDict: + """dummy doc""" + return {"test": 1, "test2": "2"} + + annotation.validate(dummy_dict) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + + assert len(nodes) == 2 + assert nodes[0] == node.Node( + name=dummy_dict.__name__, + typ=MyDict, + doc_string=getattr(dummy_dict, "__doc__", ""), + callabl=dummy_dict, + tags={"module": "tests.function_modifiers.test_expanders"}, + ) + + assert nodes[1].name == "test2" + assert nodes[1].type == str + assert nodes[1].documentation == "dummy doc" + assert nodes[1].input_types == {dummy_dict.__name__: (MyDict, DependencyType.REQUIRED)} + + +def test_extract_fields_transform_on_typed_dict_with_inferred_types(): + """Tests whole extract_fields decorator using a TypedDict and inferred types.""" + annotation = function_modifiers.extract_fields() + + def dummy_dict() -> MyDict: + """dummy doc""" + return {"test": 1, "test2": "2"} + + annotation.validate(dummy_dict) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + + assert len(nodes) == 3 + assert nodes[0] == node.Node( + name=dummy_dict.__name__, + typ=MyDict, + doc_string=getattr(dummy_dict, "__doc__", ""), + callabl=dummy_dict, + tags={"module": "tests.function_modifiers.test_expanders"}, + ) + + assert nodes[1].name == "test" + assert nodes[1].type == int + assert nodes[1].documentation == "dummy doc" # we default to base function doc. + assert nodes[1].input_types == {dummy_dict.__name__: (MyDict, DependencyType.REQUIRED)} + assert nodes[2].name == "test2" + assert nodes[2].type == str + assert nodes[2].documentation == "dummy doc" + assert nodes[2].input_types == {dummy_dict.__name__: (MyDict, DependencyType.REQUIRED)} -def test_extract_fields_fill_with(): +def test_extract_fields_transform_using_fill_with(): def dummy_dict() -> dict: """dummy doc""" return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])} annotation = function_modifiers.extract_fields({"col_2": int, "col_4": float}, fill_with=1.0) + annotation.validate(dummy_dict) original_node, extracted_field_node, missing_field_node = annotation.transform_node( node.Node.from_fn(dummy_dict), {}, dummy_dict ) @@ -468,12 +661,13 @@ def dummy_dict() -> dict: assert missing_field == 1.0 -def test_extract_fields_no_fill_with(): +def test_extract_fields_transform_not_using_fill_with(): def dummy_dict() -> dict: """dummy doc""" return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])} annotation = function_modifiers.extract_fields({"col_4": int}) + annotation.validate(dummy_dict) nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): nodes[1].callable(dummy_dict=dummy_dict()) From 750092ee139a2abac090a0d6df75a4e5f221ed1e Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Thu, 8 May 2025 22:30:02 -0400 Subject: [PATCH 3/8] Update `extract_fields` documentation --- docs/concepts/function-modifiers.rst | 68 ++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/docs/concepts/function-modifiers.rst b/docs/concepts/function-modifiers.rst index 4c4797b27..e374c138f 100644 --- a/docs/concepts/function-modifiers.rst +++ b/docs/concepts/function-modifiers.rst @@ -140,7 +140,7 @@ The ``@check_output`` function modifiers are applied on the **node output / func .. note:: - In the future, validatation capabailities may be added to ``@schema``. For now, it's only added metadata. + In the future, validation capabilities may be added to ``@schema``. For now, it's only added metadata. @check_output* ~~~~~~~~~~~~~~ @@ -216,14 +216,14 @@ Now, ``X_train``, ``X_validation``, and ``X_test`` are available to other nodes @extract_fields ~~~~~~~~~~~~~~~ -Additionally, we can extract fields from an output dictionary using ``@extract_fields``. In this case, you must specify the dictionary keys and their types. The function must return a dictionary that contains, at a minimum, those keys specified in the decorator. +Additionally, we can extract fields from an output dictionary using ``@extract_fields``. The function must return a dictionary that contains, at a minimum, those keys specified in the decorator. In this case, you can specify a dictionary of fields and their types: .. code-block:: python from typing import Dict from hamilton.function_modifiers import extract_fields - @extract_fields(dict( # don't forget the dictionary + @extract_fields(dict( # fields specified as a dictionary X_train=np.ndarray, X_validation=np.ndarray, X_test=np.ndarray, @@ -240,6 +240,68 @@ Additionally, we can extract fields from an output dictionary using ``@extract_f .. image:: ./_function-modifiers/extract_fields.png :height: 250px +Or if you are using a generic dictionary, you can specify solely the field names. + +.. code-block:: python + + from typing import Dict + from hamilton.function_modifiers import extract_fields + + @extract_fields("X_train", "X_validation", "X_test") # field names only + def dataset_splits(X: np.ndarray) -> Dict[str, np.ndarray]: # generic dict + """Randomly split data into train, validation, test""" + X_train, X_validation, X_test = random_split(X) + return dict( + X_train=X_train, + X_validation=X_validation, + X_test=X_test, + ) + +If you are using a `TypedDict`, you can specify the just field names. + +.. code-block:: python + + from typing import TypedDict + from hamilton.function_modifiers import extract_fields + + class DatasetSplits(TypedDict): + X_train: np.ndarray + X_validation: np.ndarray + X_test: np.ndarray + + @extract_fields("X_train", "X_validation", "X_test") + def dataset_splits(X: np.ndarray) -> DatasetSplits: + """Randomly split data into train, validation, test""" + X_train, X_validation, X_test = random_split(X) + return dict( + X_train=X_train, + X_validation=X_validation, + X_test=X_test, + ) + + +Or you can leave the field names empty and extract all fields from the `TypedDict`. + +.. code-block:: python + + from typing import TypedDict + from hamilton.function_modifiers import extract_fields + + class DatasetSplits(TypedDict): + X_train: np.ndarray + X_validation: np.ndarray + X_test: np.ndarray + + @extract_fields(DatasetSplits) # field names only + def dataset_splits(X: np.ndarray) -> DatasetSplits: + """Randomly split data into train, validation, test""" + X_train, X_validation, X_test = random_split(X) + return dict( + X_train=X_train, + X_validation=X_validation, + X_test=X_test, + ) + Again, ``X_train``, ``X_validation``, and ``X_test`` are now available to other nodes, or you can query the ``dataset_splits`` node to retrieve all splits in a dictionary. From cfce56c33a65458383750991ed6d69a9e3a8a676 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Thu, 8 May 2025 22:30:46 -0400 Subject: [PATCH 4/8] Rename `unpack_fields` tests for consistency --- tests/function_modifiers/test_expanders.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/function_modifiers/test_expanders.py b/tests/function_modifiers/test_expanders.py index 6939b26e3..33292fe42 100644 --- a/tests/function_modifiers/test_expanders.py +++ b/tests/function_modifiers/test_expanders.py @@ -673,7 +673,7 @@ def dummy_dict() -> dict: nodes[1].callable(dummy_dict=dummy_dict()) -def test_unpack_fields_valid_explicit_tuple(): +def test_unpack_fields_transform_on_explicit_tuple(): def dummy() -> Tuple[int, str, int]: """dummy doc""" return 1, "2", 3 @@ -704,7 +704,7 @@ def dummy() -> Tuple[int, str, int]: assert nodes[3].input_types == {dummy.__name__: (Tuple[int, str, int], DependencyType.REQUIRED)} -def test_unpack_fields_valid_explicit_tuple_subset(): +def test_unpack_fields_transform_on_explicit_tuple_subset(): def dummy() -> Tuple[int, str, int]: """dummy doc""" return 1, "2", 3 @@ -727,7 +727,7 @@ def dummy() -> Tuple[int, str, int]: assert nodes[1].input_types == {dummy.__name__: (Tuple[int, str, int], DependencyType.REQUIRED)} -def test_unpack_fields_valid_indeterminate_tuple(): +def test_unpack_fields_transform_on_indeterminate_tuple(): def dummy() -> Tuple[int, ...]: """dummy doc""" return 1, 2, 3 From c7f97ea3c98dda61cfcb647e417d572043c858b2 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Thu, 8 May 2025 23:01:10 -0400 Subject: [PATCH 5/8] Add backward compatible `Union` --- hamilton/function_modifiers/expanders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 1207ad2d8..98156a611 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -700,7 +700,7 @@ def extractor_fn( def _process_extract_fields( - fields: Optional[Dict[str, Any] | List[str]], output_type: Any + fields: Optional[Union[Dict[str, Any], List[str]]], output_type: Any ) -> Dict[str, Any]: """Processes the fields and base output type to extract a dict of field types. @@ -826,7 +826,7 @@ class extract_fields(base.SingleNodeNodeTransformer): def __init__( self, - fields: Optional[Dict[str, Any] | List[str] | Any] = None, + fields: Optional[Union[Dict[str, Any], List[str], Any]] = None, *others, fill_with: Any = None, ): From 39203fd5ff63250f7a732f65e2721ef3f8bf60bf Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Thu, 8 May 2025 23:11:25 -0400 Subject: [PATCH 6/8] Update docs/concepts/function-modifiers.rst Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- docs/concepts/function-modifiers.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/concepts/function-modifiers.rst b/docs/concepts/function-modifiers.rst index e374c138f..8b4e4fe32 100644 --- a/docs/concepts/function-modifiers.rst +++ b/docs/concepts/function-modifiers.rst @@ -201,7 +201,7 @@ A good example is splitting a dataset into training, validation, and test splits from typing import Tuple from hamilton.function_modifiers import unpack_fields - @unpack_fields("X_train" "X_validation", "X_test") + @unpack_fields("X_train", "X_validation", "X_test") def dataset_splits(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Randomly split data into train, validation, test""" X_train, X_validation, X_test = random_split(X) From 1c8dbaba63964531ae3428e28022b93dc83b2ac9 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 19 May 2025 12:57:16 -0400 Subject: [PATCH 7/8] Rename process functions --- hamilton/function_modifiers/expanders.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 98156a611..3b10f1e69 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -699,10 +699,11 @@ def extractor_fn( return output_nodes -def _process_extract_fields( +def _determine_fields_to_extract( fields: Optional[Union[Dict[str, Any], List[str]]], output_type: Any ) -> Dict[str, Any]: - """Processes the fields and base output type to extract a dict of field types. + """Determines which fields to extract based on user requested fields and the output type of + the return type of the function. :param fields: Dict of fields to extract. :param output_type: The output type of the node function. @@ -856,7 +857,7 @@ def validate(self, fn: Callable): :raises: InvalidDecoratorException If the function is not annotated with a dict or typing.Dict type as output. """ self.output_type = typing.get_type_hints(fn).get("return") - self.resolved_fields = _process_extract_fields(self.fields, self.output_type) + self.resolved_fields = _determine_fields_to_extract(self.fields, self.output_type) def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable @@ -922,8 +923,9 @@ def extractor_fn(field_to_extract: str = field, **kwargs) -> field_type: # type return output_nodes -def _process_unpack_fields(fields: List[str], output_type: Any) -> List[Type]: - """Processes the fields and base output type to extract a list of field types. +def _determine_fields_to_unpack(fields: List[str], output_type: Any) -> List[Type]: + """Determines which fields to unpack based on user requested fields and the output type of + the return type of the function. :param fields: List of fields to to unpack. :param output_type: The output type of the node function. @@ -1006,7 +1008,7 @@ def __init__(self, *fields: str): @override def validate(self, fn: Callable): output_type = typing.get_type_hints(fn).get("return") - field_types = _process_unpack_fields(self.fields, output_type) + field_types = _determine_fields_to_unpack(self.fields, output_type) self.field_types = field_types self.output_type = output_type From 2bf1d898ed3833ed3a28bf31c9d4783608e8a994 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Mon, 19 May 2025 12:58:07 -0400 Subject: [PATCH 8/8] Update docstrings --- hamilton/function_modifiers/expanders.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 3b10f1e69..dfbb16227 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -851,7 +851,8 @@ def __init__( self.fill_with = fill_with def validate(self, fn: Callable): - """A function is invalid if it is not annotated with a dict or typing.Dict return type. + """A function is invalid if it is not annotated with a dict or typing.Dict return type or if the + fields to extract are not valid. :param fn: Function to validate. :raises: InvalidDecoratorException If the function is not annotated with a dict or typing.Dict type as output. @@ -1007,6 +1008,11 @@ def __init__(self, *fields: str): @override def validate(self, fn: Callable): + """Validates that the return type of the function is a tuple or typing.Tuple with the + + :param fn: Function to validate + :raises: InvalidDecoratorException If the function does not output a tuple or typing.Tuple type. + """ output_type = typing.get_type_hints(fn).get("return") field_types = _determine_fields_to_unpack(self.fields, output_type) self.field_types = field_types @@ -1016,6 +1022,14 @@ def validate(self, fn: Callable): def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: + """Unpacks the specified fields form the tuple output into separate nodes. + + :param node_: Node to transform + :param config: Config to use + :param fn: Function to unpack fields from. Must output a tuple. + :return: A collection of nodes -- + one for the original tuple generator, and another for each field to unpack. + """ fn = node_.callable base_doc = node_.documentation base_tags = node_.tags.copy()