diff --git a/.ci/test.sh b/.ci/test.sh index ab5f13661..c71dfb1d8 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -51,6 +51,13 @@ if [[ ${TASK} == "vaex" ]]; then exit 0 fi +if [[ ${TASK} == "narwhals" ]]; then + pip install -e . + pip install polars pandas narwhals + pytest plugin_tests/h_narwhals + exit 0 +fi + if [[ ${TASK} == "tests" ]]; then pip install . pytest \ diff --git a/.circleci/config.yml b/.circleci/config.yml index dfa860cb8..36f667c9b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -155,3 +155,21 @@ workflows: name: integrations-py312 python-version: '3.12' task: integrations + - test: + requires: + - check_for_changes + name: narwhals-py39 + python-version: '3.9' + task: narwhals + - test: + requires: + - check_for_changes + name: narwhals-py310 + python-version: '3.10' + task: narwhals + - test: + requires: + - check_for_changes + name: narwhals-py311 + python-version: '3.11' + task: narwhals diff --git a/docs/integrations/index.rst b/docs/integrations/index.rst index 3f5c68d6e..5ef802111 100644 --- a/docs/integrations/index.rst +++ b/docs/integrations/index.rst @@ -26,3 +26,4 @@ This section showcases how Hamilton integrates with popular frameworks. Slack Spark Vaex + Narwhals diff --git a/examples/narwhals/README.md b/examples/narwhals/README.md new file mode 100644 index 000000000..12e2d148e --- /dev/null +++ b/examples/narwhals/README.md @@ -0,0 +1,28 @@ +# Narwhals + +[Narwhals](https://narwhals-dev.github.io/narwhals/) is a library that aims +to unify expression across dataframe libraries. It is meant to be lightweight +and focuses on python first dataframe libraries. + +This examples shows how you can write dataframe agnostic code +and then load up a pandas or polars data to then use with it. + +## Running the example + +You can run the example doing: + +```bash +# cd examples/narwhals/ +python example.py +``` +This will run both variants one after the other. + +or running the notebook: + +```bash +# cd examples/narwhals +jupyter notebook # pip install jupyter if you don't have it +``` +Or you can open up the notebook in Colab: + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/narwhals/notebook.ipynb) diff --git a/examples/narwhals/example.png b/examples/narwhals/example.png new file mode 100644 index 000000000..765d0c2a4 Binary files /dev/null and b/examples/narwhals/example.png differ diff --git a/examples/narwhals/example.py b/examples/narwhals/example.py new file mode 100644 index 000000000..2fb16b06e --- /dev/null +++ b/examples/narwhals/example.py @@ -0,0 +1,70 @@ +import narwhals as nw +import pandas as pd +import polars as pl + +from hamilton.function_modifiers import config, tag + + +@config.when(load="pandas") +def df__pandas() -> nw.DataFrame: + return pd.DataFrame({"a": [1, 1, 2, 2, 3], "b": [4, 5, 6, 7, 8]}) + + +@config.when(load="pandas") +def series__pandas() -> nw.Series: + return pd.Series([1, 3]) + + +@config.when(load="polars") +def df__polars() -> nw.DataFrame: + return pl.DataFrame({"a": [1, 1, 2, 2, 3], "b": [4, 5, 6, 7, 8]}) + + +@config.when(load="polars") +def series__polars() -> nw.Series: + return pl.Series([1, 3]) + + +@tag(nw_kwargs=["eager_only"]) +def example1(df: nw.DataFrame, series: nw.Series, col_name: str) -> int: + return df.filter(nw.col(col_name).is_in(series.to_numpy())).shape[0] + + +def group_by_mean(df: nw.DataFrame) -> nw.DataFrame: + return df.group_by("a").agg(nw.col("b").mean()).sort("a") + + +if __name__ == "__main__": + import __main__ as example + + from hamilton import base, driver + from hamilton.plugins import h_narwhals, h_polars + + # pandas + dr = ( + driver.Builder() + .with_config({"load": "pandas"}) + .with_modules(example) + .with_adapters( + h_narwhals.NarwhalsAdapter(), + h_narwhals.NarwhalsDataFrameResultBuilder(base.PandasDataFrameResult()), + ) + .build() + ) + r = dr.execute([example.group_by_mean, example.example1], inputs={"col_name": "a"}) + print(r) + + # polars + dr = ( + driver.Builder() + .with_config({"load": "polars"}) + .with_modules(example) + .with_adapters( + h_narwhals.NarwhalsAdapter(), + h_narwhals.NarwhalsDataFrameResultBuilder(h_polars.PolarsDataFrameResult()), + ) + .build() + ) + r = dr.execute([example.group_by_mean, example.example1], inputs={"col_name": "a"}) + print(r) + dr.display_all_functions("example.png") diff --git a/examples/narwhals/notebook.ipynb b/examples/narwhals/notebook.ipynb new file mode 100644 index 000000000..d6283a859 --- /dev/null +++ b/examples/narwhals/notebook.ipynb @@ -0,0 +1,340 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": "!pip install 'sf-hamilton[visualization]' pandas polars" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# run me in google colab\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/narwhals/notebook.ipynb)" + ], + "id": "ce17944a48a226a9" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-01T19:02:16.560492Z", + "start_time": "2024-07-01T19:02:06.001758Z" + } + }, + "cell_type": "code", + "source": "%load_ext hamilton.plugins.jupyter_magic", + "id": "a4897501e00ed4e2", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cannot import name 'PolarsDataType' from 'polars' (/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/polars/__init__.py)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", + " warnings.warn(\n" + ] + } + ], + "execution_count": 2 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-01T19:04:42.572389Z", + "start_time": "2024-07-01T19:04:42.567211Z" + } + }, + "cell_type": "code", + "source": [ + "config = {\n", + " \"mode\": \"pandas\"\n", + "}\n", + "from hamilton import driver\n", + "builder = driver.Builder()" + ], + "id": "6f0290a44e113076", + "outputs": [], + "execution_count": 8 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-01T19:06:01.729149Z", + "start_time": "2024-07-01T19:06:01.052739Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module example --display --config '{\"mode\":\"pandas\"}'\n", + "\n", + "import narwhals as nw\n", + "import pandas as pd\n", + "import polars as pl\n", + "\n", + "from hamilton.function_modifiers import config, tag\n", + "\n", + "\n", + "@config.when(load=\"pandas\")\n", + "def df__pandas() -> nw.DataFrame:\n", + " return pd.DataFrame({\"a\": [1, 1, 2, 2, 3], \"b\": [4, 5, 6, 7, 8]})\n", + "\n", + "\n", + "@config.when(load=\"pandas\")\n", + "def series__pandas() -> nw.Series:\n", + " return pd.Series([1, 3])\n", + "\n", + "\n", + "@config.when(load=\"polars\")\n", + "def df__polars() -> nw.DataFrame:\n", + " return pl.DataFrame({\"a\": [1, 1, 2, 2, 3], \"b\": [4, 5, 6, 7, 8]})\n", + "\n", + "\n", + "@config.when(load=\"polars\")\n", + "def series__polars() -> nw.Series:\n", + " return pl.Series([1, 3])\n", + "\n", + "\n", + "@tag(nw_kwargs=[\"eager_only\"])\n", + "def example1(df: nw.DataFrame, series: nw.Series, col_name: str) -> int:\n", + " return df.filter(nw.col(col_name).is_in(series.to_numpy())).shape[0]\n", + "\n", + "\n", + "def group_by_mean(df: nw.DataFrame) -> nw.DataFrame:\n", + " return df.group_by(\"a\").agg(nw.col(\"b\").mean()).sort(\"a\")\n" + ], + "id": "5c57c8bad9d004cd", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nmode\n\n\n\nmode\npandas\n\n\n\ngroup_by_mean\n\ngroup_by_mean\nDataFrame\n\n\n\nexample1\n\nexample1\nint\n\n\n\n_group_by_mean_inputs\n\ndf\nDataFrame\n\n\n\n_group_by_mean_inputs->group_by_mean\n\n\n\n\n\n_example1_inputs\n\nseries\nSeries\ncol_name\nstr\ndf\nDataFrame\n\n\n\n_example1_inputs->example1\n\n\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 12 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-01T19:08:20.197966Z", + "start_time": "2024-07-01T19:08:20.151820Z" + } + }, + "cell_type": "code", + "source": [ + "from hamilton import base, driver\n", + "from hamilton.plugins import h_narwhals, h_polars\n", + "# pandas\n", + "dr = (\n", + " driver.Builder()\n", + " .with_config({\"load\": \"pandas\"})\n", + " .with_modules(example)\n", + " .with_adapters(\n", + " h_narwhals.NarwhalsAdapter(),\n", + " h_narwhals.NarwhalsDataFrameResultBuilder(base.PandasDataFrameResult()),\n", + " )\n", + " .build()\n", + ")\n", + "result = dr.execute([example.group_by_mean, example.example1], inputs={\"col_name\": \"a\"})\n", + "result" + ], + "id": "4ec491ce248b32ec", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: a single pandas index was found, but there are also 1 outputs without an index. Please check whether the dataframe created matches what what you expect to happen.\n" + ] + }, + { + "data": { + "text/plain": [ + " group_by_mean.a group_by_mean.b example1\n", + "0 1 4.5 3\n", + "1 2 6.5 3\n", + "2 3 8.0 3" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
group_by_mean.agroup_by_mean.bexample1
014.53
126.53
238.03
\n", + "
" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 18 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-01T19:08:25.361471Z", + "start_time": "2024-07-01T19:08:25.322417Z" + } + }, + "cell_type": "code", + "source": [ + "# polars\n", + "dr = (\n", + " driver.Builder()\n", + " .with_config({\"load\": \"polars\"})\n", + " .with_modules(example)\n", + " .with_adapters(\n", + " h_narwhals.NarwhalsAdapter(),\n", + " h_narwhals.NarwhalsDataFrameResultBuilder(h_polars.PolarsDataFrameResult()),\n", + " )\n", + " .build()\n", + ")\n", + "result= dr.execute([example.group_by_mean, example.example1], inputs={\"col_name\": \"a\"})\n", + "result" + ], + "id": "b9e65f6b29a58a5d", + "outputs": [ + { + "data": { + "text/plain": [ + "shape: (3, 2)\n", + "┌───────────────┬──────────┐\n", + "│ group_by_mean ┆ example1 │\n", + "│ --- ┆ --- │\n", + "│ struct[2] ┆ i32 │\n", + "╞═══════════════╪══════════╡\n", + "│ {1,4.5} ┆ 3 │\n", + "│ {2,6.5} ┆ 3 │\n", + "│ {3,8.0} ┆ 3 │\n", + "└───────────────┴──────────┘" + ], + "text/html": [ + "
\n", + "shape: (3, 2)
group_by_meanexample1
struct[2]i32
{1,4.5}3
{2,6.5}3
{3,8.0}3
" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 19 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-01T19:07:42.534409Z", + "start_time": "2024-07-01T19:07:41.961806Z" + } + }, + "cell_type": "code", + "source": "dr.display_all_functions()", + "id": "c17d7b45b69cee61", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nload\n\n\n\nload\npolars\n\n\n\ngroup_by_mean\n\ngroup_by_mean\nDataFrame\n\n\n\nseries\n\nseries: load\nSeries\n\n\n\nexample1\n\nexample1\nint\n\n\n\nseries->example1\n\n\n\n\n\ndf\n\ndf: load\nDataFrame\n\n\n\ndf->group_by_mean\n\n\n\n\n\ndf->example1\n\n\n\n\n\n_example1_inputs\n\ncol_name\nstr\n\n\n\n_example1_inputs->example1\n\n\n\n\n\nconfig\n\n\n\nconfig\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 15 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "db8cb54bc64bbd27" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/narwhals/requirements.txt b/examples/narwhals/requirements.txt new file mode 100644 index 000000000..4c5cbc75a --- /dev/null +++ b/examples/narwhals/requirements.txt @@ -0,0 +1,4 @@ +narwhals +pandas +polars +sf-hamilton[visualization] diff --git a/hamilton/plugins/h_narwhals.py b/hamilton/plugins/h_narwhals.py new file mode 100644 index 000000000..7274ae17e --- /dev/null +++ b/hamilton/plugins/h_narwhals.py @@ -0,0 +1,62 @@ +from typing import Any, Dict, Optional, Type, Union + +import narwhals as nw + +from hamilton.lifecycle import api + + +class NarwhalsAdapter(api.NodeExecutionMethod): + + def run_to_execute_node( + self, + *, + node_name: str, + node_tags: Dict[str, Any], + node_callable: Any, + node_kwargs: Dict[str, Any], + task_id: Optional[str], + **future_kwargs: Any, + ) -> Any: + """This method is responsible for executing the node and returning the result. + + :param node_name: Name of the node. + :param node_tags: Tags of the node. + :param node_callable: Callable of the node. + :param node_kwargs: Keyword arguments to pass to the node. + :param task_id: The ID of the task, none if not in a task-based environment + :param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility + :return: The result of the node execution -- up to you to return this. + """ + nw_kwargs = {} + if "nw_kwargs" in node_tags: + nw_kwargs = {k: True for k in node_tags["nw_kwargs"]} + nw_func = nw.narwhalify(node_callable, **nw_kwargs) + return nw_func(**node_kwargs) + + +class NarwhalsDataFrameResultBuilder(api.ResultBuilder): + """Builds the result. It unwraps the narwhals parts of it and delegates.""" + + def __init__(self, result_builder: Union[api.ResultBuilder, api.LegacyResultMixin]): + self.result_builder = result_builder + + def build_result(self, **outputs: Any) -> Any: + """Given a set of outputs, build the result. + + :param outputs: the outputs from the execution of the graph. + :return: the result of the execution of the graph. + """ + de_narwhaled_outputs = {} + for key, value in outputs.items(): + if isinstance(value, (nw.DataFrame, nw.Series)): + de_narwhaled_outputs[key] = nw.to_native(value) + else: + de_narwhaled_outputs[key] = value + + return self.result_builder.build_result(**de_narwhaled_outputs) + + def output_type(self) -> Type: + """Returns the output type of this result builder + :return: the type that this creates + """ + return self.result_builder.output_type() diff --git a/plugin_tests/h_narwhals/__init__.py b/plugin_tests/h_narwhals/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugin_tests/h_narwhals/conftest.py b/plugin_tests/h_narwhals/conftest.py new file mode 100644 index 000000000..bc5ef5b5a --- /dev/null +++ b/plugin_tests/h_narwhals/conftest.py @@ -0,0 +1,4 @@ +from hamilton import telemetry + +# disable telemetry for all tests! +telemetry.disable_telemetry() diff --git a/plugin_tests/h_narwhals/resources b/plugin_tests/h_narwhals/resources new file mode 120000 index 000000000..1e58ceb6d --- /dev/null +++ b/plugin_tests/h_narwhals/resources @@ -0,0 +1 @@ +../../tests/resources \ No newline at end of file diff --git a/plugin_tests/h_narwhals/test_h_narwhals.py b/plugin_tests/h_narwhals/test_h_narwhals.py new file mode 100644 index 000000000..adf588f6e --- /dev/null +++ b/plugin_tests/h_narwhals/test_h_narwhals.py @@ -0,0 +1,50 @@ +import json + +from hamilton import base, driver +from hamilton.plugins import h_narwhals, h_polars + +from .resources import narwhals_example + + +def test_pandas(): + # pandas + dr = ( + driver.Builder() + .with_config({"load": "pandas"}) + .with_modules(narwhals_example) + .with_adapters( + h_narwhals.NarwhalsAdapter(), + h_narwhals.NarwhalsDataFrameResultBuilder(base.PandasDataFrameResult()), + ) + .build() + ) + r = dr.execute( + [narwhals_example.group_by_mean, narwhals_example.example1], inputs={"col_name": "a"} + ) + assert r.to_dict() == { + "example1": {0: 3, 1: 3, 2: 3}, + "group_by_mean.a": {0: 1, 1: 2, 2: 3}, + "group_by_mean.b": {0: 4.5, 1: 6.5, 2: 8.0}, + } + + +def test_polars(): + # polars + dr = ( + driver.Builder() + .with_config({"load": "polars"}) + .with_modules(narwhals_example) + .with_adapters( + h_narwhals.NarwhalsAdapter(), + h_narwhals.NarwhalsDataFrameResultBuilder(h_polars.PolarsDataFrameResult()), + ) + .build() + ) + r = dr.execute( + [narwhals_example.group_by_mean, narwhals_example.example1], inputs={"col_name": "a"} + ) + assert json.loads(r.write_json()) == [ + {"example1": 3, "group_by_mean": {"a": 1, "b": 4.5}}, + {"example1": 3, "group_by_mean": {"a": 2, "b": 6.5}}, + {"example1": 3, "group_by_mean": {"a": 3, "b": 8.0}}, + ] diff --git a/tests/resources/narwhals_example.py b/tests/resources/narwhals_example.py new file mode 100644 index 000000000..e32037a20 --- /dev/null +++ b/tests/resources/narwhals_example.py @@ -0,0 +1,34 @@ +import narwhals as nw +import pandas as pd +import polars as pl + +from hamilton.function_modifiers import config, tag + + +@config.when(load="pandas") +def df__pandas() -> nw.DataFrame: + return pd.DataFrame({"a": [1, 1, 2, 2, 3], "b": [4, 5, 6, 7, 8]}) + + +@config.when(load="pandas") +def series__pandas() -> nw.Series: + return pd.Series([1, 3]) + + +@config.when(load="polars") +def df__polars() -> nw.DataFrame: + return pl.DataFrame({"a": [1, 1, 2, 2, 3], "b": [4, 5, 6, 7, 8]}) + + +@config.when(load="polars") +def series__polars() -> nw.Series: + return pl.Series([1, 3]) + + +@tag(nw_kwargs=["eager_only"]) +def example1(df: nw.DataFrame, series: nw.Series, col_name: str) -> int: + return df.filter(nw.col(col_name).is_in(series.to_numpy())).shape[0] + + +def group_by_mean(df: nw.DataFrame) -> nw.DataFrame: + return df.group_by("a").agg(nw.col("b").mean()).sort("a")