diff --git a/docs/reference/decorators/with_columns.rst b/docs/reference/decorators/with_columns.rst index 9dbbb7b1f..fd9918714 100644 --- a/docs/reference/decorators/with_columns.rst +++ b/docs/reference/decorators/with_columns.rst @@ -2,10 +2,10 @@ with_columns ======================= -Pandas --------------- +We support the `with_columns` operation that appends the results as new columns to the original dataframe for several libraries: -We have a ``with_columns`` option to run operations on columns of a Pandas dataframe and append the results as new columns. +Pandas +----------------------- **Reference Documentation** @@ -13,6 +13,24 @@ We have a ``with_columns`` option to run operations on columns of a Pandas dataf :special-members: __init__ +Polar (Eager) +----------------------- + +**Reference Documentation** + +.. autoclass:: hamilton.plugins.h_polars.with_columns + :special-members: __init__ + + +Polars (Lazy) +----------------------- + +**Reference Documentation** + +.. autoclass:: hamilton.plugins.h_polars_lazyframe.with_columns + :special-members: __init__ + + PySpark -------------- diff --git a/examples/pandas/with_columns/README b/examples/pandas/with_columns/README index 95cad3cfe..53a422d5a 100644 --- a/examples/pandas/with_columns/README +++ b/examples/pandas/with_columns/README @@ -2,6 +2,6 @@ We show the ability to use the familiar `with_columns` from either `pyspark` or `polars` on a Pandas dataframe. -To see the example look at the notebook. +To see the example look at the [notebook](notebook.ipynb). ![image info](./dag.png) diff --git a/examples/pandas/with_columns/notebook.ipynb b/examples/pandas/with_columns/notebook.ipynb index 49eca8ed7..495b46aa1 100644 --- a/examples/pandas/with_columns/notebook.ipynb +++ b/examples/pandas/with_columns/notebook.ipynb @@ -30,9 +30,7 @@ "output_type": "stream", "text": [ "/Users/jernejfrank/miniconda3/envs/hamilton/lib/python3.10/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", - " warnings.warn(\n", - "/Users/jernejfrank/miniconda3/envs/hamilton/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + " warnings.warn(\n" ] } ], @@ -59,334 +57,222 @@ "\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "cluster__legend\n", - "\n", - "Legend\n", + "\n", + "Legend\n", "\n", "\n", "\n", "case\n", - "\n", - "\n", - "\n", - "case\n", - "thousands\n", + "\n", + "\n", + "\n", + "case\n", + "thousands\n", "\n", - "\n", + "\n", "\n", - "initial_df\n", - "\n", - "initial_df\n", - "DataFrame\n", - "\n", - "\n", - "\n", - "final_df.signups\n", - "\n", - "final_df.signups\n", - "Series\n", + "final_df.spend_zero_mean\n", + "\n", + "final_df.spend_zero_mean\n", + "Series\n", "\n", - "\n", - "\n", - "initial_df->final_df.signups\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "Series\n", "\n", - "\n", - "\n", - "final_df.__append\n", - "\n", - "final_df.__append\n", - "DataFrame\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", "\n", - "\n", - "\n", - "initial_df->final_df.__append\n", - "\n", - "\n", + "\n", + "\n", + "final_df\n", + "\n", + "final_df\n", + "DataFrame\n", "\n", "\n", - "\n", + "\n", "final_df.spend\n", - "\n", - "final_df.spend\n", - "Series\n", + "\n", + "final_df.spend\n", + "Series\n", "\n", - "\n", - "\n", - "initial_df->final_df.spend\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup\n", + "\n", + "final_df.spend_per_signup\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_per_signup\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "final_df.avg_3wk_spend\n", - "\n", - "final_df.avg_3wk_spend: case\n", - "Series\n", + "\n", + "final_df.avg_3wk_spend: case\n", + "Series\n", "\n", - "\n", - "\n", - "final_df.avg_3wk_spend->final_df.__append\n", - "\n", - "\n", - "\n", - "\n", - "final_df.signups->final_df.__append\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend->final_df.avg_3wk_spend\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_per_signup\n", - "\n", - "final_df.spend_per_signup\n", - "Series\n", + "\n", + "\n", + "final_df.spend_mean\n", + "\n", + "final_df.spend_mean\n", + "float\n", "\n", - "\n", - "\n", - "final_df.signups->final_df.spend_per_signup\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_mean\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_zero_mean\n", - "\n", - "final_df.spend_zero_mean\n", - "Series\n", + "\n", + "\n", + "final_df.spend_std_dev\n", + "\n", + "final_df.spend_std_dev\n", + "float\n", "\n", - "\n", - "\n", - "final_df.spend_zero_mean_unit_variance\n", - "\n", - "final_df.spend_zero_mean_unit_variance\n", - "Series\n", + "\n", + "\n", + "final_df.spend->final_df.spend_std_dev\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n", - "\n", - "\n", + "\n", + "\n", + "final_df.__append\n", + "\n", + "final_df.__append\n", + "DataFrame\n", "\n", - "\n", - "\n", - "final_df.spend_zero_mean_unit_variance->final_df.__append\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend_per_signup->final_df.__append\n", + "\n", "\n", - "\n", - "\n", - "final_df\n", - "\n", - "final_df\n", - "DataFrame\n", + "\n", + "\n", + "final_df.avg_3wk_spend->final_df.__append\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "final_df.__append->final_df\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "final_df.spend_mean\n", - "\n", - "final_df.spend_mean\n", - "float\n", - "\n", - "\n", - "\n", - "final_df.spend_mean->final_df.spend_zero_mean\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_std_dev\n", - "\n", - "final_df.spend_std_dev\n", - "float\n", + "\n", + "\n", + "final_df.signups\n", + "\n", + "final_df.signups\n", + "Series\n", "\n", - "\n", + "\n", "\n", - "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n", - "\n", - "\n", + "final_df.signups->final_df.spend_per_signup\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.avg_3wk_spend\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance->final_df.__append\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.spend_zero_mean\n", - "\n", - "\n", + "\n", + "\n", + "initial_df\n", + "\n", + "initial_df\n", + "DataFrame\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.__append\n", - "\n", + "\n", + "\n", + "initial_df->final_df.spend\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.spend_mean\n", - "\n", - "\n", + "\n", + "\n", + "initial_df->final_df.__append\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.spend_std_dev\n", - "\n", - "\n", + "\n", + "\n", + "initial_df->final_df.signups\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.spend_per_signup\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend_mean->final_df.spend_zero_mean\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_per_signup->final_df.__append\n", - "\n", + "\n", + "\n", + "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", "\n", "\n", "\n", "config\n", - "\n", - "\n", - "\n", - "config\n", + "\n", + "\n", + "\n", + "config\n", "\n", "\n", "\n", "function\n", - "\n", - "function\n", + "\n", + "function\n", "\n", "\n", "\n", "output\n", - "\n", - "output\n", + "\n", + "output\n", "\n", "\n", "\n" ], "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
signupsspendavg_3wk_spendspend_per_signupspend_zero_mean_unit_variance
0110000000.0NaN10000000.0-1.064405
11010000000.0NaN1000000.0-1.064405
25020000000.013333.333333400000.0-0.483821
310040000000.023333.333333400000.00.677349
420040000000.033333.333333200000.00.677349
540050000000.043333.333333125000.01.257934
\n", - "
" - ], - "text/plain": [ - " signups spend avg_3wk_spend spend_per_signup \\\n", - "0 1 10000000.0 NaN 10000000.0 \n", - "1 10 10000000.0 NaN 1000000.0 \n", - "2 50 20000000.0 13333.333333 400000.0 \n", - "3 100 40000000.0 23333.333333 400000.0 \n", - "4 200 40000000.0 33333.333333 200000.0 \n", - "5 400 50000000.0 43333.333333 125000.0 \n", - "\n", - " spend_zero_mean_unit_variance \n", - "0 -1.064405 \n", - "1 -1.064405 \n", - "2 -0.483821 \n", - "3 0.677349 \n", - "4 0.677349 \n", - "5 1.257934 " + "" ] }, "metadata": {}, @@ -400,8 +286,8 @@ "import my_functions\n", "\n", "output_columns = [\n", - " \"spend\",\n", - " \"signups\",\n", + " # \"spend\",\n", + " # \"signups\",\n", " \"avg_3wk_spend\",\n", " \"spend_per_signup\",\n", " \"spend_zero_mean_unit_variance\",\n", @@ -461,232 +347,222 @@ "\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "cluster__legend\n", - "\n", - "Legend\n", + "\n", + "Legend\n", "\n", "\n", "\n", "case\n", - "\n", - "\n", - "\n", - "case\n", - "millions\n", + "\n", + "\n", + "\n", + "case\n", + "millions\n", "\n", - "\n", + "\n", "\n", - "initial_df\n", - "\n", - "initial_df\n", - "DataFrame\n", - "\n", - "\n", - "\n", - "final_df.signups\n", - "\n", - "final_df.signups\n", - "Series\n", + "final_df.spend_zero_mean\n", + "\n", + "final_df.spend_zero_mean\n", + "Series\n", "\n", - "\n", - "\n", - "initial_df->final_df.signups\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "Series\n", "\n", - "\n", - "\n", - "final_df.__append\n", - "\n", - "final_df.__append\n", - "DataFrame\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", "\n", - "\n", - "\n", - "initial_df->final_df.__append\n", - "\n", - "\n", + "\n", + "\n", + "final_df\n", + "\n", + "final_df\n", + "DataFrame\n", "\n", "\n", - "\n", + "\n", "final_df.spend\n", - "\n", - "final_df.spend\n", - "Series\n", + "\n", + "final_df.spend\n", + "Series\n", "\n", - "\n", - "\n", - "initial_df->final_df.spend\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup\n", + "\n", + "final_df.spend_per_signup\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_per_signup\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "final_df.avg_3wk_spend\n", - "\n", - "final_df.avg_3wk_spend: case\n", - "Series\n", + "\n", + "final_df.avg_3wk_spend: case\n", + "Series\n", "\n", - "\n", - "\n", - "final_df.avg_3wk_spend->final_df.__append\n", - "\n", - "\n", - "\n", - "\n", - "final_df.signups->final_df.__append\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend->final_df.avg_3wk_spend\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_per_signup\n", - "\n", - "final_df.spend_per_signup\n", - "Series\n", + "\n", + "\n", + "final_df.spend_mean\n", + "\n", + "final_df.spend_mean\n", + "float\n", "\n", - "\n", - "\n", - "final_df.signups->final_df.spend_per_signup\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_mean\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_zero_mean\n", - "\n", - "final_df.spend_zero_mean\n", - "Series\n", + "\n", + "\n", + "final_df.spend_std_dev\n", + "\n", + "final_df.spend_std_dev\n", + "float\n", "\n", - "\n", - "\n", - "final_df.spend_zero_mean_unit_variance\n", - "\n", - "final_df.spend_zero_mean_unit_variance\n", - "Series\n", + "\n", + "\n", + "final_df.spend->final_df.spend_std_dev\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n", - "\n", - "\n", + "\n", + "\n", + "final_df.__append\n", + "\n", + "final_df.__append\n", + "DataFrame\n", "\n", - "\n", - "\n", - "final_df.spend_zero_mean_unit_variance->final_df.__append\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend_per_signup->final_df.__append\n", + "\n", "\n", - "\n", - "\n", - "final_df\n", - "\n", - "final_df\n", - "DataFrame\n", + "\n", + "\n", + "final_df.avg_3wk_spend->final_df.__append\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "final_df.__append->final_df\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "final_df.spend_mean\n", - "\n", - "final_df.spend_mean\n", - "float\n", - "\n", - "\n", - "\n", - "final_df.spend_mean->final_df.spend_zero_mean\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_std_dev\n", - "\n", - "final_df.spend_std_dev\n", - "float\n", + "\n", + "\n", + "final_df.signups\n", + "\n", + "final_df.signups\n", + "Series\n", "\n", - "\n", + "\n", "\n", - "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n", - "\n", - "\n", + "final_df.signups->final_df.spend_per_signup\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.avg_3wk_spend\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance->final_df.__append\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.spend_zero_mean\n", - "\n", - "\n", + "\n", + "\n", + "initial_df\n", + "\n", + "initial_df\n", + "DataFrame\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.__append\n", - "\n", + "\n", + "\n", + "initial_df->final_df.spend\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.spend_mean\n", - "\n", - "\n", + "\n", + "\n", + "initial_df->final_df.__append\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.spend_std_dev\n", - "\n", - "\n", + "\n", + "\n", + "initial_df->final_df.signups\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend->final_df.spend_per_signup\n", - "\n", - "\n", + "\n", + "\n", + "final_df.spend_mean->final_df.spend_zero_mean\n", + "\n", + "\n", "\n", - "\n", - "\n", - "final_df.spend_per_signup->final_df.__append\n", - "\n", + "\n", + "\n", + "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", "\n", "\n", "\n", "config\n", - "\n", - "\n", - "\n", - "config\n", + "\n", + "\n", + "\n", + "config\n", "\n", "\n", "\n", "function\n", - "\n", - "function\n", + "\n", + "function\n", "\n", "\n", "\n", "output\n", - "\n", - "output\n", + "\n", + "output\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -701,6 +577,102 @@ "dr.visualize_execution(final_vars=[\"final_df\"])\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# We can also run it async" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext hamilton.plugins.jupyter_magic" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module with_columns_async\n", + "\n", + "import asyncio\n", + "import pandas as pd\n", + "from hamilton.plugins.h_pandas import with_columns\n", + "\n", + "async def data_input() -> pd.DataFrame:\n", + " await asyncio.sleep(0.0001)\n", + " return pd.DataFrame({\n", + " \"a\": [1, 2, 3],\n", + " \"b\": [4, 5, 6],\n", + " \"c\": [7, 8, 9]\n", + " })\n", + "\n", + "\n", + "async def multiply_a(a: pd.Series) -> pd.Series:\n", + " await asyncio.sleep(0.0001)\n", + " return a * 10\n", + "\n", + "\n", + "async def mean_b(b: pd.Series) -> pd.Series:\n", + " await asyncio.sleep(5)\n", + " return b.mean()\n", + "\n", + "async def a_plus_b(a: pd.Series, b: pd.Series) -> pd.Series:\n", + " await asyncio.sleep(1)\n", + " return a + b\n", + "\n", + "async def multiply_a_plus_mean_b(multiply_a: pd.Series, mean_b: pd.Series) -> pd.Series:\n", + " await asyncio.sleep(0.0001)\n", + " return multiply_a + mean_b\n", + "\n", + "\n", + "@with_columns(\n", + " multiply_a,mean_b,a_plus_b, multiply_a_plus_mean_b,\n", + " columns_to_pass=[\"a\", \"b\"]\n", + ")\n", + "def final_df(data_input: pd.DataFrame) -> pd.DataFrame:\n", + " return data_input" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " a b c multiply_a mean_b a_plus_b multiply_a_plus_mean_b\n", + "0 1 4 7 10 5.0 5 15.0\n", + "1 2 5 8 20 5.0 7 25.0\n", + "2 3 6 9 30 5.0 9 35.0\n" + ] + } + ], + "source": [ + "import asyncio\n", + "from hamilton import async_driver\n", + "import with_columns_async\n", + "\n", + "async def main():\n", + " await asyncio.sleep(2)\n", + " dr = (await async_driver.Builder()\n", + " .with_modules(with_columns_async)\n", + " .with_config({\"case\":\"millions\"})\n", + " .build())\n", + " results = await dr.execute([\"final_df\"])\n", + " print(results[\"final_df\"])\n", + "\n", + "await main()\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/examples/polars/notebook.ipynb b/examples/polars/notebook.ipynb index c8cad7e44..c81678590 100644 --- a/examples/polars/notebook.ipynb +++ b/examples/polars/notebook.ipynb @@ -38,8 +38,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", - " warnings.warn(\n" + "/Users/jernejfrank/miniconda3/envs/hamilton/lib/python3.10/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", + " warnings.warn(\n", + "/Users/jernejfrank/miniconda3/envs/hamilton/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -70,177 +72,177 @@ "\n", "\n", - "\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "cluster__legend\n", - "\n", - "Legend\n", + "\n", + "Legend\n", "\n", - "\n", + "\n", "\n", + "spend_zero_mean\n", + "\n", + "spend_zero_mean\n", + "Series\n", + "\n", + "\n", + "\n", + "spend_zero_mean_unit_variance\n", + "\n", + "spend_zero_mean_unit_variance\n", + "Series\n", + "\n", + "\n", + "\n", + "spend_zero_mean->spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "spend_mean\n", + "\n", + "spend_mean\n", + "float\n", + "\n", + "\n", + "\n", + "spend_mean->spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", "spend_per_signup\n", "\n", "spend_per_signup\n", "Series\n", "\n", "\n", - "\n", + "\n", "avg_3wk_spend\n", - "\n", - "avg_3wk_spend\n", - "Series\n", + "\n", + "avg_3wk_spend\n", + "Series\n", "\n", - "\n", - "\n", - "spend\n", - "\n", - "spend\n", - "Series\n", + "\n", + "\n", + "base_df\n", + "\n", + "base_df\n", + "DataFrame\n", "\n", - "\n", - "\n", - "spend->spend_per_signup\n", - "\n", - "\n", + "\n", + "\n", + "signups\n", + "\n", + "signups\n", + "Series\n", "\n", - "\n", - "\n", - "spend->avg_3wk_spend\n", - "\n", - "\n", + "\n", + "\n", + "base_df->signups\n", + "\n", + "\n", "\n", - "\n", - "\n", - "spend_mean\n", - "\n", - "spend_mean\n", - "float\n", + "\n", + "\n", + "spend\n", + "\n", + "spend\n", + "Series\n", "\n", - "\n", - "\n", - "spend->spend_mean\n", - "\n", - "\n", + "\n", + "\n", + "base_df->spend\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "spend_std_dev\n", "\n", "spend_std_dev\n", "float\n", "\n", - "\n", + "\n", "\n", - "spend->spend_std_dev\n", - "\n", - "\n", + "spend_std_dev->spend_zero_mean_unit_variance\n", + "\n", + "\n", "\n", - "\n", - "\n", - "spend_zero_mean\n", - "\n", - "spend_zero_mean\n", - "Series\n", + "\n", + "\n", + "signups->spend_per_signup\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "spend->spend_zero_mean\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "spend_mean->spend_zero_mean\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", - "\n", - "base_df\n", - "\n", - "base_df\n", - "DataFrame\n", + "\n", + "\n", + "spend->spend_mean\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "base_df->spend\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "signups\n", - "\n", - "signups\n", - "Series\n", + "spend->spend_per_signup\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "base_df->signups\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "spend_zero_mean_unit_variance\n", - "\n", - "spend_zero_mean_unit_variance\n", - "Series\n", - "\n", - "\n", - "\n", - "spend_std_dev->spend_zero_mean_unit_variance\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "signups->spend_per_signup\n", - "\n", - "\n", + "spend->avg_3wk_spend\n", + "\n", + "\n", "\n", - "\n", - "\n", - "spend_zero_mean->spend_zero_mean_unit_variance\n", - "\n", - "\n", + "\n", + "\n", + "spend->spend_std_dev\n", + "\n", + "\n", "\n", "\n", "\n", "_base_df_inputs\n", - "\n", - "base_df_location\n", - "str\n", + "\n", + "base_df_location\n", + "str\n", "\n", "\n", - "\n", + "\n", "_base_df_inputs->base_df\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "input\n", - "\n", - "input\n", + "\n", + "input\n", "\n", "\n", "\n", "function\n", - "\n", - "function\n", + "\n", + "function\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -777,7 +779,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "hamilton", "language": "python", "name": "python3" }, @@ -791,7 +793,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/examples/polars/with_columns/DAG_DataFrame.png b/examples/polars/with_columns/DAG_DataFrame.png new file mode 100644 index 000000000..cc3d713b7 Binary files /dev/null and b/examples/polars/with_columns/DAG_DataFrame.png differ diff --git a/examples/polars/with_columns/DAG_lazy.png b/examples/polars/with_columns/DAG_lazy.png new file mode 100644 index 000000000..d2566fd7b Binary files /dev/null and b/examples/polars/with_columns/DAG_lazy.png differ diff --git a/examples/polars/with_columns/README b/examples/polars/with_columns/README new file mode 100644 index 000000000..385ef5fb7 --- /dev/null +++ b/examples/polars/with_columns/README @@ -0,0 +1,8 @@ +# Using with_columns with Polars + +We show the ability to use the familiar `with_columns` from `polars`. Supported for both: `pl.DataFrame` and `pl.LazyFrame`. + +To see the example look at the [notebook](notebook.ipynb). + +![image info](./DAG_DataFrame.png) +![image info](./DAG_lazy.png) diff --git a/examples/polars/with_columns/my_functions.py b/examples/polars/with_columns/my_functions.py new file mode 100644 index 000000000..3b2c401b9 --- /dev/null +++ b/examples/polars/with_columns/my_functions.py @@ -0,0 +1,51 @@ +import polars as pl + +from hamilton.function_modifiers import config + +""" +Notes: + 1. This file is used for all the [ray|dask|spark]/hello_world examples. + 2. It therefore show cases how you can write something once and not only scale it, but port it + to different frameworks with ease! +""" + + +@config.when(case="millions") +def avg_3wk_spend__millions(spend: pl.Series) -> pl.Series: + """Rolling 3 week average spend.""" + return ( + spend.to_frame("spend").select(pl.col("spend").rolling_mean(window_size=3) / 1e6) + ).to_series(0) + + +@config.when(case="thousands") +def avg_3wk_spend__thousands(spend: pl.Series) -> pl.Series: + """Rolling 3 week average spend.""" + return ( + spend.to_frame("spend").select(pl.col("spend").rolling_mean(window_size=3) / 1e3) + ).to_series(0) + + +def spend_per_signup(spend: pl.Series, signups: pl.Series) -> pl.Series: + """The cost per signup in relation to spend.""" + return spend / signups + + +def spend_mean(spend: pl.Series) -> float: + """Shows function creating a scalar. In this case it computes the mean of the entire column.""" + return spend.mean() + + +def spend_zero_mean(spend: pl.Series, spend_mean: float) -> pl.Series: + """Shows function that takes a scalar. In this case to zero mean spend.""" + return spend - spend_mean + + +def spend_std_dev(spend: pl.Series) -> float: + """Function that computes the standard deviation of the spend column.""" + return spend.std() + + +def spend_zero_mean_unit_variance(spend_zero_mean: pl.Series, spend_std_dev: float) -> pl.Series: + """Function showing one way to make spend have zero mean and unit variance.""" + return spend_zero_mean / spend_std_dev diff --git a/examples/polars/with_columns/my_functions_lazy.py b/examples/polars/with_columns/my_functions_lazy.py new file mode 100644 index 000000000..4b65b2ac2 --- /dev/null +++ b/examples/polars/with_columns/my_functions_lazy.py @@ -0,0 +1,47 @@ +import polars as pl + +from hamilton.function_modifiers import config + +""" +Notes: + 1. This file is used for all the [ray|dask|spark]/hello_world examples. + 2. It therefore show cases how you can write something once and not only scale it, but port it + to different frameworks with ease! +""" + + +@config.when(case="millions") +def avg_3wk_spend__millions(spend: pl.Expr) -> pl.Expr: + """Rolling 3 week average spend.""" + return spend.rolling_mean(window_size=3) / 1e6 + + +@config.when(case="thousands") +def avg_3wk_spend__thousands(spend: pl.Expr) -> pl.Expr: + """Rolling 3 week average spend.""" + return spend.rolling_mean(window_size=3) / 1e3 + + +def spend_per_signup(spend: pl.Expr, signups: pl.Expr) -> pl.Expr: + """The cost per signup in relation to spend.""" + return spend / signups + + +def spend_mean(spend: pl.Expr) -> float: + """Shows function creating a scalar. In this case it computes the mean of the entire column.""" + return spend.mean() + + +def spend_zero_mean(spend: pl.Expr, spend_mean: float) -> pl.Expr: + """Shows function that takes a scalar. In this case to zero mean spend.""" + return spend - spend_mean + + +def spend_std_dev(spend: pl.Expr) -> float: + """Function that computes the standard deviation of the spend column.""" + return spend.std() + + +def spend_zero_mean_unit_variance(spend_zero_mean: pl.Expr, spend_std_dev: float) -> pl.Expr: + """Function showing one way to make spend have zero mean and unit variance.""" + return spend_zero_mean / spend_std_dev diff --git a/examples/polars/with_columns/notebook.ipynb b/examples/polars/with_columns/notebook.ipynb new file mode 100644 index 000000000..39bd66d35 --- /dev/null +++ b/examples/polars/with_columns/notebook.ipynb @@ -0,0 +1,1239 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Execute this cell to install dependencies\n", + "%pip install sf-hamilton[visualization]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example of using with_columns for Polars DataFrame [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/polars/with_columns/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/dagworks-inc/hamilton/blob/main/examples/polars/with_columns/notebook.ipynb)\n", + "\n", + "This allows you to efficiently run groups of map operations on a dataframe.\n", + "Here's an example of calling it -- if you've seen `@subdag`, you should be familiar with the concepts." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jernejfrank/miniconda3/envs/hamilton/lib/python3.10/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "%reload_ext hamilton.plugins.jupyter_magic\n", + "from hamilton import driver\n", + "import my_functions\n", + "\n", + "my_builder = driver.Builder().with_modules(my_functions).with_config({\"case\":\"thousands\"})\n", + "output_node = [\"final_df\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "case\n", + "\n", + "\n", + "\n", + "case\n", + "thousands\n", + "\n", + "\n", + "\n", + "final_df.spend\n", + "\n", + "final_df.spend\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend_std_dev\n", + "\n", + "final_df.spend_std_dev\n", + "float\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_std_dev\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup\n", + "\n", + "final_df.spend_per_signup\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_per_signup\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.__append\n", + "\n", + "final_df.__append\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean\n", + "\n", + "final_df.spend_zero_mean\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_mean\n", + "\n", + "final_df.spend_mean\n", + "float\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.avg_3wk_spend\n", + "\n", + "final_df.avg_3wk_spend: case\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.avg_3wk_spend\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df\n", + "\n", + "final_df\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "initial_df\n", + "\n", + "initial_df\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "initial_df->final_df.spend\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups\n", + "\n", + "final_df.signups\n", + "Series\n", + "\n", + "\n", + "\n", + "initial_df->final_df.signups\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "initial_df->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups->final_df.spend_per_signup\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.__append->final_df\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_mean->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.avg_3wk_spend->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "config\n", + "\n", + "\n", + "\n", + "config\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module with_columns_example --builder my_builder --display --execute output_node\n", + "import polars as pl\n", + "from hamilton.plugins.h_polars import with_columns\n", + "import my_functions\n", + "\n", + "output_columns = [\n", + " \"spend\",\n", + " \"signups\",\n", + " \"avg_3wk_spend\",\n", + " \"spend_per_signup\",\n", + " \"spend_zero_mean_unit_variance\",\n", + "]\n", + "\n", + "def initial_df()->pl.DataFrame:\n", + " return pl.DataFrame(\n", + " { \n", + " \"signups\": pl.Series([1, 10, 50, 100, 200, 400]),\n", + " \"spend\": pl.Series([10, 10, 20, 40, 40, 50])*1e6,\n", + " }\n", + " )\n", + "\n", + "# the with_columns call\n", + "@with_columns(\n", + " *[my_functions],\n", + " columns_to_pass=[\"spend\", \"signups\"], # The columns to select from the dataframe\n", + " # select=output_columns, # The columns to append to the dataframe\n", + " # config_required = [\"a\"]\n", + ")\n", + "def final_df(initial_df: pl.DataFrame) -> pl.DataFrame:\n", + " return initial_df" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape: (6, 6)\n", + "┌─────────┬───────┬───────────────┬──────────────────┬─────────────────┬───────────────────────────┐\n", + "│ signups ┆ spend ┆ avg_3wk_spend ┆ spend_per_signup ┆ spend_zero_mean ┆ spend_zero_mean_unit_vari │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ ance │\n", + "│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ --- │\n", + "│ ┆ ┆ ┆ ┆ ┆ f64 │\n", + "╞═════════╪═══════╪═══════════════╪══════════════════╪═════════════════╪═══════════════════════════╡\n", + "│ 1 ┆ 1e7 ┆ null ┆ 1e7 ┆ -1.8333e7 ┆ -1.064405 │\n", + "│ 10 ┆ 1e7 ┆ null ┆ 1e6 ┆ -1.8333e7 ┆ -1.064405 │\n", + "│ 50 ┆ 2e7 ┆ 13.333333 ┆ 400000.0 ┆ -8.3333e6 ┆ -0.483821 │\n", + "│ 100 ┆ 4e7 ┆ 23.333333 ┆ 400000.0 ┆ 1.1667e7 ┆ 0.677349 │\n", + "│ 200 ┆ 4e7 ┆ 33.333333 ┆ 200000.0 ┆ 1.1667e7 ┆ 0.677349 │\n", + "│ 400 ┆ 5e7 ┆ 43.333333 ┆ 125000.0 ┆ 2.1667e7 ┆ 1.257934 │\n", + "└─────────┴───────┴───────────────┴──────────────────┴─────────────────┴───────────────────────────┘\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "case\n", + "\n", + "\n", + "\n", + "case\n", + "millions\n", + "\n", + "\n", + "\n", + "final_df.spend\n", + "\n", + "final_df.spend\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend_std_dev\n", + "\n", + "final_df.spend_std_dev\n", + "float\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_std_dev\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup\n", + "\n", + "final_df.spend_per_signup\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_per_signup\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.__append\n", + "\n", + "final_df.__append\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean\n", + "\n", + "final_df.spend_zero_mean\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_mean\n", + "\n", + "final_df.spend_mean\n", + "float\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.avg_3wk_spend\n", + "\n", + "final_df.avg_3wk_spend: case\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.avg_3wk_spend\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "Series\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df\n", + "\n", + "final_df\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "initial_df\n", + "\n", + "initial_df\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "initial_df->final_df.spend\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups\n", + "\n", + "final_df.signups\n", + "Series\n", + "\n", + "\n", + "\n", + "initial_df->final_df.signups\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "initial_df->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups->final_df.spend_per_signup\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.__append->final_df\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_mean->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.avg_3wk_spend->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "config\n", + "\n", + "\n", + "\n", + "config\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import with_columns_example\n", + "dr = driver.Builder().with_modules(my_functions, with_columns_example).with_config({\"case\":\"millions\"}).build()\n", + "print(dr.execute(final_vars=[\"final_df\"])[\"final_df\"])\n", + "dr.visualize_execution(final_vars=[\"final_df\"])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example of using with_columns for Polars LazyFrame\n", + "\n", + "This allows you to efficiently run groups of map operations on a dataframe.\n", + "Here's an example of calling it -- if you've seen `@subdag`, you should be familiar with the concepts." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext hamilton.plugins.jupyter_magic\n", + "from hamilton import driver\n", + "import my_functions_lazy\n", + "\n", + "my_builder_lazy = driver.Builder().with_modules(my_functions_lazy).with_config({\"case\":\"thousands\"})\n", + "output_node = [\"final_df\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "case\n", + "\n", + "\n", + "\n", + "case\n", + "thousands\n", + "\n", + "\n", + "\n", + "final_df.spend\n", + "\n", + "final_df.spend\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend_std_dev\n", + "\n", + "final_df.spend_std_dev\n", + "float\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_std_dev\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup\n", + "\n", + "final_df.spend_per_signup\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_per_signup\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.__append\n", + "\n", + "final_df.__append\n", + "LazyFrame\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean\n", + "\n", + "final_df.spend_zero_mean\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_mean\n", + "\n", + "final_df.spend_mean\n", + "float\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.avg_3wk_spend\n", + "\n", + "final_df.avg_3wk_spend: case\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.avg_3wk_spend\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df\n", + "\n", + "final_df\n", + "LazyFrame\n", + "\n", + "\n", + "\n", + "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "initial_df\n", + "\n", + "initial_df\n", + "LazyFrame\n", + "\n", + "\n", + "\n", + "initial_df->final_df.spend\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups\n", + "\n", + "final_df.signups\n", + "Expr\n", + "\n", + "\n", + "\n", + "initial_df->final_df.signups\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "initial_df->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups->final_df.spend_per_signup\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.__append->final_df\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_mean->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.avg_3wk_spend->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "config\n", + "\n", + "\n", + "\n", + "config\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module with_columns_lazy_example --builder my_builder_lazy --display --execute output_node\n", + "import polars as pl\n", + "from hamilton.plugins.h_polars_lazyframe import with_columns\n", + "import my_functions_lazy\n", + "\n", + "output_columns = [\n", + " \"spend\",\n", + " \"signups\",\n", + " \"avg_3wk_spend\",\n", + " \"spend_per_signup\",\n", + " \"spend_zero_mean_unit_variance\",\n", + "]\n", + "\n", + "def initial_df()->pl.LazyFrame:\n", + " return pl.DataFrame(\n", + " { \n", + " \"signups\": pl.Series([1, 10, 50, 100, 200, 400]),\n", + " \"spend\": pl.Series([10, 10, 20, 40, 40, 50])*1e6,\n", + " }\n", + " ).lazy()\n", + "\n", + "# the with_columns call\n", + "@with_columns(\n", + " *[my_functions_lazy],\n", + " columns_to_pass=[\"spend\", \"signups\"], # The columns to select from the dataframe\n", + " # select=output_columns, # The columns to append to the dataframe\n", + " # config_required = [\"a\"]\n", + ")\n", + "def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame:\n", + " return initial_df" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape: (6, 6)\n", + "┌─────────┬───────┬───────────────┬──────────────────┬─────────────────┬───────────────────────────┐\n", + "│ signups ┆ spend ┆ avg_3wk_spend ┆ spend_per_signup ┆ spend_zero_mean ┆ spend_zero_mean_unit_vari │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ ance │\n", + "│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ --- │\n", + "│ ┆ ┆ ┆ ┆ ┆ f64 │\n", + "╞═════════╪═══════╪═══════════════╪══════════════════╪═════════════════╪═══════════════════════════╡\n", + "│ 1 ┆ 1e7 ┆ null ┆ 1e7 ┆ -1.8333e7 ┆ -1.064405 │\n", + "│ 10 ┆ 1e7 ┆ null ┆ 1e6 ┆ -1.8333e7 ┆ -1.064405 │\n", + "│ 50 ┆ 2e7 ┆ 13.333333 ┆ 400000.0 ┆ -8.3333e6 ┆ -0.483821 │\n", + "│ 100 ┆ 4e7 ┆ 23.333333 ┆ 400000.0 ┆ 1.1667e7 ┆ 0.677349 │\n", + "│ 200 ┆ 4e7 ┆ 33.333333 ┆ 200000.0 ┆ 1.1667e7 ┆ 0.677349 │\n", + "│ 400 ┆ 5e7 ┆ 43.333333 ┆ 125000.0 ┆ 2.1667e7 ┆ 1.257934 │\n", + "└─────────┴───────┴───────────────┴──────────────────┴─────────────────┴───────────────────────────┘\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "case\n", + "\n", + "\n", + "\n", + "case\n", + "millions\n", + "\n", + "\n", + "\n", + "final_df.spend\n", + "\n", + "final_df.spend\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend_std_dev\n", + "\n", + "final_df.spend_std_dev\n", + "float\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_std_dev\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup\n", + "\n", + "final_df.spend_per_signup\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_per_signup\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.__append\n", + "\n", + "final_df.__append\n", + "LazyFrame\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean\n", + "\n", + "final_df.spend_zero_mean\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_mean\n", + "\n", + "final_df.spend_mean\n", + "float\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.spend_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.avg_3wk_spend\n", + "\n", + "final_df.avg_3wk_spend: case\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend->final_df.avg_3wk_spend\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "\n", + "final_df.spend_zero_mean_unit_variance\n", + "Expr\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean_unit_variance->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df\n", + "\n", + "final_df\n", + "LazyFrame\n", + "\n", + "\n", + "\n", + "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "initial_df\n", + "\n", + "initial_df\n", + "LazyFrame\n", + "\n", + "\n", + "\n", + "initial_df->final_df.spend\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups\n", + "\n", + "final_df.signups\n", + "Expr\n", + "\n", + "\n", + "\n", + "initial_df->final_df.signups\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "initial_df->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_per_signup->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups->final_df.spend_per_signup\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.signups->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "final_df.__append->final_df\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_zero_mean->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.spend_mean->final_df.spend_zero_mean\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "final_df.avg_3wk_spend->final_df.__append\n", + "\n", + "\n", + "\n", + "\n", + "config\n", + "\n", + "\n", + "\n", + "config\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import with_columns_lazy_example\n", + "from hamilton import base\n", + "from hamilton.plugins import h_polars\n", + "\n", + "dr = (\n", + " driver.Builder()\n", + " .with_adapter(\n", + " adapter=base.SimplePythonGraphAdapter(result_builder=h_polars.PolarsDataFrameResult()))\n", + " .with_modules(my_functions_lazy, with_columns_lazy_example)\n", + " .with_config({\"case\":\"millions\"})\n", + " .build()\n", + " )\n", + "print(dr.execute(final_vars=[\"final_df\"]))\n", + "dr.visualize_execution(final_vars=[\"final_df\"])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hamilton", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/spark/pyspark/out.png b/examples/spark/pyspark/out.png index e86907761..5ac2eaa1b 100644 Binary files a/examples/spark/pyspark/out.png and b/examples/spark/pyspark/out.png differ diff --git a/hamilton/function_modifiers/README b/hamilton/function_modifiers/README new file mode 100644 index 000000000..d826e29e9 --- /dev/null +++ b/hamilton/function_modifiers/README @@ -0,0 +1,21 @@ +# with_columns_base + +Documenting the current design flow for the `with_columns` decorator. + +For now, it belongs to the `NodeInjector` lifecycle since it still runs the decorated function as a node but injects the dataframe with columns appended columns as one of the parameters. + +The `with_columns` consists of three parts that are represented in the corresponding three abstract methods in `with_columns_base`: + +1. `get_initial_nodes` -- Input node(s): Either a dataframe if `pass_datafame_as` is used or extracted columns into nodes if `columns_to_pass` and is library specific. +2. `get_subdag_nodes` -- Subdag nodes: Creating the `subdag` is outsourced to `recursive.subdag`, left flexibility to pre- and post-process since some libraries need that (see h_spark). +3. `chain_subdag_nodes` -- Merge node: The append functionality between dataframe and selected columns is library specific. + +Each plugin library that can implement `with_columns` should subclass from this base class and implement the three abstract methods (four since `validate()` is also abstract). The child +classes need to override the `init` where they call out to the parent `init` and pass in `dataframe_type` which is registered in the corresponding `extensions` and has information of what +columns types are permitted for the given dataframe type. + +Keeping it for now loosely coupled to the `registry` and detached from `ResultBuilder`. The API is private, should we want to switch to `registry`, the refactoring is straightforward and shouldn't get us into trouble down the road. + +## NOTE +The handling of scalars and dataframe types varies between library to library. We made the decision that such a thing should not be permissible, so all the selected columns that want to be +appended to the original dataframe need to have the matching column type that is registered in the `registry` and set in the library extension modules. diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index 744d222e2..a330714a9 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -1,5 +1,8 @@ +import abc import inspect import sys +import typing +from collections import defaultdict from types import ModuleType from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, TypedDict, Union @@ -11,7 +14,6 @@ else: from typing import NotRequired - # Copied this over from function_graph # TODO -- determine the best place to put this code from hamilton import graph_utils, node @@ -624,5 +626,226 @@ def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> L if dep not in seen_nodes and dep in node_name_map: dep_node = node_name_map[dep] stack.append(dep_node) - seen_nodes.add(dep) + seen_nodes.add(dep) + + if not set(select) <= set([node_.name for node_ in output]): + raise ValueError( + "At least one of the selected nodes is not not in the DAG. " + f"You selected: {select}, but we only found nodes: {nodes}." + ) return output + + +def _default_inject_parameter(fn: Callable, target_dataframe: str = None) -> str: + if target_dataframe is not None: + inject_parameter = target_dataframe + else: + # If we don't have a specified dataframe we assume it's the first argument + function_parameters = list(inspect.signature(fn).parameters.values()) + if function_parameters: + inject_parameter = function_parameters[0].name + else: + raise ValueError( + f"Function {fn.__qualname__} has no parameters, but was " + f"decorated with with_columns. with_columns requires the first " + f"parameter to be a dataframe or using the on_input argument." + ) + return inject_parameter + + +class with_columns_base(base.NodeInjector, abc.ABC): + """Factory for with_columns operation on a dataframe. This is used when you want to extract some + columns out of the dataframe, perform operations on them and then append to the original dataframe. + + This is an internal class that is meant to be extended by each individual dataframe library implementing + the following abstract methods: + + - get_initial_nodes + - get_subdag_nodes + - chain_subdag_nodes + - validate + """ + + # TODO: if we rename the column nodes into something smarter this can be avoided and + # can also modify columns in place + @staticmethod + def contains_duplicates(nodes_: List[node.Node]) -> bool: + """Ensures that we don't run into name clashing of columns and group operations. + + In the case when we extract columns for the user, because ``columns_to_pass`` was used, we want + to safeguard against nameclashing with functions that are passed into ``with_columns`` - i.e. + there are no functions that have the same name as the columns. This effectively means that + using ``columns_to_pass`` will only append new columns to the dataframe and for changing + existing columns ``pass_dataframe_as`` or ``on_input`` needs to be used. + """ + node_counter = defaultdict(int) + for node_ in nodes_: + node_counter[node_.name] += 1 + if node_counter[node_.name] > 1: + return True + return False + + @staticmethod + def validate_dataframe( + fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]], required_type: Type + ) -> None: + input_types = typing.get_type_hints(fn) + if inject_parameter not in params: + raise InvalidDecoratorException( + f"Function: {fn.__name__} does not have the parameter {inject_parameter} as a dependency. " + f"@with_columns requires the parameter names to match the function parameters. " + f"If you wish do not wish to use the first argument, please use ``pass_dataframe_as`` or ``on_input`` option. " + f"It might not be compatible with some other decorators." + ) + + if isinstance(input_types[inject_parameter], required_type): + raise InvalidDecoratorException( + "The selected dataframe parameter is not the correct dataframe type. " + f"You selected a parameter of type {input_types[inject_parameter]}, but we expect to get {required_type}" + ) + + def __init__( + self, + *load_from: Union[Callable, ModuleType], + columns_to_pass: List[str] = None, + pass_dataframe_as: str = None, + on_input: str = None, + select: List[str] = None, + namespace: str = None, + config_required: List[str] = None, + dataframe_type: Type = None, + ): + """Instantiates a ``@with_columns`` decorator. + + :param load_from: The functions or modules that will be used to generate the group of map operations. + :param columns_to_pass: The initial schema of the dataframe. This is used to determine which + upstream inputs should be taken from the dataframe, and which shouldn't. Note that, if this is + left empty (and external_inputs is as well), we will assume that all dependencies come + from the dataframe. This cannot be used in conjunction with pass_dataframe_as. + :param pass_dataframe_as: The name of the dataframe that we're modifying, as known to the subdag. + If you pass this in, you are responsible for extracting columns out. If not provided, you have + to pass columns_to_pass in, and we will extract the columns out for you. + :param on_input: the dataframe parameter that we are applying with_columns on. By default we + will assume the first parameter is the corresponding dataframe. + :param select: The end nodes that represent columns to be appended to the original dataframe + via with_columns. Existing columns will be overridden. + :param namespace: The namespace of the nodes, so they don't clash with the global namespace + and so this can be reused. If its left out, there will be no namespace (in which case you'll want + to be careful about repeating it/reusing the nodes in other parts of the DAG.) + :param config_required: the list of config keys that are required to resolve any functions. Pass in None\ + if you want the functions/modules to have access to all possible config. + """ + + self.subdag_functions = subdag.collect_functions(load_from) + self.select = select + + # This is here to restrict to using either pass_dataframe_as or on_input or columns_to_pass + # TODO: decouple columns_to_pass, pass_dataframe_as and on_input + # For spark, we always perform with_columns on first parameter and use pass_dataframe_as; + # for pandas/polars, we can select which dataframe with on_input, but columns_to_pass, will always only work on first parameter + # We can decouple it so that on_input selects the target dataframe parameter that will inject into the next node + # pass_dataframe_as selects the original dataframe we want to extract columns from + # columns_to_pass is optinal helper that can be toggled on/off so no need to raise this error. + if ( + int(pass_dataframe_as is None) + int(columns_to_pass is None) + int(on_input is None) + == 1 + ): + raise ValueError( + "You must specify only one of ``columns_to_pass``, ``pass_dataframe_as``, and ``on_input``. " + "This is because specifying ``pass_dataframe_as`` or ``on_input`` injects into " + "the set of columns, allowing you to perform your own extraction" + "from the dataframe. We then execute all columns in the subdag" + "in order, passing in that initial dataframe. If you want" + "to reference columns in your code, you'll have to specify " + "the set of initial columns, and allow the subdag decorator " + "to inject the dataframe through. The initial columns tell " + "us which parameters to take from that dataframe, so we can" + "feed the right data into the right columns." + ) + + self.initial_schema = columns_to_pass + self.dataframe_subdag_param = pass_dataframe_as + self.target_dataframe = on_input + self.namespace = namespace + self.config_required = config_required + + if dataframe_type is None: + raise InvalidDecoratorException( + "Please provide the dataframe type for this specific library." + ) + + self.dataframe_type = dataframe_type + + def required_config(self) -> List[str]: + return self.config_required + + @abc.abstractmethod + def get_initial_nodes( + self, fn: Callable, params: Dict[str, Type[Type]] + ) -> Tuple[str, Collection[node.Node]]: + """Preparation stage where columns get extracted into nodes. In case `pass_dataframe_as` or `on_input` is + used, this should return an empty list (no column nodes) since the users will extract it + themselves. + + :param fn: the function we are decorating. By using the inspect library you can get information. + about what arguments it has / find out the dataframe argument. + :param params: Dictionary of all the type names one wants to inject. + :return: name of the dataframe parameter and list of nodes representing the extracted columns (can be empty). + """ + pass + + @abc.abstractmethod + def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + """Creates subdag from the passed in module / functions. + + :param config: Configuration with which the DAG was constructed. + :return: the subdag as a list of nodes. + """ + pass + + @abc.abstractmethod + def chain_subdag_nodes( + self, fn: Callable, inject_parameter: str, generated_nodes: Collection[node.Node] + ) -> node.Node: + """Combines the origanl dataframe with selected columns. This should produce a + dataframe output that is injected into the decorated function with new columns + appended and existing columns overriden. + + :param inject_parameter: the name of the original dataframe that. + :return: the new dataframe with the columns appended / overwritten. + """ + pass + + def inject_nodes( + self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable + ) -> Tuple[List[node.Node], Dict[str, str]]: + namespace = fn.__name__ if self.namespace is None else self.namespace + + inject_parameter, initial_nodes = self.get_initial_nodes(fn=fn, params=params) + subdag_nodes = self.get_subdag_nodes(fn=fn, config=config) + generated_nodes = initial_nodes + subdag_nodes + # TODO: for now we restrict that if user wants to change columns that already exist, he needs to + # pass the dataframe and extract them himself. If we add namespace to initial nodes and rewire the + # initial node names with the ongoing ones that have a column argument, we can also allow in place + # changes when using columns_to_pass + if with_columns_base.contains_duplicates(generated_nodes): + raise ValueError( + "You can only specify columns once. You used `columns_to_pass` and we " + "extract the columns for you. In this case they cannot be overwritten -- only new columns get " + "appended. If you want to modify in-place columns pass in a dataframe and " + "extract + modify the columns and afterwards select them." + ) + + pruned_nodes = prune_nodes(nodes=generated_nodes, select=self.select) + if len(pruned_nodes) == 0: + raise ValueError( + f"No nodes found upstream from select columns: {self.select} for function: " + f"{fn.__qualname__}" + ) + + # Node combining columns and dataframe might need info about prior nodes + output_nodes, current_param = self.chain_subdag_nodes( + fn=fn, inject_parameter=inject_parameter, generated_nodes=pruned_nodes + ) + output_nodes = subdag.add_namespace(output_nodes, namespace) + return output_nodes, {inject_parameter: assign_namespace(current_param, namespace)} diff --git a/hamilton/plugins/h_pandas.py b/hamilton/plugins/h_pandas.py index 722896dcf..bcbc2e2b3 100644 --- a/hamilton/plugins/h_pandas.py +++ b/hamilton/plugins/h_pandas.py @@ -1,9 +1,6 @@ -import inspect import sys -import typing -from collections import defaultdict from types import ModuleType -from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union, get_type_hints _sys_version_info = sys.version_info _version_tuple = (_sys_version_info.major, _sys_version_info.minor, _sys_version_info.micro) @@ -13,17 +10,17 @@ else: pass -import pandas as pd - -# Copied this over from function_graph -# TODO -- determine the best place to put this code -from hamilton import node -from hamilton.function_modifiers import base +from hamilton import node, registry from hamilton.function_modifiers.expanders import extract_columns -from hamilton.function_modifiers.recursive import assign_namespace, prune_nodes, subdag +from hamilton.function_modifiers.recursive import ( + _default_inject_parameter, + subdag, + with_columns_base, +) +from hamilton.plugins.pandas_extensions import DATAFRAME_TYPE -class with_columns(base.NodeInjector): +class with_columns(with_columns_base): """Initializes a with_columns decorator for pandas. This allows you to efficiently run groups of map operations on a dataframe. Here's an example of calling it -- if you've seen ``@subdag``, you should be familiar with @@ -81,69 +78,56 @@ def final_df(initial_df: pd.DataFrame) -> pd.DataFrame: from my_module. It starts with the columns a_from_df and b_from_df, and then adds the columns a, b, and a_plus_b to the dataframe. It then returns the dataframe, and does some processing on it." - In case you need more flexibility you can alternatively use ``pass_dataframe_as``, for example, + In case you need more flexibility you can alternatively use ``on_input``, for example, .. code-block:: python - # with_columns_module.py - def a_from_df(initial_df: pd.Series) -> pd.Series: - return initial_df["a_from_df"] / 100 + # with_columns_module.py + def a_from_df(initial_df: pd.Series) -> pd.Series: + return initial_df["a_from_df"] / 100 def b_from_df(initial_df: pd.Series) -> pd.Series: - return initial_df["b_from_df"] / 100 + return initial_df["b_from_df"] / 100 - # the with_columns call - @with_columns( - *[my_module], - *[a_from_df], - columns_to_pass=["a_from_df", "b_from_df"], - select=["a_from_df", "b_from_df", "a", "b", "a_plus_b", "a_b_average"], - ) - def final_df(initial_df: pd.DataFrame) -> pd.DataFrame: - # process, or just return unprocessed - ... + # the with_columns call + @with_columns( + *[my_module], + *[a_from_df], + on_input="initial_df", + select=["a_from_df", "b_from_df", "a", "b", "a_plus_b", "a_b_average"], + ) + def final_df(initial_df: pd.DataFrame, ...) -> pd.DataFrame: + # process, or just return unprocessed + ... the above would output a dataframe where the two columns ``a_from_df`` and ``b_from_df`` get overwritten. """ - @staticmethod - def _check_for_duplicates(nodes_: List[node.Node]) -> bool: - """Ensures that we don't run into name clashing of columns and group operations. - - In the case when we extract columns for the user, because ``columns_to_pass`` was used, we want - to safeguard against nameclashing with functions that are passed into ``with_columns`` - i.e. - there are no functions that have the same name as the columns. This effectively means that - using ``columns_to_pass`` will only append new columns to the dataframe and for changing - existing columns ``pass_dataframe_as`` needs to be used. - """ - node_counter = defaultdict(int) - for node_ in nodes_: - node_counter[node_.name] += 1 - if node_counter[node_.name] > 1: - return True - return False - def __init__( self, *load_from: Union[Callable, ModuleType], columns_to_pass: List[str] = None, pass_dataframe_as: str = None, + on_input: str = None, select: List[str] = None, namespace: str = None, config_required: List[str] = None, ): - """Instantiates a ``@with_column`` decorator. + """Instantiates a ``@with_columns`` decorator. :param load_from: The functions or modules that will be used to generate the group of map operations. :param columns_to_pass: The initial schema of the dataframe. This is used to determine which upstream inputs should be taken from the dataframe, and which shouldn't. Note that, if this is left empty (and external_inputs is as well), we will assume that all dependencies come - from the dataframe. This cannot be used in conjunction with pass_dataframe_as. - :param pass_dataframe_as: The name of the dataframe that we're modifying, as known to the subdag. + from the dataframe. This cannot be used in conjunction with on_input. + :param on_input: The name of the dataframe that we're modifying, as known to the subdag. If you pass this in, you are responsible for extracting columns out. If not provided, you have - to pass columns_to_pass in, and we will extract the columns out for you. + to pass columns_to_pass in, and we will extract the columns out on the first parameter for you. + :param select: The end nodes that represent columns to be appended to the original dataframe + via with_columns. Existing columns will be overridden. The selected nodes need to have the + corresponding column type, in this case pd.Series, to be appended to the original dataframe. :param namespace: The namespace of the nodes, so they don't clash with the global namespace and so this can be reused. If its left out, there will be no namespace (in which case you'll want to be careful about repeating it/reusing the nodes in other parts of the DAG.) @@ -151,44 +135,28 @@ def __init__( if you want the functions/modules to have access to all possible config. """ - self.subdag_functions = subdag.collect_functions(load_from) - - if select is None: - raise ValueError("Please specify at least one column to append or update.") - else: - self.select = select - - if (pass_dataframe_as is not None and columns_to_pass is not None) or ( - pass_dataframe_as is None and columns_to_pass is None - ): - raise ValueError( - "You must specify only one of columns_to_pass and " - "pass_dataframe_as. " - "This is because specifying pass_dataframe_as injects into " - "the set of columns, allowing you to perform your own extraction" - "from the dataframe. We then execute all columns in the sbudag" - "in order, passing in that initial dataframe. If you want" - "to reference columns in your code, you'll have to specify " - "the set of initial columns, and allow the subdag decorator " - "to inject the dataframe through. The initial columns tell " - "us which parameters to take from that dataframe, so we can" - "feed the right data into the right columns." + if pass_dataframe_as is not None: + raise NotImplementedError( + "We currently do not support pass_dataframe_as for pandas. Please reach out if you need this " + "functionality." ) - self.initial_schema = columns_to_pass - self.dataframe_subdag_param = pass_dataframe_as - self.namespace = namespace - self.config_required = config_required - - def required_config(self) -> List[str]: - return self.config_required + super().__init__( + *load_from, + columns_to_pass=columns_to_pass, + on_input=on_input, + select=select, + namespace=namespace, + config_required=config_required, + dataframe_type=DATAFRAME_TYPE, + ) def _create_column_nodes( - self, inject_parameter: str, params: Dict[str, Type[Type]] + self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]] ) -> List[node.Node]: output_type = params[inject_parameter] - def temp_fn(**kwargs) -> pd.DataFrame: + def temp_fn(**kwargs) -> Any: return kwargs[inject_parameter] # We recreate the df node to use extract columns @@ -204,88 +172,68 @@ def temp_fn(**kwargs) -> pd.DataFrame: out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn) return out_nodes[1:] - def _get_inital_nodes( + def get_initial_nodes( self, fn: Callable, params: Dict[str, Type[Type]] ) -> Tuple[str, Collection[node.Node]]: """Selects the correct dataframe and optionally extracts out columns.""" - initial_nodes = [] - if self.dataframe_subdag_param is not None: - inject_parameter = self.dataframe_subdag_param - else: - # If we don't have a specified dataframe we assume it's the first argument - sig = inspect.signature(fn) - inject_parameter = list(sig.parameters.values())[0].name - input_types = typing.get_type_hints(fn) - - if not input_types[inject_parameter] == pd.DataFrame: - raise ValueError( - "First argument has to be a pandas DataFrame. If you wish to use a " - "different argument, please use `pass_dataframe_as` option." - ) - - initial_nodes.extend( - self._create_column_nodes(inject_parameter=inject_parameter, params=params) - ) + inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) + with_columns_base.validate_dataframe( + fn=fn, + inject_parameter=inject_parameter, + params=params, + required_type=self.dataframe_type, + ) - if inject_parameter not in params: - raise base.InvalidDecoratorException( - f"Function: {fn.__name__} has a first parameter that is not a dependency. " - f"@with_columns requires the parameter names to match the function parameters. " - f"Thus it might not be compatible with some other decorators" - ) + initial_nodes = ( + [] + if self.target_dataframe is not None + else self._create_column_nodes(fn=fn, inject_parameter=inject_parameter, params=params) + ) return inject_parameter, initial_nodes - def _create_merge_node(self, upstream_node: str, node_name: str) -> node.Node: + def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + return subdag.collect_nodes(config, self.subdag_functions) + + def chain_subdag_nodes( + self, fn: Callable, inject_parameter: str, generated_nodes: Collection[node.Node] + ) -> node.Node: "Node that adds to / overrides columns for the original dataframe based on selected output." + # In case no node is selected we append all possible nodes that have a column type matching + # what the dataframe expects + if self.select is None: + self.select = [ + sink_node.name + for sink_node in generated_nodes + if sink_node.type == registry.get_column_type_from_df_type(self.dataframe_type) + ] def new_callable(**kwargs) -> Any: - df = kwargs[upstream_node] + df = kwargs[inject_parameter] columns_to_append = {} for column in self.select: columns_to_append[column] = kwargs[column] return df.assign(**columns_to_append) - input_map = {column: pd.Series for column in self.select} - input_map[upstream_node] = pd.DataFrame - - return node.Node( - name=node_name, - typ=pd.DataFrame, + column_type = registry.get_column_type_from_df_type(self.dataframe_type) + input_map = {column: column_type for column in self.select} + input_map[inject_parameter] = self.dataframe_type + merge_node = node.Node( + name="_append", + typ=self.dataframe_type, callabl=new_callable, input_types=input_map, ) - - def inject_nodes( - self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable - ) -> Tuple[List[node.Node], Dict[str, str]]: - namespace = fn.__name__ if self.namespace is None else self.namespace - - inject_parameter, initial_nodes = self._get_inital_nodes(fn=fn, params=params) - - subdag_nodes = subdag.collect_nodes(config, self.subdag_functions) - - if with_columns._check_for_duplicates(initial_nodes + subdag_nodes): - raise ValueError( - "You can only specify columns once. You used `columns_to_pass` and we " - "extract the columns for you. In this case they cannot be overwritten -- only new columns get " - "appended. If you want to modify in-place columns pass in a dataframe and " - "extract + modify the columns and afterwards select them." - ) - - pruned_nodes = prune_nodes(subdag_nodes, self.select) - if len(pruned_nodes) == 0: - raise ValueError( - f"No nodes found upstream from select columns: {self.select} for function: " - f"{fn.__qualname__}" - ) - - merge_node = self._create_merge_node(inject_parameter, node_name="__append") - - output_nodes = initial_nodes + pruned_nodes + [merge_node] - output_nodes = subdag.add_namespace(output_nodes, namespace) - return output_nodes, {inject_parameter: assign_namespace(merge_node.name, namespace)} + output_nodes = generated_nodes + [merge_node] + return output_nodes, merge_node.name def validate(self, fn: Callable): - pass + inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) + params = get_type_hints(fn) + with_columns_base.validate_dataframe( + fn=fn, + inject_parameter=inject_parameter, + params=params, + required_type=self.dataframe_type, + ) diff --git a/hamilton/plugins/h_polars.py b/hamilton/plugins/h_polars.py index 799882a30..4ef8609ab 100644 --- a/hamilton/plugins/h_polars.py +++ b/hamilton/plugins/h_polars.py @@ -1,8 +1,27 @@ -from typing import Any, Dict, Type, Union +import sys +from types import ModuleType +from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union, get_type_hints import polars as pl -from hamilton import base +_sys_version_info = sys.version_info +_version_tuple = (_sys_version_info.major, _sys_version_info.minor, _sys_version_info.micro) + +if _version_tuple < (3, 11, 0): + pass +else: + pass + +# Copied this over from function_graph +# TODO -- determine the best place to put this code +from hamilton import base, node, registry +from hamilton.function_modifiers.expanders import extract_columns +from hamilton.function_modifiers.recursive import ( + _default_inject_parameter, + subdag, + with_columns_base, +) +from hamilton.plugins.polars_extensions import DATAFRAME_TYPE class PolarsDataFrameResult(base.ResultMixin): @@ -54,3 +73,216 @@ def build_result( def output_type(self) -> Type: return pl.DataFrame + + +# Do we need this here? +class with_columns(with_columns_base): + """Initializes a with_columns decorator for polars. + + This allows you to efficiently run groups of map operations on a dataframe. We support + both eager and lazy mode in polars. In case of using eager mode the type should be + pl.DataFrame and the subsequent operations run on columns with type pl.Series. + + Here's an example of calling in eager mode -- if you've seen ``@subdag``, you should be familiar with + the concepts: + + .. code-block:: python + + # my_module.py + def a_b_average(a: pl.Series, b: pl.Series) -> pl.Series: + return (a + b) / 2 + + + .. code-block:: python + + # with_columns_module.py + def a_plus_b(a: pl.Series, b: pl.Series) -> pl.Series: + return a + b + + + # the with_columns call + @with_columns( + *[my_module], # Load from any module + *[a_plus_b], # or list operations directly + columns_to_pass=["a", "b"], # The columns to pass from the dataframe to + # the subdag + select=["a_plus_b", "a_b_average"], # The columns to append to the dataframe + ) + def final_df(initial_df: pl.DataFrame) -> pl.DataFrame: + # process, or just return unprocessed + ... + + In this instance the ``initial_df`` would get two columns added: ``a_plus_b`` and ``a_b_average``. + + Note that the operation is "append", meaning that the columns that are selected are appended + onto the dataframe. + + If the function takes multiple dataframes, the dataframe input to process will always be + the first argument. This will be passed to the subdag, transformed, and passed back to the function. + This follows the hamilton rule of reference by parameter name. To demonstarte this, in the code + above, the dataframe that is passed to the subdag is `initial_df`. That is transformed + by the subdag, and then returned as the final dataframe. + + You can read it as: + + "final_df is a function that transforms the upstream dataframe initial_df, running the transformations + from my_module. It starts with the columns a_from_df and b_from_df, and then adds the columns + a, b, and a_plus_b to the dataframe. It then returns the dataframe, and does some processing on it." + + In case you need more flexibility you can alternatively use ``on_input``, for example, + + .. code-block:: python + + # with_columns_module.py + def a_from_df() -> pl.Expr: + return pl.col(a).alias("a") / 100 + + def b_from_df() -> pl.Expr: + return pl.col(b).alias("b") / 100 + + + # the with_columns call + @with_columns( + *[my_module], + on_input="initial_df", + select=["a_from_df", "b_from_df", "a_plus_b", "a_b_average"], + ) + def final_df(initial_df: pl.DataFrame) -> pl.DataFrame: + # process, or just return unprocessed + ... + + the above would output a dataframe where the two columns ``a`` and ``b`` get + overwritten. + """ + + def __init__( + self, + *load_from: Union[Callable, ModuleType], + columns_to_pass: List[str] = None, + pass_dataframe_as: str = None, + on_input: str = None, + select: List[str] = None, + namespace: str = None, + config_required: List[str] = None, + ): + """Instantiates a ``@with_columns`` decorator. + + :param load_from: The functions or modules that will be used to generate the group of map operations. + :param columns_to_pass: The initial schema of the dataframe. This is used to determine which + upstream inputs should be taken from the dataframe, and which shouldn't. Note that, if this is + left empty (and external_inputs is as well), we will assume that all dependencies come + from the dataframe. This cannot be used in conjunction with on_input. + :param on_input: The name of the dataframe that we're modifying, as known to the subdag. + If you pass this in, you are responsible for extracting columns out. If not provided, you have + to pass columns_to_pass in, and we will extract the columns out on the first parameter for you. + :param select: The end nodes that represent columns to be appended to the original dataframe + via with_columns. Existing columns will be overridden. The selected nodes need to have the + corresponding column type, in this case pl.Series, to be appended to the original dataframe. + :param namespace: The namespace of the nodes, so they don't clash with the global namespace + and so this can be reused. If its left out, there will be no namespace (in which case you'll want + to be careful about repeating it/reusing the nodes in other parts of the DAG.) + :param config_required: the list of config keys that are required to resolve any functions. Pass in None\ + if you want the functions/modules to have access to all possible config. + """ + + if pass_dataframe_as is not None: + raise NotImplementedError( + "We currently do not support pass_dataframe_as for pandas. Please reach out if you need this " + "functionality." + ) + + super().__init__( + *load_from, + columns_to_pass=columns_to_pass, + on_input=on_input, + select=select, + namespace=namespace, + config_required=config_required, + dataframe_type=DATAFRAME_TYPE, + ) + + def _create_column_nodes( + self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]] + ) -> List[node.Node]: + output_type = params[inject_parameter] + + def temp_fn(**kwargs) -> Any: + return kwargs[inject_parameter] + + # We recreate the df node to use extract columns + temp_node = node.Node( + name=inject_parameter, + typ=output_type, + callabl=temp_fn, + input_types={inject_parameter: output_type}, + ) + + extract_columns_decorator = extract_columns(*self.initial_schema) + + out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn) + return out_nodes[1:] + + def get_initial_nodes( + self, fn: Callable, params: Dict[str, Type[Type]] + ) -> Tuple[str, Collection[node.Node]]: + """Selects the correct dataframe and optionally extracts out columns.""" + inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) + with_columns_base.validate_dataframe( + fn=fn, + inject_parameter=inject_parameter, + params=params, + required_type=self.dataframe_type, + ) + + initial_nodes = ( + [] + if self.target_dataframe is not None + else self._create_column_nodes(fn=fn, inject_parameter=inject_parameter, params=params) + ) + + return inject_parameter, initial_nodes + + def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + return subdag.collect_nodes(config, self.subdag_functions) + + def chain_subdag_nodes( + self, fn: Callable, inject_parameter: str, generated_nodes: Collection[node.Node] + ) -> node.Node: + "Node that adds to / overrides columns for the original dataframe based on selected output." + + if self.select is None: + self.select = [ + sink_node.name + for sink_node in generated_nodes + if sink_node.type == registry.get_column_type_from_df_type(self.dataframe_type) + ] + + def new_callable(**kwargs) -> Any: + df = kwargs[inject_parameter] + columns_to_append = {} + for column in self.select: + columns_to_append[column] = kwargs[column] + + return df.with_columns(**columns_to_append) + + column_type = registry.get_column_type_from_df_type(self.dataframe_type) + input_map = {column: column_type for column in self.select} + input_map[inject_parameter] = self.dataframe_type + merge_node = node.Node( + name="_append", + typ=self.dataframe_type, + callabl=new_callable, + input_types=input_map, + ) + output_nodes = generated_nodes + [merge_node] + return output_nodes, merge_node.name + + def validate(self, fn: Callable): + inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) + params = get_type_hints(fn) + with_columns_base.validate_dataframe( + fn=fn, + inject_parameter=inject_parameter, + params=params, + required_type=self.dataframe_type, + ) diff --git a/hamilton/plugins/h_polars_lazyframe.py b/hamilton/plugins/h_polars_lazyframe.py index a933762a7..00f4326e1 100644 --- a/hamilton/plugins/h_polars_lazyframe.py +++ b/hamilton/plugins/h_polars_lazyframe.py @@ -1,8 +1,16 @@ -from typing import Any, Dict, Type, Union +from types import ModuleType +from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union, get_type_hints import polars as pl -from hamilton import base +from hamilton import base, node, registry +from hamilton.function_modifiers.expanders import extract_columns +from hamilton.function_modifiers.recursive import ( + _default_inject_parameter, + subdag, + with_columns_base, +) +from hamilton.plugins.polars_lazyframe_extensions import DATAFRAME_TYPE class PolarsLazyFrameResult(base.ResultMixin): @@ -45,3 +53,214 @@ def build_result( def output_type(self) -> Type: return pl.LazyFrame + + +class with_columns(with_columns_base): + """Initializes a with_columns decorator for polars. + + This allows you to efficiently run groups of map operations on a dataframe. We support + both eager and lazy mode in polars. For lazy execution, use pl.LazyFrame and the subsequent + operations should be typed as pl.Expr. See examples/polars/with_columns for a practical + implementation in both variations. + + The lazy execution would be: + + .. code-block:: python + + # my_module.py + def a_b_average(a: pl.Expr, b: pl.Expr) -> pl.Expr: + return (a + b) / 2 + + + .. code-block:: python + + # with_columns_module.py + def a_plus_b(a: pl.Expr, b: pl.Expr) -> pl.Expr: + return a + b + + + # the with_columns call + @with_columns( + *[my_module], # Load from any module + *[a_plus_b], # or list operations directly + columns_to_pass=["a_from_df", "b_from_df"], # The columns to pass from the dataframe to + # the subdag + select=["a_plus_b", "a_b_average"], # The columns to append to the dataframe + ) + def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame: + # process, or just return unprocessed + ... + + Note that the operation is "append", meaning that the columns that are selected are appended + onto the dataframe. + + If the function takes multiple dataframes, the dataframe input to process will always be + the first argument. This will be passed to the subdag, transformed, and passed back to the function. + This follows the hamilton rule of reference by parameter name. To demonstarte this, in the code + above, the dataframe that is passed to the subdag is `initial_df`. That is transformed + by the subdag, and then returned as the final dataframe. + + You can read it as: + + "final_df is a function that transforms the upstream dataframe initial_df, running the transformations + from my_module. It starts with the columns a_from_df and b_from_df, and then adds the columns + a, b, and a_plus_b to the dataframe. It then returns the dataframe, and does some processing on it." + + In case you need more flexibility you can alternatively use ``on_input``, for example, + + .. code-block:: python + + # with_columns_module.py + def a_from_df() -> pl.Expr: + return pl.col(a).alias("a") / 100 + + def b_from_df() -> pd.Expr: + return pl.col(a).alias("b") / 100 + + + # the with_columns call + @with_columns( + *[my_module], + on_input="initial_df", + select=["a_from_df", "b_from_df", "a_plus_b", "a_b_average"], + ) + def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame: + # process, or just return unprocessed + ... + + the above would output a dataframe where the two columns ``a`` and ``b`` get + overwritten. + """ + + def __init__( + self, + *load_from: Union[Callable, ModuleType], + columns_to_pass: List[str] = None, + pass_dataframe_as: str = None, + on_input: str = None, + select: List[str] = None, + namespace: str = None, + config_required: List[str] = None, + ): + """Instantiates a ``@with_columns`` decorator. + + :param load_from: The functions or modules that will be used to generate the group of map operations. + :param columns_to_pass: The initial schema of the dataframe. This is used to determine which + upstream inputs should be taken from the dataframe, and which shouldn't. Note that, if this is + left empty (and external_inputs is as well), we will assume that all dependencies come + from the dataframe. This cannot be used in conjunction with on_input. + :param on_input: The name of the dataframe that we're modifying, as known to the subdag. + If you pass this in, you are responsible for extracting columns out. If not provided, you have + to pass columns_to_pass in, and we will extract the columns out on the first parameter for you. + :param select: The end nodes that represent columns to be appended to the original dataframe + via with_columns. Existing columns will be overridden. The selected nodes need to have the + corresponding column type, in this case pl.Expr, to be appended to the original dataframe. + :param namespace: The namespace of the nodes, so they don't clash with the global namespace + and so this can be reused. If its left out, there will be no namespace (in which case you'll want + to be careful about repeating it/reusing the nodes in other parts of the DAG.) + :param config_required: the list of config keys that are required to resolve any functions. Pass in None\ + if you want the functions/modules to have access to all possible config. + """ + + if pass_dataframe_as is not None: + raise NotImplementedError( + "We currently do not support pass_dataframe_as for pandas. Please reach out if you need this " + "functionality." + ) + + super().__init__( + *load_from, + columns_to_pass=columns_to_pass, + on_input=on_input, + select=select, + namespace=namespace, + config_required=config_required, + dataframe_type=DATAFRAME_TYPE, + ) + + def _create_column_nodes( + self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]] + ) -> List[node.Node]: + output_type = params[inject_parameter] + + def temp_fn(**kwargs) -> Any: + return kwargs[inject_parameter] + + # We recreate the df node to use extract columns + temp_node = node.Node( + name=inject_parameter, + typ=output_type, + callabl=temp_fn, + input_types={inject_parameter: output_type}, + ) + + extract_columns_decorator = extract_columns(*self.initial_schema) + + out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn) + return out_nodes[1:] + + def get_initial_nodes( + self, fn: Callable, params: Dict[str, Type[Type]] + ) -> Tuple[str, Collection[node.Node]]: + """Selects the correct dataframe and optionally extracts out columns.""" + inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) + + with_columns_base.validate_dataframe( + fn=fn, + inject_parameter=inject_parameter, + params=params, + required_type=self.dataframe_type, + ) + + initial_nodes = ( + [] + if self.target_dataframe is not None + else self._create_column_nodes(fn=fn, inject_parameter=inject_parameter, params=params) + ) + + return inject_parameter, initial_nodes + + def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + return subdag.collect_nodes(config, self.subdag_functions) + + def chain_subdag_nodes( + self, fn: Callable, inject_parameter: str, generated_nodes: Collection[node.Node] + ) -> node.Node: + "Node that adds to / overrides columns for the original dataframe based on selected output." + + if self.select is None: + self.select = [ + sink_node.name + for sink_node in generated_nodes + if sink_node.type == registry.get_column_type_from_df_type(self.dataframe_type) + ] + + def new_callable(**kwargs) -> Any: + df = kwargs[inject_parameter] + columns_to_append = {} + for column in self.select: + columns_to_append[column] = kwargs[column] + + return df.with_columns(**columns_to_append) + + column_type = registry.get_column_type_from_df_type(self.dataframe_type) + input_map = {column: column_type for column in self.select} + input_map[inject_parameter] = self.dataframe_type + merge_node = node.Node( + name="_append", + typ=self.dataframe_type, + callabl=new_callable, + input_types=input_map, + ) + output_nodes = generated_nodes + [merge_node] + return output_nodes, merge_node.name + + def validate(self, fn: Callable): + inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) + params = get_type_hints(fn) + with_columns_base.validate_dataframe( + fn=fn, + inject_parameter=inject_parameter, + params=params, + required_type=self.dataframe_type, + ) diff --git a/hamilton/plugins/h_spark.py b/hamilton/plugins/h_spark.py index b60eb7908..499a488b8 100644 --- a/hamilton/plugins/h_spark.py +++ b/hamilton/plugins/h_spark.py @@ -20,9 +20,10 @@ from hamilton.execution import graph_functions from hamilton.function_modifiers import base as fm_base from hamilton.function_modifiers import subdag -from hamilton.function_modifiers.recursive import assign_namespace, prune_nodes +from hamilton.function_modifiers.recursive import with_columns_base from hamilton.htypes import custom_subclass_check from hamilton.lifecycle import base as lifecycle_base +from hamilton.plugins.pyspark_pandas_extensions import DATAFRAME_TYPE logger = logging.getLogger(__name__) @@ -904,12 +905,13 @@ def _identify_upstream_dataframe_nodes(nodes: List[node.Node]) -> List[str]: return list(df_deps) -class with_columns(fm_base.NodeCreator): +class with_columns(with_columns_base): def __init__( self, *load_from: Union[Callable, ModuleType], columns_to_pass: List[str] = None, pass_dataframe_as: str = None, + on_input: str = None, select: List[str] = None, namespace: str = None, mode: str = "append", @@ -992,29 +994,24 @@ def final_df(initial_df: ps.DataFrame) -> ps.DataFrame: :param config_required: the list of config keys that are required to resolve any functions. Pass in None\ if you want the functions/modules to have access to all possible config. """ - self.subdag_functions = subdag.collect_functions(load_from) - self.select = select - self.initial_schema = columns_to_pass - if (pass_dataframe_as is not None and columns_to_pass is not None) or ( - pass_dataframe_as is None and columns_to_pass is None - ): - raise ValueError( - "You must specify only one of columns_to_pass and " - "pass_dataframe_as. " - "This is because specifying pass_dataframe_as injects into " - "the set of columns, allowing you to perform your own extraction" - "from the dataframe. We then execute all columns in the sbudag" - "in order, passing in that initial dataframe. If you want" - "to reference columns in your code, you'll have to specify " - "the set of initial columns, and allow the subdag decorator " - "to inject the dataframe through. The initial columns tell " - "us which parameters to take from that dataframe, so we can" - "feed the right data into the right columns." + + if on_input is not None: + raise NotImplementedError( + "We currently do not support on_input for spark. Please reach out if you need this " + "functionality." ) - self.dataframe_subdag_param = pass_dataframe_as - self.namespace = namespace + + super().__init__( + *load_from, + columns_to_pass=columns_to_pass, + pass_dataframe_as=pass_dataframe_as, + select=select, + namespace=namespace, + config_required=config_required, + dataframe_type=DATAFRAME_TYPE, + ) + self.mode = mode - self.config_required = config_required @staticmethod def _prep_nodes(initial_nodes: List[node.Node]) -> List[node.Node]: @@ -1118,42 +1115,43 @@ def _validate_dataframe_subdag_parameter(self, nodes: List[node.Node], fn_name: def required_config(self) -> List[str]: return self.config_required - def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]: - """Generates nodes in the with_columns groups. This does the following: - - 1. Collects all the nodes from the subdag functions - 2. Prunes them to only include the ones that are upstream from the select columns - 3. Sorts them topologically - 4. Creates a new node for each one, injecting the dataframe parameter into the first one - 5. Creates a new node for the final one, injecting the last node into that one - 6. Returns the list of nodes + def get_initial_nodes( + self, fn: Callable, params: Dict[str, Type[Type]] + ) -> Tuple[str, Collection[node.Node]]: + inject_parameter = _derive_first_dataframe_parameter_from_fn(fn=fn) + with_columns_base.validate_dataframe( + fn=fn, + inject_parameter=inject_parameter, + params=params, + required_type=self.dataframe_type, + ) + # Cannot extract columns in pyspark + initial_nodes = [] + return inject_parameter, initial_nodes - :param fn: Function to generate from - :param config: Config to use for generating/collecting nodes - :return: List of nodes that this function produces - """ - namespace = fn.__name__ if self.namespace is None else self.namespace + def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: initial_nodes = subdag.collect_nodes(config, self.subdag_functions) transformed_nodes = with_columns._prep_nodes(initial_nodes) + self._validate_dataframe_subdag_parameter(transformed_nodes, fn.__qualname__) - pruned_nodes = prune_nodes(transformed_nodes, self.select) - if len(pruned_nodes) == 0: - raise ValueError( - f"No nodes found upstream from select columns: {self.select} for function: " - f"{fn.__qualname__}" - ) - sorted_initial_nodes = graph_functions.topologically_sort_nodes(pruned_nodes) - output_nodes = [] - inject_parameter = _derive_first_dataframe_parameter_from_fn(fn) - current_dataframe_node = inject_parameter + return transformed_nodes + + def chain_subdag_nodes( + self, fn: Callable, inject_parameter: str, generated_nodes: Collection[node.Node] + ) -> node.Node: + generated_nodes = graph_functions.topologically_sort_nodes(generated_nodes) + # Columns that it is dependent on could be from the group of transforms created - columns_produced_within_mapgroup = {node_.name for node_ in pruned_nodes} + columns_produced_within_mapgroup = {node_.name for node_ in generated_nodes} + # Or from the dataframe passed in... columns_passed_in_from_dataframe = ( set(self.initial_schema) if self.initial_schema is not None else [] ) + + current_dataframe_node = inject_parameter + output_nodes = [] drop_list = [] - # Or from the dataframe passed in... - for node_ in sorted_initial_nodes: + for node_ in generated_nodes: # dependent columns are broken into two sets: # 1. Those that come from the group of transforms dependent_columns_in_mapgroup = { @@ -1183,18 +1181,20 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node dependent_columns_in_mapgroup, dependent_columns_in_dataframe, ) + if self.select is not None and sparkified.name not in self.select: # we need to create a drop list because we don't want to drop # original columns from the DF by accident. drop_list.append(sparkified.name) + output_nodes.append(sparkified) current_dataframe_node = sparkified.name - # We get the final node, which is the function we're using - # and reassign inputs to be the dataframe + if self.mode == "select": - # this selects over the original DF and the additions + # Have to redo this here since for spark the nodes are of type dataframe and not columns + # so with_columns.inject_nodes does not correctly select all the sink nodes select_columns = ( - self.select if self.select is not None else [item.name for item in output_nodes] + self.select if self.select is not None else [item.name for item in generated_nodes] ) select_node = with_columns.create_selector_node( upstream_name=current_dataframe_node, @@ -1214,11 +1214,8 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node ) output_nodes.append(select_node) current_dataframe_node = select_node.name - output_nodes = subdag.add_namespace(output_nodes, namespace) - final_node = node.Node.from_fn(fn).reassign_inputs( - {inject_parameter: assign_namespace(current_dataframe_node, namespace)} - ) - return output_nodes + [final_node] + + return output_nodes, current_dataframe_node def validate(self, fn: Callable): _derive_first_dataframe_parameter_from_fn(fn) diff --git a/hamilton/plugins/polars_extensions.py b/hamilton/plugins/polars_extensions.py index 7fe500c7d..2c8b9d5c7 100644 --- a/hamilton/plugins/polars_extensions.py +++ b/hamilton/plugins/polars_extensions.py @@ -48,7 +48,7 @@ def get_column_polars(df: pl.DataFrame, column_name: str) -> pl.Series: def fill_with_scalar_polars(df: pl.DataFrame, column_name: str, scalar_value: Any) -> pl.DataFrame: if not isinstance(scalar_value, pl.Series): scalar_value = [scalar_value] - return df.with_column(pl.Series(name=column_name, values=scalar_value)) + return df.with_columns(pl.Series(name=column_name, values=scalar_value)) register_types() diff --git a/hamilton/plugins/polars_lazyframe_extensions.py b/hamilton/plugins/polars_lazyframe_extensions.py index 6130dfce8..29a82396a 100644 --- a/hamilton/plugins/polars_lazyframe_extensions.py +++ b/hamilton/plugins/polars_lazyframe_extensions.py @@ -41,8 +41,8 @@ from hamilton.io.data_adapters import DataLoader DATAFRAME_TYPE = pl.LazyFrame -COLUMN_TYPE = None -COLUMN_FRIENDLY_DF_TYPE = False +COLUMN_TYPE = pl.Expr +# COLUMN_FRIENDLY_DF_TYPE = False def register_types(): @@ -50,6 +50,25 @@ def register_types(): registry.register_types("polars_lazyframe", DATAFRAME_TYPE, COLUMN_TYPE) +@registry.get_column.register(pl.LazyFrame) +def get_column_polars_lazyframe(df: pl.LazyFrame, column_name: str) -> pl.Expr: + # TODO: figure out if we can validate this here already or need to wait to the end + # when query.collect() resolves the lazy frame + # df.collect_schema().names() gives a list of names but it can be expensive + # https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.columns.html + # https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.collect_schema.html#polars.LazyFrame.collect_schema + return pl.col(column_name) + + +@registry.fill_with_scalar.register(pl.LazyFrame) +def fill_with_scalar_polars_lazyframe( + df: pl.LazyFrame, column_name: str, scalar_value: Any +) -> pl.LazyFrame: + if not isinstance(scalar_value, pl.Expr): + scalar_value = pl.lit(scalar_value) + return df.with_columns(scalar_value.alias(column_name)) + + register_types() diff --git a/hamilton/plugins/polars_pre_1_0_0_extension.py b/hamilton/plugins/polars_pre_1_0_0_extension.py index 39b75c262..3814f0fe1 100644 --- a/hamilton/plugins/polars_pre_1_0_0_extension.py +++ b/hamilton/plugins/polars_pre_1_0_0_extension.py @@ -64,7 +64,7 @@ def get_column_polars(df: pl.DataFrame, column_name: str) -> pl.Series: def fill_with_scalar_polars(df: pl.DataFrame, column_name: str, scalar_value: Any) -> pl.DataFrame: if not isinstance(scalar_value, pl.Series): scalar_value = [scalar_value] - return df.with_column(pl.Series(name=column_name, values=scalar_value)) + return df.with_columns(pl.Series(name=column_name, values=scalar_value)) @dataclasses.dataclass diff --git a/plugin_tests/h_pandas/resources/with_columns_end_to_end.py b/plugin_tests/h_pandas/resources/with_columns_end_to_end.py index 16d493e46..58192b119 100644 --- a/plugin_tests/h_pandas/resources/with_columns_end_to_end.py +++ b/plugin_tests/h_pandas/resources/with_columns_end_to_end.py @@ -59,7 +59,9 @@ def col_3(initial_df: pd.DataFrame) -> pd.Series: @with_columns( col_3, - pass_dataframe_as="initial_df", + multiply_3__by_5, + multiply_3__by_7, + on_input="initial_df", select=["col_3", "multiply_3"], ) def final_df_2(initial_df: pd.DataFrame) -> pd.DataFrame: diff --git a/plugin_tests/h_pandas/test_with_columns.py b/plugin_tests/h_pandas/test_with_columns.py index b3a84d666..f9012e718 100644 --- a/plugin_tests/h_pandas/test_with_columns.py +++ b/plugin_tests/h_pandas/test_with_columns.py @@ -8,58 +8,96 @@ from .resources import with_columns_end_to_end -def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: - return col_1 + 100 +def test__create_column_nodes(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + def dummy_df() -> pd.DataFrame: + return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) + + def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: + return upstream_df + + decorator = with_columns( + dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"] + ) -def test_detect_duplicate_nodes(): - node_a = node.Node.from_fn(dummy_fn_with_columns, name="a") - node_b = node.Node.from_fn(dummy_fn_with_columns, name="a") - node_c = node.Node.from_fn(dummy_fn_with_columns, name="c") + column_nodes = decorator._create_column_nodes( + fn=target_fn, inject_parameter="upstream_df", params={"upstream_df": pd.DataFrame} + ) + + col1 = column_nodes[0] + col2 = column_nodes[1] + + assert col1.name == "col_1" + assert col2.name == "col_2" + + pd.testing.assert_series_equal( + col1.callable(upstream_df=dummy_df()), + pd.Series([1, 2, 3, 4]), + check_names=False, + ) - if not with_columns._check_for_duplicates([node_a, node_b, node_c]): - raise (AssertionError) + pd.testing.assert_series_equal( + col2.callable(upstream_df=dummy_df()), + pd.Series([11, 12, 13, 14]), + check_names=False, + ) - if with_columns._check_for_duplicates([node_a, node_c]): - raise (AssertionError) +def test__get_initial_nodes_when_extracting_columns(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 -def test_select_not_empty(): - error_message = "Please specify at least one column to append or update." + def dummy_df() -> pd.DataFrame: + return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) - with pytest.raises(ValueError) as e: - with_columns(dummy_fn_with_columns) - assert str(e.value) == error_message + def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: + return upstream_df + dummy_node = node.Node.from_fn(target_fn) -def test_columns_to_pass_and_pass_dataframe_as_raises_error(): - error_message = ( - "You must specify only one of columns_to_pass and " - "pass_dataframe_as. " - "This is because specifying pass_dataframe_as injects into " - "the set of columns, allowing you to perform your own extraction" - "from the dataframe. We then execute all columns in the sbudag" - "in order, passing in that initial dataframe. If you want" - "to reference columns in your code, you'll have to specify " - "the set of initial columns, and allow the subdag decorator " - "to inject the dataframe through. The initial columns tell " - "us which parameters to take from that dataframe, so we can" - "feed the right data into the right columns." + decorator = with_columns( + dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"] ) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) - with pytest.raises(ValueError) as e: - with_columns( - dummy_fn_with_columns, columns_to_pass=["a"], pass_dataframe_as="a", select=["a"] - ) - assert str(e.value) == error_message + inject_parameter, initial_nodes = decorator.get_initial_nodes( + fn=target_fn, params=injectable_params + ) + assert inject_parameter == "upstream_df" + assert len(initial_nodes) == 2 -def test_first_parameter_is_dataframe(): - error_message = ( - "First argument has to be a pandas DataFrame. If you wish to use a " - "different argument, please use `pass_dataframe_as` option." + +def test__get_initial_nodes_when_passing_dataframe(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + + def dummy_df() -> pd.DataFrame: + return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) + + def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: + return upstream_df + + dummy_node = node.Node.from_fn(target_fn) + + decorator = with_columns( + dummy_fn_with_columns, on_input="upstream_df", select=["dummy_fn_with_columns"] + ) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + inject_parameter, initial_nodes = decorator.get_initial_nodes( + fn=target_fn, params=injectable_params ) + assert inject_parameter == "upstream_df" + assert len(initial_nodes) == 0 + + +def test_first_parameter_is_dataframe(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + def target_fn(upstream_df: int) -> pd.DataFrame: return upstream_df @@ -70,23 +108,25 @@ def target_fn(upstream_df: int) -> pd.DataFrame: ) injectable_params = NodeInjector.find_injectable_params([dummy_node]) - with pytest.raises(ValueError) as e: - decorator._get_inital_nodes(fn=target_fn, params=injectable_params) - - assert str(e.value) == error_message + # Raises error that is not pandas dataframe + with pytest.raises(NotImplementedError): + decorator.get_initial_nodes(fn=target_fn, params=injectable_params) def test_create_column_nodes_pass_dataframe(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + def target_fn(some_var: int, upstream_df: pd.DataFrame) -> pd.DataFrame: return upstream_df dummy_node = node.Node.from_fn(target_fn) decorator = with_columns( - dummy_fn_with_columns, pass_dataframe_as="upstream_df", select=["dummy_fn_with_columns"] + dummy_fn_with_columns, on_input="upstream_df", select=["dummy_fn_with_columns"] ) injectable_params = NodeInjector.find_injectable_params([dummy_node]) - inject_parameter, initial_nodes = decorator._get_inital_nodes( + inject_parameter, initial_nodes = decorator.get_initial_nodes( fn=target_fn, params=injectable_params ) @@ -95,6 +135,9 @@ def target_fn(some_var: int, upstream_df: pd.DataFrame) -> pd.DataFrame: def test_create_column_nodes_extract_single_columns(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + def dummy_df() -> pd.DataFrame: return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) @@ -107,8 +150,7 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: dummy_fn_with_columns, columns_to_pass=["col_1"], select=["dummy_fn_with_columns"] ) injectable_params = NodeInjector.find_injectable_params([dummy_node]) - - inject_parameter, initial_nodes = decorator._get_inital_nodes( + inject_parameter, initial_nodes = decorator.get_initial_nodes( fn=target_fn, params=injectable_params ) @@ -124,6 +166,9 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: def test_create_column_nodes_extract_multiple_columns(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + def dummy_df() -> pd.DataFrame: return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) @@ -137,7 +182,7 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: ) injectable_params = NodeInjector.find_injectable_params([dummy_node]) - inject_parameter, initial_nodes = decorator._get_inital_nodes( + inject_parameter, initial_nodes = decorator.get_initial_nodes( fn=target_fn, params=injectable_params ) @@ -160,6 +205,9 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: def test_no_matching_select_column_error(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: return upstream_df @@ -171,17 +219,14 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: ) injectable_params = NodeInjector.find_injectable_params([dummy_node]) - error_message = ( - f"No nodes found upstream from select columns: {select} for function: " - f"{target_fn.__qualname__}" - ) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError): decorator.inject_nodes(injectable_params, {}, fn=target_fn) - assert str(e.value) == error_message - def test_append_into_original_df(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + def dummy_df() -> pd.DataFrame: return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) @@ -191,13 +236,17 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: decorator = with_columns( dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"] ) - merge_node = decorator._create_merge_node(upstream_node="upstream_df", node_name="merge_node") + + output_nodes, _ = decorator.chain_subdag_nodes( + fn=target_fn, inject_parameter="upstream_df", generated_nodes=[] + ) + merge_node = output_nodes[-1] output_df = merge_node.callable( upstream_df=dummy_df(), dummy_fn_with_columns=dummy_fn_with_columns(col_1=pd.Series([1, 2, 3, 4])), ) - assert merge_node.name == "merge_node" + assert merge_node.name == "__append" assert merge_node.type == pd.DataFrame pd.testing.assert_series_equal(output_df["col_1"], pd.Series([1, 2, 3, 4]), check_names=False) @@ -219,11 +268,14 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: def col_1() -> pd.Series: return pd.Series([0, 3, 5, 7]) - decorator = with_columns(col_1, pass_dataframe_as=["upstream_df"], select=["col_1"]) - merge_node = decorator._create_merge_node(upstream_node="upstream_df", node_name="merge_node") + decorator = with_columns(col_1, on_input="upstream_df", select=["col_1"]) + output_nodes, _ = decorator.chain_subdag_nodes( + fn=target_fn, inject_parameter="upstream_df", generated_nodes=[] + ) + merge_node = output_nodes[-1] output_df = merge_node.callable(upstream_df=dummy_df(), col_1=col_1()) - assert merge_node.name == "merge_node" + assert merge_node.name == "__append" assert merge_node.type == pd.DataFrame pd.testing.assert_series_equal(output_df["col_1"], pd.Series([0, 3, 5, 7]), check_names=False) @@ -233,6 +285,9 @@ def col_1() -> pd.Series: def test_assign_custom_namespace_with_columns(): + def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series: + return col_1 + 100 + def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: return upstream_df @@ -244,12 +299,11 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame: namespace="dummy_namespace", ) nodes_ = decorator.transform_dag([dummy_node], {}, target_fn) - + print(nodes_) assert nodes_[0].name == "target_fn" - assert nodes_[1].name == "dummy_namespace.col_1" - assert nodes_[2].name == "dummy_namespace.col_2" - assert nodes_[3].name == "dummy_namespace.dummy_fn_with_columns" - assert nodes_[4].name == "dummy_namespace.__append" + assert nodes_[1].name == "dummy_namespace.dummy_fn_with_columns" + assert nodes_[2].name == "dummy_namespace.col_1" + assert nodes_[3].name == "dummy_namespace.__append" def test_end_to_end_with_columns_automatic_extract(): diff --git a/plugin_tests/h_polars/__init__.py b/plugin_tests/h_polars/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugin_tests/h_polars/conftest.py b/plugin_tests/h_polars/conftest.py new file mode 100644 index 000000000..bc5ef5b5a --- /dev/null +++ b/plugin_tests/h_polars/conftest.py @@ -0,0 +1,4 @@ +from hamilton import telemetry + +# disable telemetry for all tests! +telemetry.disable_telemetry() diff --git a/plugin_tests/h_polars/resources/__init__.py b/plugin_tests/h_polars/resources/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugin_tests/h_polars/resources/with_columns_end_to_end.py b/plugin_tests/h_polars/resources/with_columns_end_to_end.py new file mode 100644 index 000000000..a893818fc --- /dev/null +++ b/plugin_tests/h_polars/resources/with_columns_end_to_end.py @@ -0,0 +1,68 @@ +import polars as pl + +from hamilton.function_modifiers import config +from hamilton.plugins.h_polars import with_columns + + +def upstream_factor() -> int: + return 3 + + +def initial_df() -> pl.DataFrame: + return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14], "col_3": [1, 1, 1, 1]}) + + +def subtract_1_from_2(col_1: pl.Series, col_2: pl.Series) -> pl.Series: + return col_2 - col_1 + + +@config.when(factor=5) +def multiply_3__by_5(col_3: pl.Series) -> pl.Series: + return col_3 * 5 + + +@config.when(factor=7) +def multiply_3__by_7(col_3: pl.Series) -> pl.Series: + return col_3 * 7 + + +def add_1_by_user_adjustment_factor(col_1: pl.Series, user_factor: int) -> pl.Series: + return col_1 + user_factor + + +def multiply_2_by_upstream_3(col_2: pl.Series, upstream_factor: int) -> pl.Series: + return col_2 * upstream_factor + + +@with_columns( + subtract_1_from_2, + multiply_3__by_5, + multiply_3__by_7, + add_1_by_user_adjustment_factor, + multiply_2_by_upstream_3, + columns_to_pass=["col_1", "col_2", "col_3"], + select=[ + "subtract_1_from_2", + "multiply_3", + "add_1_by_user_adjustment_factor", + "multiply_2_by_upstream_3", + ], + namespace="some_subdag", +) +def final_df(initial_df: pl.DataFrame) -> pl.DataFrame: + return initial_df + + +def col_3(initial_df: pl.DataFrame) -> pl.Series: + return pl.Series([0, 2, 4, 6]) + + +@with_columns( + col_3, + multiply_3__by_5, + multiply_3__by_7, + on_input="initial_df", + select=["col_3", "multiply_3"], +) +def final_df_2(initial_df: pl.DataFrame) -> pl.DataFrame: + return initial_df diff --git a/plugin_tests/h_polars/resources/with_columns_end_to_end_lazy.py b/plugin_tests/h_polars/resources/with_columns_end_to_end_lazy.py new file mode 100644 index 000000000..367cfacf4 --- /dev/null +++ b/plugin_tests/h_polars/resources/with_columns_end_to_end_lazy.py @@ -0,0 +1,80 @@ +import polars as pl + +from hamilton.function_modifiers import config +from hamilton.plugins.h_polars_lazyframe import with_columns + + +def upstream_factor() -> int: + return 3 + + +def initial_df() -> pl.LazyFrame: + return pl.DataFrame( + {"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14], "col_3": [1, 1, 1, 1]} + ).lazy() + + +def subtract_1_from_2(col_1: pl.Expr, col_2: pl.Expr) -> pl.Expr: + return col_2 - col_1 + + +@config.when(factor=5) +def multiply_3__by_5(col_3: pl.Expr) -> pl.Expr: + return col_3 * 5 + + +@config.when(factor=7) +def multiply_3__by_7(col_3: pl.Expr) -> pl.Expr: + return col_3 * 7 + + +def add_1_by_user_adjustment_factor(col_1: pl.Expr, user_factor: int) -> pl.Expr: + return col_1 + user_factor + + +def multiply_2_by_upstream_3(col_2: pl.Expr, upstream_factor: int) -> pl.Expr: + return col_2 * upstream_factor + + +@with_columns( + subtract_1_from_2, + multiply_3__by_5, + multiply_3__by_7, + add_1_by_user_adjustment_factor, + multiply_2_by_upstream_3, + columns_to_pass=["col_1", "col_2", "col_3"], + select=[ + "subtract_1_from_2", + "multiply_3", + "add_1_by_user_adjustment_factor", + "multiply_2_by_upstream_3", + ], + namespace="some_subdag", +) +def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame: + return initial_df + + +def col_1(initial_df: pl.LazyFrame) -> pl.Expr: + return pl.col("col_1") + + +@config.when(factor=5) +def multiply_1__by_5(col_1: pl.Expr) -> pl.Expr: + return col_1 * 5 + + +@config.when_not(factor=5) +def multiply_1__by_1(col_1: pl.Expr) -> pl.Expr: + return col_1 * 1 + + +@with_columns( + col_1, + multiply_1__by_5, + multiply_1__by_1, + on_input="initial_df", + select=["col_1", "multiply_1"], +) +def final_df_2(initial_df: pl.LazyFrame) -> pl.LazyFrame: + return initial_df diff --git a/plugin_tests/h_polars/test_with_columns.py b/plugin_tests/h_polars/test_with_columns.py new file mode 100644 index 000000000..151347fb7 --- /dev/null +++ b/plugin_tests/h_polars/test_with_columns.py @@ -0,0 +1,265 @@ +import polars as pl +import pytest +from polars.testing import assert_frame_equal + +from hamilton import driver, node +from hamilton.function_modifiers.base import NodeInjector +from hamilton.plugins.h_polars import with_columns + +from .resources import with_columns_end_to_end + + +def test_create_column_nodes_pass_dataframe(): + def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series: + return col_1 + 100 + + def target_fn(some_var: int, upstream_df: pl.DataFrame) -> pl.DataFrame: + return upstream_df + + dummy_node = node.Node.from_fn(target_fn) + + decorator = with_columns( + dummy_fn_with_columns, on_input="upstream_df", select=["dummy_fn_with_columns"] + ) + + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + inject_parameter, initial_nodes = decorator.get_initial_nodes( + fn=target_fn, params=injectable_params + ) + + assert inject_parameter == "upstream_df" + assert len(initial_nodes) == 0 + + +def test_create_column_nodes_extract_single_columns(): + def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series: + return col_1 + 100 + + def dummy_df() -> pl.DataFrame: + return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) + + def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame: + return upstream_df + + dummy_node = node.Node.from_fn(target_fn) + + decorator = with_columns( + dummy_fn_with_columns, columns_to_pass=["col_1"], select=["dummy_fn_with_columns"] + ) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + inject_parameter, initial_nodes = decorator.get_initial_nodes( + fn=target_fn, params=injectable_params + ) + + assert inject_parameter == "upstream_df" + assert len(initial_nodes) == 1 + assert initial_nodes[0].name == "col_1" + assert initial_nodes[0].type == pl.Series + pl.testing.assert_series_equal( + initial_nodes[0].callable(upstream_df=dummy_df()), + pl.Series([1, 2, 3, 4]), + check_names=False, + ) + + +def test_create_column_nodes_extract_multiple_columns(): + def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series: + return col_1 + 100 + + def dummy_df() -> pl.DataFrame: + return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) + + def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame: + return upstream_df + + dummy_node = node.Node.from_fn(target_fn) + + decorator = with_columns( + dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"] + ) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + inject_parameter, initial_nodes = decorator.get_initial_nodes( + fn=target_fn, params=injectable_params + ) + + assert inject_parameter == "upstream_df" + assert len(initial_nodes) == 2 + assert initial_nodes[0].name == "col_1" + assert initial_nodes[1].name == "col_2" + assert initial_nodes[0].type == pl.Series + assert initial_nodes[1].type == pl.Series + pl.testing.assert_series_equal( + initial_nodes[0].callable(upstream_df=dummy_df()), + pl.Series([1, 2, 3, 4]), + check_names=False, + ) + pl.testing.assert_series_equal( + initial_nodes[1].callable(upstream_df=dummy_df()), + pl.Series([11, 12, 13, 14]), + check_names=False, + ) + + +def test_no_matching_select_column_error(): + def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series: + return col_1 + 100 + + def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame: + return upstream_df + + dummy_node = node.Node.from_fn(target_fn) + select = "wrong_column" + + decorator = with_columns( + dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=select + ) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + with pytest.raises(ValueError): + decorator.inject_nodes(injectable_params, {}, fn=target_fn) + + +def test_append_into_original_df(): + def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series: + return col_1 + 100 + + def dummy_df() -> pl.DataFrame: + return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) + + def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame: + return upstream_df + + decorator = with_columns( + dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"] + ) + + output_nodes, _ = decorator.chain_subdag_nodes( + fn=target_fn, inject_parameter="upstream_df", generated_nodes=[] + ) + merge_node = output_nodes[-1] + + output_df = merge_node.callable( + upstream_df=dummy_df(), + dummy_fn_with_columns=dummy_fn_with_columns(col_1=pl.Series([1, 2, 3, 4])), + ) + assert merge_node.name == "__append" + assert merge_node.type == pl.DataFrame + + pl.testing.assert_series_equal(output_df["col_1"], pl.Series([1, 2, 3, 4]), check_names=False) + pl.testing.assert_series_equal( + output_df["col_2"], pl.Series([11, 12, 13, 14]), check_names=False + ) + pl.testing.assert_series_equal( + output_df["dummy_fn_with_columns"], pl.Series([101, 102, 103, 104]), check_names=False + ) + + +def test_override_original_column_in_df(): + def dummy_df() -> pl.DataFrame: + return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) + + def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame: + return upstream_df + + def col_1() -> pl.Series: + return pl.col("col_1") * 100 + + decorator = with_columns(col_1, on_input="upstream_df", select=["col_1"]) + + output_nodes, _ = decorator.chain_subdag_nodes( + fn=target_fn, inject_parameter="upstream_df", generated_nodes=[] + ) + merge_node = output_nodes[-1] + + output_df = merge_node.callable(upstream_df=dummy_df(), col_1=col_1()) + assert merge_node.name == "__append" + assert merge_node.type == pl.DataFrame + + pl.testing.assert_series_equal( + output_df["col_1"], pl.Series([100, 200, 300, 400]), check_names=False + ) + pl.testing.assert_series_equal( + output_df["col_2"], pl.Series([11, 12, 13, 14]), check_names=False + ) + + +def test_assign_custom_namespace_with_columns(): + def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series: + return col_1 + 100 + + def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame: + return upstream_df + + dummy_node = node.Node.from_fn(target_fn) + decorator = with_columns( + dummy_fn_with_columns, + columns_to_pass=["col_1", "col_2"], + select=["dummy_fn_with_columns"], + namespace="dummy_namespace", + ) + nodes_ = decorator.transform_dag([dummy_node], {}, target_fn) + + assert nodes_[0].name == "target_fn" + assert nodes_[1].name == "dummy_namespace.dummy_fn_with_columns" + assert nodes_[2].name == "dummy_namespace.col_1" + assert nodes_[3].name == "dummy_namespace.__append" + + +def test_end_to_end_with_columns_automatic_extract(): + config_5 = { + "factor": 5, + } + dr = driver.Builder().with_modules(with_columns_end_to_end).with_config(config_5).build() + result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"] + + expected_df = pl.DataFrame( + { + "col_1": [1, 2, 3, 4], + "col_2": [11, 12, 13, 14], + "col_3": [1, 1, 1, 1], + "subtract_1_from_2": [10, 10, 10, 10], + "multiply_3": [5, 5, 5, 5], + "add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004], + "multiply_2_by_upstream_3": [33, 36, 39, 42], + } + ) + pl.testing.assert_frame_equal(result, expected_df) + + config_7 = { + "factor": 7, + } + dr = driver.Builder().with_modules(with_columns_end_to_end).with_config(config_7).build() + result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"] + + expected_df = pl.DataFrame( + { + "col_1": [1, 2, 3, 4], + "col_2": [11, 12, 13, 14], + "col_3": [1, 1, 1, 1], + "subtract_1_from_2": [10, 10, 10, 10], + "multiply_3": [7, 7, 7, 7], + "add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004], + "multiply_2_by_upstream_3": [33, 36, 39, 42], + } + ) + assert_frame_equal(result, expected_df) + + +def test_end_to_end_with_columns_pass_dataframe(): + config_5 = { + "factor": 5, + } + dr = driver.Builder().with_modules(with_columns_end_to_end).with_config(config_5).build() + + result = dr.execute(final_vars=["final_df_2"])["final_df_2"] + expected_df = pl.DataFrame( + { + "col_1": [1, 2, 3, 4], + "col_2": [11, 12, 13, 14], + "col_3": [0, 2, 4, 6], + "multiply_3": [0, 10, 20, 30], + } + ) + assert_frame_equal(result, expected_df) diff --git a/plugin_tests/h_polars/test_with_columns_lazy.py b/plugin_tests/h_polars/test_with_columns_lazy.py new file mode 100644 index 000000000..2cb52c4db --- /dev/null +++ b/plugin_tests/h_polars/test_with_columns_lazy.py @@ -0,0 +1,64 @@ +import polars as pl +from polars.testing import assert_frame_equal + +from hamilton import driver + +from .resources import with_columns_end_to_end_lazy + + +def test_end_to_end_with_columns_automatic_extract_lazy(): + config_5 = { + "factor": 5, + } + dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_5).build() + result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"] + + expected_df = pl.DataFrame( + { + "col_1": [1, 2, 3, 4], + "col_2": [11, 12, 13, 14], + "col_3": [1, 1, 1, 1], + "subtract_1_from_2": [10, 10, 10, 10], + "multiply_3": [5, 5, 5, 5], + "add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004], + "multiply_2_by_upstream_3": [33, 36, 39, 42], + } + ) + pl.testing.assert_frame_equal(result.collect(), expected_df) + + config_7 = { + "factor": 7, + } + dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_7).build() + result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"] + + expected_df = pl.DataFrame( + { + "col_1": [1, 2, 3, 4], + "col_2": [11, 12, 13, 14], + "col_3": [1, 1, 1, 1], + "subtract_1_from_2": [10, 10, 10, 10], + "multiply_3": [7, 7, 7, 7], + "add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004], + "multiply_2_by_upstream_3": [33, 36, 39, 42], + } + ) + assert_frame_equal(result.collect(), expected_df) + + +def test_end_to_end_with_columns_pass_dataframe_lazy(): + config_5 = { + "factor": 5, + } + dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_5).build() + + result = dr.execute(final_vars=["final_df_2"])["final_df_2"] + expected_df = pl.DataFrame( + { + "col_1": [1, 2, 3, 4], + "col_2": [11, 12, 13, 14], + "col_3": [1, 1, 1, 1], + "multiply_1": [5, 10, 15, 20], + } + ) + assert_frame_equal(result.collect(), expected_df) diff --git a/plugin_tests/h_spark/test_h_spark.py b/plugin_tests/h_spark/test_h_spark.py index 2f12ea5d8..36bc295a4 100644 --- a/plugin_tests/h_spark/test_h_spark.py +++ b/plugin_tests/h_spark/test_h_spark.py @@ -11,6 +11,8 @@ from pyspark.sql.functions import column from hamilton import base, driver, htypes, node +from hamilton.function_modifiers.base import NodeInjector +from hamilton.function_modifiers.recursive import prune_nodes from hamilton.plugins import h_spark from hamilton.plugins.h_spark import SparkInputValidator @@ -569,7 +571,7 @@ def test_prune_nodes_no_select(): node.Node.from_fn(fn) for fn in [basic_spark_dag.a, basic_spark_dag.b, basic_spark_dag.c] ] select = None - assert {n for n in h_spark.prune_nodes(nodes, select)} == set(nodes) + assert {n for n in prune_nodes(nodes, select)} == set(nodes) def test_prune_nodes_single_select(): @@ -577,7 +579,7 @@ def test_prune_nodes_single_select(): node.Node.from_fn(fn) for fn in [basic_spark_dag.a, basic_spark_dag.b, basic_spark_dag.c] ] select = ["a", "b"] - assert {n for n in h_spark.prune_nodes(nodes, select)} == set(nodes[0:2]) + assert {n for n in prune_nodes(nodes, select)} == set(nodes[0:2]) def test_generate_nodes_invalid_select(): @@ -593,7 +595,10 @@ def test_generate_nodes_invalid_select(): def df_as_pandas(df: DataFrame) -> pd.DataFrame: return df.toPandas() - dec.generate_nodes(df_as_pandas, {}) + dummy_node = node.Node.from_fn(df_as_pandas) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas) def test_with_columns_generate_nodes_no_select(): @@ -607,13 +612,16 @@ def test_with_columns_generate_nodes_no_select(): def df_as_pandas(df: DataFrame) -> pd.DataFrame: return df.toPandas() - nodes = dec.generate_nodes(df_as_pandas, {}) + dummy_node = node.Node.from_fn(df_as_pandas) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas) + nodes_by_names = {n.name: n for n in nodes} assert set(nodes_by_names.keys()) == { "df_as_pandas.a", "df_as_pandas.b", "df_as_pandas.c", - "df_as_pandas", } @@ -629,9 +637,14 @@ def test_with_columns_generate_nodes_select(): def df_as_pandas(df: DataFrame) -> pd.DataFrame: return df.toPandas() - nodes = dec.generate_nodes(df_as_pandas, {}) + dummy_node = node.Node.from_fn(df_as_pandas) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas) nodes_by_names = {n.name: n for n in nodes} - assert set(nodes_by_names.keys()) == {"df_as_pandas.c", "df_as_pandas"} + assert set(nodes_by_names.keys()) == { + "df_as_pandas.c", + } def test_with_columns_generate_nodes_select_append_mode(): @@ -644,10 +657,13 @@ def test_with_columns_generate_nodes_select_append_mode(): def df_as_pandas(df: DataFrame) -> pd.DataFrame: return df.toPandas() - nodes = dec.generate_nodes(df_as_pandas, {}) + dummy_node = node.Node.from_fn(df_as_pandas) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas) + nodes_by_names = {n.name: n for n in nodes} assert set(nodes_by_names.keys()) == { - "df_as_pandas", "df_as_pandas._select", "df_as_pandas.a", "df_as_pandas.b", @@ -668,11 +684,13 @@ def test_with_columns_generate_nodes_select_mode_select(): def df_as_pandas(df: DataFrame) -> pd.DataFrame: return df.toPandas() - nodes = dec.generate_nodes(df_as_pandas, {}) + dummy_node = node.Node.from_fn(df_as_pandas) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas) nodes_by_names = {n.name: n for n in nodes} assert set(nodes_by_names.keys()) == { "df_as_pandas.c", - "df_as_pandas", "df_as_pandas._select", } @@ -689,9 +707,16 @@ def test_with_columns_generate_nodes_specify_namespace(): def df_as_pandas(df: DataFrame) -> pd.DataFrame: return df.toPandas() - nodes = dec.generate_nodes(df_as_pandas, {}) + dummy_node = node.Node.from_fn(df_as_pandas) + injectable_params = NodeInjector.find_injectable_params([dummy_node]) + + nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas) nodes_by_names = {n.name: n for n in nodes} - assert set(nodes_by_names.keys()) == {"foo.a", "foo.b", "foo.c", "df_as_pandas"} + assert set(nodes_by_names.keys()) == { + "foo.a", + "foo.b", + "foo.c", + } def test__format_pandas_udf(): diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py index 7b3ae7a91..e9b76686c 100644 --- a/tests/function_modifiers/test_recursive.py +++ b/tests/function_modifiers/test_recursive.py @@ -5,6 +5,7 @@ import pytest +import hamilton from hamilton import ad_hoc_utils, graph from hamilton.function_modifiers import ( InvalidDecoratorException, @@ -16,7 +17,7 @@ ) from hamilton.function_modifiers.base import NodeTransformer from hamilton.function_modifiers.dependencies import source -from hamilton.function_modifiers.recursive import _validate_config_inputs +from hamilton.function_modifiers.recursive import _validate_config_inputs, with_columns_base import tests.resources.reuse_subdag @@ -539,3 +540,16 @@ def test_recursive_validate_config_inputs_happy(config, inputs): def test_recursive_validate_config_inputs_sad(config, inputs): with pytest.raises(InvalidDecoratorException): _validate_config_inputs(config, inputs) + + +def dummy_fn_with_columns(col_1: int) -> int: + return col_1 + 100 + + +def test_columns_and_subdag_nodes_do_not_clash(): + node_a = hamilton.node.Node.from_fn(dummy_fn_with_columns, name="a") + node_b = hamilton.node.Node.from_fn(dummy_fn_with_columns, name="a") + node_c = hamilton.node.Node.from_fn(dummy_fn_with_columns, name="c") + + assert not with_columns_base.contains_duplicates([node_a, node_c]) + assert with_columns_base.contains_duplicates([node_a, node_b, node_c])