diff --git a/examples/datasets/airbnb_multicity.ipynb b/examples/datasets/airbnb_multicity.ipynb new file mode 100644 index 00000000..ce2fdbba --- /dev/null +++ b/examples/datasets/airbnb_multicity.ipynb @@ -0,0 +1,140 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from srai.datasets import AirbnbMulticityDataset\n", + "\n", + "%load_ext dotenv\n", + "%dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "airbnb_multicity = AirbnbMulticityDataset()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Loading default version" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hf_token = os.getenv(\"HF_TOKEN\")\n", + "gdf_train, gdf_test = airbnb_multicity.load(hf_token=hf_token)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Getting aggregated hexagon values " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_h3, test_h3 = airbnb_multicity.get_h3_with_labels(\n", + " resolution=8, train_gdf=gdf_train, test_gdf=gdf_test\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_h3.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_h3.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Loading raw, full data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_all, _ = airbnb_multicity.load(hf_token=hf_token, version=\"all\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_all.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/datasets/chicago_crime.ipynb b/examples/datasets/chicago_crime.ipynb new file mode 100644 index 00000000..d96fc1df --- /dev/null +++ b/examples/datasets/chicago_crime.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from srai.datasets import ChicagoCrimeDataset\n", + "\n", + "%load_ext dotenv\n", + "%dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chicago_crime = ChicagoCrimeDataset()\n", + "hf_token = os.getenv(\"HF_TOKEN\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load default data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train, gdf_test = chicago_crime.load(hf_token=hf_token)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Getting target values for h3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_h3, test_h3 = chicago_crime.get_h3_with_labels(\n", + " resolution=9, train_gdf=gdf_train, test_gdf=gdf_test\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_h3.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_h3.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load data from 2022" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_2022, _ = chicago_crime.load(hf_token=hf_token, version=\"2022\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_2022.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/datasets/house_sales_in_king_county.ipynb b/examples/datasets/house_sales_in_king_county.ipynb new file mode 100644 index 00000000..3472091a --- /dev/null +++ b/examples/datasets/house_sales_in_king_county.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from srai.datasets import HouseSalesInKingCountyDataset\n", + "\n", + "%load_ext dotenv\n", + "%dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hskc_dataset = HouseSalesInKingCountyDataset()\n", + "hf_token = os.getenv(\"HF_TOKEN\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load default version of dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train, gdf_test = hskc_dataset.load(hf_token=hf_token)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Getting the h3 with target values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_h3, test_h3 = hskc_dataset.get_h3_with_labels(\n", + " resolution=9, train_gdf=gdf_train, test_gdf=gdf_test\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load raw version of dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_all, _ = hskc_dataset.load(hf_token=hf_token, version=\"all\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_all.head()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/datasets/philadelphia_crime.ipynb b/examples/datasets/philadelphia_crime.ipynb new file mode 100644 index 00000000..a2880f11 --- /dev/null +++ b/examples/datasets/philadelphia_crime.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from srai.datasets import PhiladelphiaCrimeDataset\n", + "\n", + "%load_ext dotenv\n", + "%dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hf_token = os.getenv(\"HF_TOKEN\")\n", + "philadelphia_crime = PhiladelphiaCrimeDataset()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get data using .load() method -> a default version 'res_8' " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train, gdf_test = philadelphia_crime.load(hf_token=hf_token)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Getting the h3 with target values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_h3, test_h3 = philadelphia_crime.get_h3_with_labels(\n", + " resolution=8, train_gdf=gdf_train, test_gdf=gdf_test\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get data from 2013 year." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_2013, _ = philadelphia_crime.load(hf_token=hf_token, version=\"2013\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_2013.head()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/datasets/police_department_incidents.ipynb b/examples/datasets/police_department_incidents.ipynb new file mode 100644 index 00000000..3fdf7b8e --- /dev/null +++ b/examples/datasets/police_department_incidents.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from srai.datasets import PoliceDepartmentIncidentsDataset\n", + "\n", + "%load_ext dotenv\n", + "%dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "police_department_incidents = PoliceDepartmentIncidentsDataset()\n", + "hf_token = os.getenv(\"HF_TOKEN\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Default config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train, gdf_test = police_department_incidents.load(hf_token=\"hf_token\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf_train.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Getting the h3 with target values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_h3, test_h3 = police_department_incidents.get_h3_with_labels(\n", + " resolution=8, train_gdf=gdf_train, test_gdf=gdf_test\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pdm.lock b/pdm.lock index 4be14941..7541312c 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "dev", "docs", "gtfs", "license", "lint", "osm", "plotting", "test", "torch", "visualization", "voronoi"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:74755c36478585ebfc0977c1e0cf7bbbcbeb987d8aeaa6607a5f096883697a5f" +content_hash = "sha256:3d53efa15ca661295e8b218b498960d739697ad107e7f5238dc586e6eaa2b1bb" [[metadata.targets]] requires_python = ">=3.9" @@ -15,7 +15,7 @@ name = "aiohappyeyeballs" version = "2.4.4" requires_python = ">=3.8" summary = "Happy Eyeballs for asyncio" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] files = [ {file = "aiohappyeyeballs-2.4.4-py3-none-any.whl", hash = "sha256:a980909d50efcd44795c4afeca523296716d50cd756ddca6af8c65b996e27de8"}, {file = "aiohappyeyeballs-2.4.4.tar.gz", hash = "sha256:5fdd7d87889c63183afc18ce9271f9b0a7d32c2303e394468dd45d514a757745"}, @@ -26,7 +26,7 @@ name = "aiohttp" version = "3.11.11" requires_python = ">=3.9" summary = "Async http client/server framework (asyncio)" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] dependencies = [ "aiohappyeyeballs>=2.3.0", "aiosignal>=1.1.2", @@ -121,7 +121,7 @@ name = "aiosignal" version = "1.3.2" requires_python = ">=3.9" summary = "aiosignal: a list of registered asynchronous callbacks" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] dependencies = [ "frozenlist>=1.1.0", ] @@ -130,6 +130,18 @@ files = [ {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.9.3" +summary = "ANTLR 4.9.3 runtime for Python 3.7" +groups = ["default"] +dependencies = [ + "typing; python_version < \"3.5\"", +] +files = [ + {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"}, +] + [[package]] name = "anywidget" version = "0.9.13" @@ -505,7 +517,7 @@ name = "async-timeout" version = "5.0.1" requires_python = ">=3.8" summary = "Timeout context manager for asyncio programs" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] marker = "python_version < \"3.11\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, @@ -517,7 +529,7 @@ name = "attrs" version = "23.2.0" requires_python = ">=3.7" summary = "Classes Without Boilerplate" -groups = ["all", "docs", "license", "torch"] +groups = ["default", "all", "docs", "license", "torch"] dependencies = [ "importlib-metadata; python_version < \"3.8\"", ] @@ -950,6 +962,33 @@ files = [ {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, ] +[[package]] +name = "datasets" +version = "3.2.0" +requires_python = ">=3.9.0" +summary = "HuggingFace community-driven open-source library of datasets" +groups = ["default"] +dependencies = [ + "aiohttp", + "dill<0.3.9,>=0.3.0", + "filelock", + "fsspec[http]<=2024.9.0,>=2023.1.0", + "huggingface-hub>=0.23.0", + "multiprocess<0.70.17", + "numpy>=1.17", + "packaging", + "pandas", + "pyarrow>=15.0.0", + "pyyaml>=5.1", + "requests>=2.32.2", + "tqdm>=4.66.3", + "xxhash", +] +files = [ + {file = "datasets-3.2.0-py3-none-any.whl", hash = "sha256:f3d2ba2698b7284a4518019658596a6a8bc79f31e51516524249d6c59cf0fe2a"}, + {file = "datasets-3.2.0.tar.gz", hash = "sha256:9a6e1a356052866b5dbdd9c9eedb000bf3fc43d986e3584d9b028f4976937229"}, +] + [[package]] name = "debugpy" version = "1.8.11" @@ -1005,13 +1044,13 @@ files = [ [[package]] name = "dill" -version = "0.3.9" +version = "0.3.8" requires_python = ">=3.8" summary = "serialize all of Python" groups = ["default"] files = [ - {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, - {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, ] [[package]] @@ -1168,7 +1207,7 @@ name = "filelock" version = "3.16.1" requires_python = ">=3.8" summary = "A platform independent file lock." -groups = ["all", "lint", "test", "torch"] +groups = ["default", "all", "lint", "test", "torch"] files = [ {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, @@ -1248,7 +1287,7 @@ name = "frozenlist" version = "1.5.0" requires_python = ">=3.8" summary = "A list-like structure which implements collections.abc.MutableSequence" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] files = [ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"}, {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"}, @@ -1331,29 +1370,29 @@ files = [ [[package]] name = "fsspec" -version = "2024.12.0" +version = "2024.9.0" requires_python = ">=3.8" summary = "File-system specification" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] files = [ - {file = "fsspec-2024.12.0-py3-none-any.whl", hash = "sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2"}, - {file = "fsspec-2024.12.0.tar.gz", hash = "sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f"}, + {file = "fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b"}, + {file = "fsspec-2024.9.0.tar.gz", hash = "sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8"}, ] [[package]] name = "fsspec" -version = "2024.12.0" +version = "2024.9.0" extras = ["http"] requires_python = ">=3.8" summary = "File-system specification" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] dependencies = [ "aiohttp!=4.0.0a0,!=4.0.0a1", - "fsspec==2024.12.0", + "fsspec==2024.9.0", ] files = [ - {file = "fsspec-2024.12.0-py3-none-any.whl", hash = "sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2"}, - {file = "fsspec-2024.12.0.tar.gz", hash = "sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f"}, + {file = "fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b"}, + {file = "fsspec-2024.9.0.tar.gz", hash = "sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8"}, ] [[package]] @@ -1622,6 +1661,26 @@ files = [ {file = "haversine-2.9.0.tar.gz", hash = "sha256:1103d7e1f0f108c25b31b63452c54d9d6f29389a70de7dd75fd4b908329b6fcf"}, ] +[[package]] +name = "huggingface-hub" +version = "0.27.0" +requires_python = ">=3.8.0" +summary = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +groups = ["default"] +dependencies = [ + "filelock", + "fsspec>=2023.5.0", + "packaging>=20.9", + "pyyaml>=5.1", + "requests", + "tqdm>=4.42.1", + "typing-extensions>=3.7.4.3", +] +files = [ + {file = "huggingface_hub-0.27.0-py3-none-any.whl", hash = "sha256:8f2e834517f1f1ddf1ecc716f91b120d7333011b7485f665a9a412eacb1a2a81"}, + {file = "huggingface_hub-0.27.0.tar.gz", hash = "sha256:902cce1a1be5739f5589e560198a65a8edcfd3b830b1666f36e4b961f0454fac"}, +] + [[package]] name = "identify" version = "2.6.3" @@ -2656,7 +2715,7 @@ name = "multidict" version = "6.1.0" requires_python = ">=3.8" summary = "multidict implementation" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] dependencies = [ "typing-extensions>=4.1.0; python_version < \"3.11\"", ] @@ -2740,6 +2799,27 @@ files = [ {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, ] +[[package]] +name = "multiprocess" +version = "0.70.16" +requires_python = ">=3.8" +summary = "better multiprocessing and multithreading in Python" +groups = ["default"] +dependencies = [ + "dill>=0.3.8", +] +files = [ + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, +] + [[package]] name = "mypy" version = "1.14.0" @@ -3132,6 +3212,22 @@ files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "omegaconf" +version = "2.3.0" +requires_python = ">=3.6" +summary = "A flexible configuration library" +groups = ["default"] +dependencies = [ + "PyYAML>=5.1.0", + "antlr4-python3-runtime==4.9.*", + "dataclasses; python_version == \"3.6\"", +] +files = [ + {file = "omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b"}, + {file = "omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7"}, +] + [[package]] name = "osmnx" version = "2.0.0" @@ -3509,7 +3605,7 @@ name = "propcache" version = "0.2.1" requires_python = ">=3.9" summary = "Accelerated property cache" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] files = [ {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6b3f39a85d671436ee3d12c017f8fdea38509e4f25b28eb25877293c98c243f6"}, {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d51fbe4285d5db5d92a929e3e21536ea3dd43732c5b177c7ef03f918dff9f2"}, @@ -4055,6 +4151,17 @@ files = [ {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] +[[package]] +name = "python-dotenv" +version = "1.0.1" +requires_python = ">=3.8" +summary = "Read key-value pairs from a .env file and set them as environment variables" +groups = ["default"] +files = [ + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, +] + [[package]] name = "pytorch-lightning" version = "2.5.0.post0" @@ -4114,7 +4221,7 @@ name = "pyyaml" version = "6.0.2" requires_python = ">=3.8" summary = "YAML parser and emitter for Python" -groups = ["all", "docs", "lint", "torch", "voronoi"] +groups = ["default", "all", "docs", "lint", "torch", "voronoi"] files = [ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, @@ -5450,6 +5557,101 @@ files = [ {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, ] +[[package]] +name = "xxhash" +version = "3.5.0" +requires_python = ">=3.7" +summary = "Python binding for xxHash" +groups = ["default"] +files = [ + {file = "xxhash-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ece616532c499ee9afbb83078b1b952beffef121d989841f7f4b3dc5ac0fd212"}, + {file = "xxhash-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3171f693dbc2cef6477054a665dc255d996646b4023fe56cb4db80e26f4cc520"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c5d3e570ef46adaf93fc81b44aca6002b5a4d8ca11bd0580c07eac537f36680"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7cb29a034301e2982df8b1fe6328a84f4b676106a13e9135a0d7e0c3e9f806da"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0d307d27099bb0cbeea7260eb39ed4fdb99c5542e21e94bb6fd29e49c57a23"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0342aafd421795d740e514bc9858ebddfc705a75a8c5046ac56d85fe97bf196"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dbbd9892c5ebffeca1ed620cf0ade13eb55a0d8c84e0751a6653adc6ac40d0c"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4cc2d67fdb4d057730c75a64c5923abfa17775ae234a71b0200346bfb0a7f482"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ec28adb204b759306a3d64358a5e5c07d7b1dd0ccbce04aa76cb9377b7b70296"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1328f6d8cca2b86acb14104e381225a3d7b42c92c4b86ceae814e5c400dbb415"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8d47ebd9f5d9607fd039c1fbf4994e3b071ea23eff42f4ecef246ab2b7334198"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b96d559e0fcddd3343c510a0fe2b127fbff16bf346dd76280b82292567523442"}, + {file = "xxhash-3.5.0-cp310-cp310-win32.whl", hash = "sha256:61c722ed8d49ac9bc26c7071eeaa1f6ff24053d553146d5df031802deffd03da"}, + {file = "xxhash-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:9bed5144c6923cc902cd14bb8963f2d5e034def4486ab0bbe1f58f03f042f9a9"}, + {file = "xxhash-3.5.0-cp310-cp310-win_arm64.whl", hash = "sha256:893074d651cf25c1cc14e3bea4fceefd67f2921b1bb8e40fcfeba56820de80c6"}, + {file = "xxhash-3.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02c2e816896dc6f85922ced60097bcf6f008dedfc5073dcba32f9c8dd786f3c1"}, + {file = "xxhash-3.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6027dcd885e21581e46d3c7f682cfb2b870942feeed58a21c29583512c3f09f8"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1308fa542bbdbf2fa85e9e66b1077eea3a88bef38ee8a06270b4298a7a62a166"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c28b2fdcee797e1c1961cd3bcd3d545cab22ad202c846235197935e1df2f8ef7"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:924361811732ddad75ff23e90efd9ccfda4f664132feecb90895bade6a1b4623"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89997aa1c4b6a5b1e5b588979d1da048a3c6f15e55c11d117a56b75c84531f5a"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:685c4f4e8c59837de103344eb1c8a3851f670309eb5c361f746805c5471b8c88"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbd2ecfbfee70bc1a4acb7461fa6af7748ec2ab08ac0fa298f281c51518f982c"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:25b5a51dc3dfb20a10833c8eee25903fd2e14059e9afcd329c9da20609a307b2"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a8fb786fb754ef6ff8c120cb96629fb518f8eb5a61a16aac3a979a9dbd40a084"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a905ad00ad1e1c34fe4e9d7c1d949ab09c6fa90c919860c1534ff479f40fd12d"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:963be41bcd49f53af6d795f65c0da9b4cc518c0dd9c47145c98f61cb464f4839"}, + {file = "xxhash-3.5.0-cp311-cp311-win32.whl", hash = "sha256:109b436096d0a2dd039c355fa3414160ec4d843dfecc64a14077332a00aeb7da"}, + {file = "xxhash-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b702f806693201ad6c0a05ddbbe4c8f359626d0b3305f766077d51388a6bac58"}, + {file = "xxhash-3.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:c4dcb4120d0cc3cc448624147dba64e9021b278c63e34a38789b688fd0da9bf3"}, + {file = "xxhash-3.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:14470ace8bd3b5d51318782cd94e6f94431974f16cb3b8dc15d52f3b69df8e00"}, + {file = "xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:59aa1203de1cb96dbeab595ded0ad0c0056bb2245ae11fac11c0ceea861382b9"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08424f6648526076e28fae6ea2806c0a7d504b9ef05ae61d196d571e5c879c84"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61a1ff00674879725b194695e17f23d3248998b843eb5e933007ca743310f793"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2f2c61bee5844d41c3eb015ac652a0229e901074951ae48581d58bfb2ba01be"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d32a592cac88d18cc09a89172e1c32d7f2a6e516c3dfde1b9adb90ab5df54a6"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70dabf941dede727cca579e8c205e61121afc9b28516752fd65724be1355cc90"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e5d0ddaca65ecca9c10dcf01730165fd858533d0be84c75c327487c37a906a27"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e5b5e16c5a480fe5f59f56c30abdeba09ffd75da8d13f6b9b6fd224d0b4d0a2"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149b7914451eb154b3dfaa721315117ea1dac2cc55a01bfbd4df7c68c5dd683d"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:eade977f5c96c677035ff39c56ac74d851b1cca7d607ab3d8f23c6b859379cab"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa9f547bd98f5553d03160967866a71056a60960be00356a15ecc44efb40ba8e"}, + {file = "xxhash-3.5.0-cp312-cp312-win32.whl", hash = "sha256:f7b58d1fd3551b8c80a971199543379be1cee3d0d409e1f6d8b01c1a2eebf1f8"}, + {file = "xxhash-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0cafd3a2af231b4e113fba24a65d7922af91aeb23774a8b78228e6cd785e3e"}, + {file = "xxhash-3.5.0-cp312-cp312-win_arm64.whl", hash = "sha256:586886c7e89cb9828bcd8a5686b12e161368e0064d040e225e72607b43858ba2"}, + {file = "xxhash-3.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:37889a0d13b0b7d739cfc128b1c902f04e32de17b33d74b637ad42f1c55101f6"}, + {file = "xxhash-3.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:97a662338797c660178e682f3bc180277b9569a59abfb5925e8620fba00b9fc5"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f85e0108d51092bdda90672476c7d909c04ada6923c14ff9d913c4f7dc8a3bc"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2fd827b0ba763ac919440042302315c564fdb797294d86e8cdd4578e3bc7f3"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82085c2abec437abebf457c1d12fccb30cc8b3774a0814872511f0f0562c768c"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07fda5de378626e502b42b311b049848c2ef38784d0d67b6f30bb5008642f8eb"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c279f0d2b34ef15f922b77966640ade58b4ccdfef1c4d94b20f2a364617a493f"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:89e66ceed67b213dec5a773e2f7a9e8c58f64daeb38c7859d8815d2c89f39ad7"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bcd51708a633410737111e998ceb3b45d3dbc98c0931f743d9bb0a209033a326"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3ff2c0a34eae7df88c868be53a8dd56fbdf592109e21d4bfa092a27b0bf4a7bf"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4e28503dccc7d32e0b9817aa0cbfc1f45f563b2c995b7a66c4c8a0d232e840c7"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a6c50017518329ed65a9e4829154626f008916d36295b6a3ba336e2458824c8c"}, + {file = "xxhash-3.5.0-cp313-cp313-win32.whl", hash = "sha256:53a068fe70301ec30d868ece566ac90d873e3bb059cf83c32e76012c889b8637"}, + {file = "xxhash-3.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:80babcc30e7a1a484eab952d76a4f4673ff601f54d5142c26826502740e70b43"}, + {file = "xxhash-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:4811336f1ce11cac89dcbd18f3a25c527c16311709a89313c3acaf771def2d4b"}, + {file = "xxhash-3.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfc8cdd7f33d57f0468b0614ae634cc38ab9202c6957a60e31d285a71ebe0301"}, + {file = "xxhash-3.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e0c48b6300cd0b0106bf49169c3e0536408dfbeb1ccb53180068a18b03c662ab"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe1a92cfbaa0a1253e339ccec42dbe6db262615e52df591b68726ab10338003f"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:33513d6cc3ed3b559134fb307aae9bdd94d7e7c02907b37896a6c45ff9ce51bd"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eefc37f6138f522e771ac6db71a6d4838ec7933939676f3753eafd7d3f4c40bc"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a606c8070ada8aa2a88e181773fa1ef17ba65ce5dd168b9d08038e2a61b33754"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:42eca420c8fa072cc1dd62597635d140e78e384a79bb4944f825fbef8bfeeef6"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:604253b2143e13218ff1ef0b59ce67f18b8bd1c4205d2ffda22b09b426386898"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:6e93a5ad22f434d7876665444a97e713a8f60b5b1a3521e8df11b98309bff833"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:7a46e1d6d2817ba8024de44c4fd79913a90e5f7265434cef97026215b7d30df6"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:30eb2efe6503c379b7ab99c81ba4a779748e3830241f032ab46bd182bf5873af"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c8aa771ff2c13dd9cda8166d685d7333d389fae30a4d2bb39d63ab5775de8606"}, + {file = "xxhash-3.5.0-cp39-cp39-win32.whl", hash = "sha256:5ed9ebc46f24cf91034544b26b131241b699edbfc99ec5e7f8f3d02d6eb7fba4"}, + {file = "xxhash-3.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:220f3f896c6b8d0316f63f16c077d52c412619e475f9372333474ee15133a558"}, + {file = "xxhash-3.5.0-cp39-cp39-win_arm64.whl", hash = "sha256:a7b1d8315d9b5e9f89eb2933b73afae6ec9597a258d52190944437158b49d38e"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:2014c5b3ff15e64feecb6b713af12093f75b7926049e26a580e94dcad3c73d8c"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab81ef75003eda96239a23eda4e4543cedc22e34c373edcaf744e721a163986"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e2febf914ace002132aa09169cc572e0d8959d0f305f93d5828c4836f9bc5a6"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d3a10609c51da2a1c0ea0293fc3968ca0a18bd73838455b5bca3069d7f8e32b"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5a74f23335b9689b66eb6dbe2a931a88fcd7a4c2cc4b1cb0edba8ce381c7a1da"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:531af8845aaadcadf951b7e0c1345c6b9c68a990eeb74ff9acd8501a0ad6a1c9"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ce379bcaa9fcc00f19affa7773084dd09f5b59947b3fb47a1ceb0179f91aaa1"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd1b2281d01723f076df3c8188f43f2472248a6b63118b036e641243656b1b0f"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c770750cc80e8694492244bca7251385188bc5597b6a39d98a9f30e8da984e0"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b150b8467852e1bd844387459aa6fbe11d7f38b56e901f9f3b3e6aba0d660240"}, + {file = "xxhash-3.5.0.tar.gz", hash = "sha256:84f2caddf951c9cbf8dc2e22a89d4ccf5d86391ac6418fe81e3c67d0cf60b45f"}, +] + [[package]] name = "xyzservices" version = "2024.9.0" @@ -5466,7 +5668,7 @@ name = "yarl" version = "1.18.3" requires_python = ">=3.9" summary = "Yet another URL library" -groups = ["all", "torch"] +groups = ["default", "all", "torch"] dependencies = [ "idna>=2.0", "multidict>=4.0", diff --git a/pyproject.toml b/pyproject.toml index 98fc0454..efca0869 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ dependencies = [ "requests", "h3ronpy>=0.20.1", "osmnx>=1.3.0", + "omegaconf", + "python-dotenv", + "datasets", ] requires-python = ">=3.9" readme = "README.md" diff --git a/srai/datasets/__init__.py b/srai/datasets/__init__.py new file mode 100644 index 00000000..9465c99a --- /dev/null +++ b/srai/datasets/__init__.py @@ -0,0 +1,21 @@ +""" +This module contains dataset used to load dataset containing spatial information. + +Datasets can be loaded using .load() method. Some of them may need name of version. +""" + +from ._base import HuggingFaceDataset +from .airbnb_multicity import AirbnbMulticityDataset +from .chicago_crime import ChicagoCrimeDataset +from .house_sales_in_king_county import HouseSalesInKingCountyDataset +from .philadelphia_crime import PhiladelphiaCrimeDataset +from .police_department_incidents import PoliceDepartmentIncidentsDataset + +__all__ = [ + "HuggingFaceDataset", + "AirbnbMulticityDataset", + "HouseSalesInKingCountyDataset", + "PhiladelphiaCrimeDataset", + "ChicagoCrimeDataset", + "PoliceDepartmentIncidentsDataset", +] diff --git a/srai/datasets/_base.py b/srai/datasets/_base.py new file mode 100644 index 00000000..759b8f6a --- /dev/null +++ b/srai/datasets/_base.py @@ -0,0 +1,314 @@ +"""Base classes for Datasets.""" + +import abc +from typing import Optional + +import geopandas as gpd +import h3 +import numpy as np +import pandas as pd +from datasets import load_dataset +from shapely.geometry import Polygon +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import MinMaxScaler + +from srai.regionalizers import H3Regionalizer + + +class HuggingFaceDataset(abc.ABC): + """Abstract class for HuggingFace datasets.""" + + def __init__( + self, + path: str, + version: Optional[str] = None, + type: Optional[str] = None, + numerical_columns: Optional[list[str]] = None, + categorical_columns: Optional[list[str]] = None, + target: Optional[str] = None, + ) -> None: + self.path = path + self.version = version + self.numerical_columns = numerical_columns + self.categorical_columns = categorical_columns + self.target = target + self.type = type + + @abc.abstractmethod + def _preprocessing(self, data: pd.DataFrame, version: Optional[str] = None) -> gpd.GeoDataFrame: + """ + Preprocess the dataset from HuggingFace. + + Args: + data (pd.DataFrame): a dataset to preprocess + version (str, optional): version of dataset + + Returns: + gpd.GeoDataFrame: preprocessed data. + """ + raise NotImplementedError + + def load( + self, hf_token: Optional[str] = None, version: Optional[str] = None + ) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame]]: + """ + Method to load dataset. + + Args: + hf_token (str, optional): If needed, a User Access Token needed to authenticate to + the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used. + Defaults to None. + version (str, optional): version of a dataset + + Returns: + gpd.GeoDataFrame, gpd.Geodataframe | None : Loaded train data and test data if exist. + """ + dataset_name = self.path + version = version or self.version + data = load_dataset(dataset_name, version, token=hf_token, trust_remote_code=True) + train = data["train"].to_pandas() + processed_train = self._preprocessing(train) + if "test" in data: + test = data["test"].to_pandas() + processed_test = self._preprocessing(test) + else: + processed_test = None + + return processed_train, processed_test + + def train_test_split_bucket_regression( + self, + gdf: gpd.GeoDataFrame, + target_column: Optional[str] = None, + resolution: int = 9, + test_size: float = 0.2, + bucket_number: int = 7, + random_state: Optional[int] = None, + ) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: + """Method to generate train and test split from GeoDataFrame, based on the target_column values - its statistic. + + Args: + gdf (gpd.GeoDataFrame): GeoDataFrame on which train, dev, test split will be performed. + target_column (Optional[str], optional): Target column name. If None, split generated on basis of number \ + of points within a hex ov given resolution. + resolution (int, optional): h3 resolution to regionalize data. Defaults to 9. + test_size (float, optional): Percentage of test set. Defaults to 0.2. + bucket_number (int, optional): Bucket number used to stratify target data. Defaults to 7. + random_state (int, optional): Controls the shuffling applied to the data before applying the split. \ + Pass an int for reproducible output across multiple function. Defaults to None. + + Returns: + tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train, Test splits in GeoDataFrames + """ # noqa: E501, W505 + if self.type != "point": + raise ValueError("This split can be performed only on point data type!") + if target_column is None: + # target_column = self.target + target_column = "count" + + gdf_ = gdf.copy() + + if target_column == "count": + regionalizer = H3Regionalizer(resolution=resolution) + regions = regionalizer.transform(gdf) + joined_gdf = gpd.sjoin(gdf, regions, how="left", predicate="within") # noqa: E501 + joined_gdf.rename(columns={"index_right": "h3_index"}, inplace=True) + + averages_hex = joined_gdf.groupby("h3_index").size().reset_index(name=target_column) + gdf_ = regions.merge( + averages_hex, how="inner", left_on="region_id", right_on="h3_index" + ) + gdf_.rename(columns={"h3_index": "region_id"}, inplace=True) + gdf_.index = gdf_["region_id"] + + splits = np.linspace( + 0, 1, num=bucket_number + 1 + ) # generate splits to bucket classification + quantiles = gdf_[target_column].quantile(splits) # compute quantiles + bins = [quantiles[i] for i in splits] + gdf_["bucket"] = pd.cut(gdf_[target_column], bins=bins, include_lowest=True).apply( + lambda x: x.mid + ) # noqa: E501 + + train_indices, test_indices = train_test_split( + range(len(gdf_)), + test_size=test_size, # * 2 multiply for dev set also + stratify=gdf_.bucket, # stratify by bucket value + random_state=random_state, + ) + + # dev_indices, test_indices = train_test_split( + # range(len(test_indices)), + # test_size=0.5, + # stratify=gdf_.iloc[test_indices].bucket, + # ) + train = gdf_.iloc[train_indices] + test = gdf_.iloc[test_indices] + if target_column == "count": + train_hex_indexes = train["region_id"].unique() + test_hex_indexes = test["region_id"].unique() + train = joined_gdf[joined_gdf["h3_index"].isin(train_hex_indexes)] + test = joined_gdf[joined_gdf["h3_index"].isin(test_hex_indexes)] + train = train.drop(columns=["h3_index"]) + test = test.drop(columns=["h3_index"]) + + return train, test # , gdf_.iloc[dev_indices] + + def train_test_split_spatial_points( + self, + gdf: gpd.GeoDataFrame, + test_size: float = 0.2, + resolution: int = 8, # TODO: dodać pole per dataset z h3_train_resolution + resolution_subsampling: int = 1, + random_state: Optional[int] = None, + ) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: + """ + Method to generate train and test split from GeoDataFrame, based on the spatial h3 + resolution. + + Args: + gdf (gpd.GeoDataFrame): GeoDataFrame on which train, dev, test split will be performed. + test_size (float, optional): Percentage of test set.. Defaults to 0.2. + resolution (int, optional): h3 resolution to regionalize data. Defaults to 8. + resolution_subsampling (int, optional): h3 resolution difference to subsample \ + data for stratification. Defaults to 1. + random_state (int, optional): Controls the shuffling applied to the data before applying the split. \ + Pass an int for reproducible output across multiple function. Defaults to None. + + Raises: + ValueError: If type of data is not Points. + + Returns: + tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train, Test splits in GeoDataFrames + """ # noqa: W505, E501, D205 + if self.type != "point": + raise ValueError("This split can be performed only on Points data type!") + gdf_ = gdf.copy() + + regionalizer = H3Regionalizer(resolution=resolution) + regions = regionalizer.transform(gdf_) + + regions.index = regions.index.map( + lambda idx: h3.cell_to_parent(idx, resolution - resolution_subsampling) + ) # get parent h3 region + regions["geometry"] = regions.index.map( + lambda idx: Polygon([(lon, lat) for lat, lon in h3.cell_to_boundary(idx)]) + ) # get localization of h3 region + + joined_gdf = gpd.sjoin(gdf_, regions, how="left", predicate="within") + joined_gdf.rename(columns={"index_right": "h3_index"}, inplace=True) + joined_gdf.drop_duplicates(inplace=True) + + if joined_gdf["h3_index"].isnull().sum() != 0: # handle outliers + joined_gdf.loc[joined_gdf["h3_index"].isnull(), "h3_index"] = "fffffffffffffff" + # set outlier index fffffffffffffff + outlier_indices = joined_gdf["h3_index"].value_counts() + outlier_indices = outlier_indices[ + outlier_indices <= 4 + ].index # if only 4 points are in hex, they're outliers + joined_gdf.loc[joined_gdf["h3_index"].isin(outlier_indices), "h3_index"] = "fffffffffffffff" + + train_indices, test_indices = train_test_split( + range(len(joined_gdf)), + test_size=test_size, # * 2, # multiply for dev set also + stratify=joined_gdf.h3_index, # stratify by spatial h3 + random_state=random_state, + ) + + # dev_indices, test_indices = train_test_split( + # range(len(test_indices)), + # test_size=0.5, + # stratify=joined_gdf.iloc[ + # test_indices + # ].h3_index, # perform spatial stratify (by h3 index) + # ) + + return ( + gdf_.iloc[train_indices], + gdf_.iloc[test_indices], + ) # , gdf_.iloc[dev_indices], + + def get_h3_with_labels( + self, + resolution: int, + train_gdf: gpd.GeoDataFrame, + test_gdf: Optional[gpd.GeoDataFrame], + target_column: Optional[str] = None, + ) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame]]: + """ + Returns h3 indexes with target labels from the dataset. + + Points are aggregated to hexes and target column values are averaged or if target column \ + is None, then the number of points is calculted within a hex and scaled to [0,1]. + + Args: + resolution (int): h3 resolution to regionalize data. + train_gdf (gpd.GeoDataFrame): GeoDataFrame with training data. + test_gdf (Optional[gpd.GeoDataFrame]): GeoDataFrame with testing data. + target_column (Optional[str], optional): Target column name. If None, aggregates h3 \ + on basis of number of points within a hex ov given resolution. Defaults to None. + + Returns: + tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame]]: Train, Test hexes with target \ + labels in GeoDataFrames + """ + # if target_column is None: + # target_column = "count" + if target_column is None: + target_column = getattr(self, "target", None) or "count" + + _train_gdf = self._aggregate_hexes(train_gdf, resolution, target_column) + + if test_gdf is not None: + _test_gdf = self._aggregate_hexes(test_gdf, resolution, target_column) + else: + _test_gdf = None + + # Scale the "count" column to [0, 1] if it is the target column + if target_column == "count": + scaler = MinMaxScaler() + # Fit the scaler on the train dataset and transform + _train_gdf["count"] = scaler.fit_transform(_train_gdf[["count"]]) + if _test_gdf is not None: + _test_gdf["count"] = scaler.transform(_test_gdf[["count"]]) + + return _train_gdf, _test_gdf + + def _aggregate_hexes( + self, + gdf: gpd.GeoDataFrame, + resolution: int, + target_column: Optional[str] = None, + ) -> gpd.GeoDataFrame: + """ + Aggregates points and calculates them or the mean of their target column within each hex. + + Args: + gdf (gpd.GeoDataFrame): GeoDataFrame with data. + resolution (int): h3 resolution to regionalize data. + target_column (Optional[str], optional): Target column name. If None, aggregates h3 on \ + basis of number of points within a hex ov given resolution. Defaults to None. + + Returns: + gpd.GeoDataFrame: GeoDataFrame with aggregated data. + """ + gdf_ = gdf.copy() + regionalizer = H3Regionalizer(resolution=resolution) + regions = regionalizer.transform(gdf) + joined_gdf = gpd.sjoin(gdf, regions, how="left", predicate="within") # noqa: E501 + joined_gdf.rename(columns={"index_right": "h3_index"}, inplace=True) + if target_column == "count": + aggregated = joined_gdf.groupby("h3_index").size().reset_index(name=target_column) + + else: + # Calculate mean of the target column within each hex + aggregated = ( + joined_gdf.groupby("h3_index")[target_column].mean().reset_index(name=target_column) + ) + + gdf_ = regions.merge(aggregated, how="inner", left_on="region_id", right_on="h3_index") + gdf_.rename(columns={"h3_index": "region_id"}, inplace=True) + # gdf_.index = gdf_["region_id"] + + gdf_.drop(columns=["geometry"], inplace=True) + return gdf_ diff --git a/srai/datasets/airbnb_multicity.py b/srai/datasets/airbnb_multicity.py new file mode 100644 index 00000000..b90ea885 --- /dev/null +++ b/srai/datasets/airbnb_multicity.py @@ -0,0 +1,81 @@ +""" +AirbnbMulticity dataset loader. + +This module contains AirbnbMulticity dataset. +""" + +from typing import Optional + +import geopandas as gpd +import pandas as pd + +from srai.constants import WGS84_CRS +from srai.datasets import HuggingFaceDataset + + +class AirbnbMulticityDataset(HuggingFaceDataset): + """ + AirbnbMulticity dataset. + + Dataset description will be added. + """ + + def __init__(self) -> None: + """Create the dataset.""" + categorical_columns = ["name", "host_name", "neighborhood", "room_type", "city"] + numerical_columns = [ + "number_of_reviews", + "minimum_nights", + "availability_365", + "calculated_host_listings_count", + "number_of_reviews_ltm", + ] + target = "price" + type = "point" + super().__init__( + "kraina/airbnb_multicity", + type=type, + numerical_columns=numerical_columns, + categorical_columns=categorical_columns, + target=target, + ) + + def _preprocessing(self, data: pd.DataFrame, version: Optional[str] = None) -> gpd.GeoDataFrame: + """ + Preprocessing to get GeoDataFrame with location data, based on GEO_EDA files. + + Args: + data (pd.DataFrame): Data of AirbnbMulticity dataset. + version (str, optional): version of a dataset + + Returns: + gpd.GeoDataFrame: preprocessed data. + """ + df = data.copy() + gdf = gpd.GeoDataFrame( + df.drop(["latitude", "longitude"], axis=1), + geometry=gpd.points_from_xy(x=df["longitude"], y=df["latitude"]), + crs=WGS84_CRS, + ) + + return gdf + + def load( + self, hf_token: Optional[str] = None, version: Optional[str] = "res_8" + ) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame]]: + """ + Method to load dataset. + + Args: + hf_token (str, optional): If needed, a User Access Token needed to authenticate to + the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used. + Defaults to None. + version (str, optional): version of a dataset. + Available: 'res_8', 'res_9', 'res_10'. Defaults to 'res_8'. Benchmark version \ + comprises six cities: Paris, Rome, London, Amsterdam, Melbourne, New York City. + Raw, full data from ~80 cities available as 'all'. + + Returns: + gpd.GeoDataFrame, gpd.Geodataframe | None : Loaded train data and test data if exist. + """ + return super().load(hf_token, version) diff --git a/srai/datasets/chicago_crime.py b/srai/datasets/chicago_crime.py new file mode 100644 index 00000000..1f32e61b --- /dev/null +++ b/srai/datasets/chicago_crime.py @@ -0,0 +1,85 @@ +""" +Chicago Crime dataset loader. + +This module contains Chicago Crime Dataset. +""" + +from typing import Optional + +import geopandas as gpd +import pandas as pd + +from srai.constants import WGS84_CRS +from srai.datasets import HuggingFaceDataset + + +class ChicagoCrimeDataset(HuggingFaceDataset): + """ + Chicago Crime dataset. + + This dataset reflects reported incidents of crime (with the exception of murders where data + exists for each victim) that occurred in the City of Chicago. Data is extracted from the Chicago + Police Department's CLEAR (Citizen Law Enforcement Analysis and Reporting) system. + """ + + def __init__(self) -> None: + """Create the dataset.""" + numerical_columns = ["Ward", "Community Area"] + categorical_columns = [ + "Primary Type", + "Description", + "Location Description", + "Arrest", + "Domestic", + "Year", + "FBI Code", + ] + type = "point" + # target = "Primary Type" + target = None + super().__init__( + "kraina/chicago_crime", + type=type, + numerical_columns=numerical_columns, + categorical_columns=categorical_columns, + target=target, + ) + + def _preprocessing(self, data: pd.DataFrame, version: Optional[str] = None) -> gpd.GeoDataFrame: + """ + Preprocessing to get GeoDataFrame with location data, based on GEO_EDA files. + + Args: + data: Data of Chicago Crime dataset. + version: version of a dataset + + Returns: + gpd.GeoDataFrame: preprocessed data. + """ + df = data.copy() + gdf = gpd.GeoDataFrame( + df.drop(["Latitude", "Longitude", "X Coordinate", "Y Coordinate"], axis=1), + geometry=gpd.points_from_xy(x=df["Longitude"], y=df["Latitude"]), + crs=WGS84_CRS, + ) + return gdf + + def load( + self, hf_token: Optional[str] = None, version: Optional[str] = "res_9" + ) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame]]: + """ + Method to load dataset. + + Args: + hf_token (str, optional): If needed, a User Access Token needed to authenticate to + the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used. + Defaults to None. + version (str, optional): version of a dataset. + Available: Official spatial train-test split from year 2022 in chosen h3 resolution: + 'res_8', 'res_9, 'res_10'. Defaults to 'res_9'. Raw data from other years available + as: '2020', '2021', '2022'. + + Returns: + gpd.GeoDataFrame, gpd.Geodataframe | None : Loaded train data and test data if exist. + """ + return super().load(hf_token, version) diff --git a/srai/datasets/house_sales_in_king_county.py b/srai/datasets/house_sales_in_king_county.py new file mode 100644 index 00000000..cbac14d8 --- /dev/null +++ b/srai/datasets/house_sales_in_king_county.py @@ -0,0 +1,90 @@ +""" +House Sales in King County dataset loader. + +This module contains House Sales in King County Dataset. +""" + +from typing import Optional + +import geopandas as gpd +import pandas as pd + +from srai.constants import WGS84_CRS +from srai.datasets import HuggingFaceDataset + + +class HouseSalesInKingCountyDataset(HuggingFaceDataset): + """ + House Sales in King County dataset. + + This dataset contains house sale prices for King County, which includes Seattle. It includes + homes sold between May 2014 and May 2015. + + It's a great dataset for evaluating simple regression models. + """ + + def __init__(self) -> None: + """Create the dataset.""" + numerical_columns = [ + "bathrooms", + "sqft_living", + "sqft_lot", + "floors", + "condition", + "grade", + "sqft_above", + "sqft_basement", + "sqft_living15", + "sqft_lot15", + ] + categorical_columns = ["view", "yr_built", "yr_renovated", "waterfront"] + type = "point" + target = "price" + super().__init__( + "kraina/house_sales_in_king_county", + type=type, + numerical_columns=numerical_columns, + categorical_columns=categorical_columns, + target=target, + ) + + def _preprocessing( + self, data: pd.DataFrame, version: Optional[str] = "res_8" + ) -> gpd.GeoDataFrame: + """ + Preprocess the dataset from HuggingFace. + + Args: + data (pd.DataFrame): a dataset to preprocess + version (str, optional): version of dataset. + Available: 'res_8', 'res_9', 'res_10'. Defaults to 'res_8'. + Raw data available as 'all'. + + Returns: + gpd.GeoDataFrame: preprocessed data. + """ + gdf = gpd.GeoDataFrame( + data.drop(["lat", "long"], axis=1), + geometry=gpd.points_from_xy(x=data["long"], y=data["lat"]), + crs=WGS84_CRS, + ) + return gdf + + def load( + self, hf_token: Optional[str] = None, version: Optional[str] = "res_8" + ) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame]]: + """ + Method to load dataset. + + Args: + hf_token (str, optional): If needed, a User Access Token needed to authenticate to + the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used. + Defaults to None. + version (str, optional): version of a dataset. + Available: 'res_8', 'res_9', 'res_10'. Defaults to 'res_8'. \ + Raw, full data available as 'all'. + + Returns: + gpd.GeoDataFrame, gpd.Geodataframe | None : Loaded train data and test data if exist. + """ + return super().load(hf_token, version) diff --git a/srai/datasets/philadelphia_crime.py b/srai/datasets/philadelphia_crime.py new file mode 100644 index 00000000..3ce24f11 --- /dev/null +++ b/srai/datasets/philadelphia_crime.py @@ -0,0 +1,96 @@ +""" +Philadelphia Crime dataset loader. + +This module contains Philadelphia Crime Dataset. +""" + +from typing import Optional + +import geopandas as gpd + +from srai.constants import WGS84_CRS +from srai.datasets import HuggingFaceDataset + +years_previous: list[int] = [2013, 2014, 2015, 2016, 2017, 2018, 2019] +years_current: list[int] = [2020, 2021, 2022, 2023] + + +class PhiladelphiaCrimeDataset(HuggingFaceDataset): + """ + Philadelphia Crime dataset. + + Crime incidents from the Philadelphia Police Department. Part I crimes include violent offenses + such as aggravated assault, rape, arson, among others. Part II crimes include simple assault, + prostitution, gambling, fraud, and other non-violent offenses. + """ + + def __init__(self) -> None: + """Create the dataset.""" + numerical_columns = None + categorical_columns = [ + "hour", + "dispatch_date", + "dispatch_time", + "dc_dist", + "psa", + ] + type = "point" + # target = "text_general_code" + target = None + super().__init__( + "kraina/philadelphia_crime", + type=type, + numerical_columns=numerical_columns, + categorical_columns=categorical_columns, + target=target, + ) + + def _preprocessing( + self, data: gpd.GeoDataFrame, version: Optional[str] = None + ) -> gpd.GeoDataFrame: + """ + Preprocess the dataset from HuggingFace. + + Args: + data (pd.DataFrame): a dataset to preprocess + version (str, optional): version of dataset + + Returns: + gpd.GeoDataFrame: preprocessed data. + """ + df = data.copy() + gdf = gpd.GeoDataFrame( + df.drop(["lng", "lat"], axis=1), + geometry=gpd.points_from_xy(df["lng"], df["lat"]), + crs=WGS84_CRS, + ) + # TODO: Add numerical and categorical columns + # if version in years_previous: + # self.numerical_columns = None + # self.categorical_columns = None + # else: + # self.numerical_columns = None + # self.categorical_columns = None + + return gdf + + def load( + self, hf_token: Optional[str] = None, version: Optional[str] = "res_8" + ) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame]]: + """ + Method to load dataset. + + Args: + hf_token (str, optional): If needed, a User Access Token needed to authenticate to + the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used. + Defaults to None. + version (str, optional): version of a dataset. + Available: Official spatial train-test split from year 2023 in chosen h3 resolution: + 'res_8', 'res_9, 'res_10'. Defaults to 'res_8'. Raw data from other years available + as: '2013', '2014', '2015', '2016', '2017', '2018','2019', '2020', '2021', + '2022', '2023'. + + Returns: + gpd.GeoDataFrame, gpd.Geodataframe | None : Loaded train data and test data if exist. + """ + return super().load(hf_token, version) diff --git a/srai/datasets/police_department_incidents.py b/srai/datasets/police_department_incidents.py new file mode 100644 index 00000000..6ac43eec --- /dev/null +++ b/srai/datasets/police_department_incidents.py @@ -0,0 +1,92 @@ +""" +The San Francisco Police Department's (SFPD) Incident Report dataset loader. + +This module contains The San Francisco Police Department's (SFPD) Incident Report Datatset. +""" + +from typing import Optional + +import geopandas as gpd +import pandas as pd + +from srai.constants import WGS84_CRS +from srai.datasets import HuggingFaceDataset + + +class PoliceDepartmentIncidentsDataset(HuggingFaceDataset): + """ + The San Francisco Police Department's (SFPD) Incident Report Datatset. + + This dataset includes incident reports that have been filed as of January 1, 2018 till March, + 2024. These reports are filed by officers or self-reported by members of the public using SFPD’s + online reporting system. + """ + + def __init__(self) -> None: + """Create the dataset.""" + numerical_columns = None + categorical_columns = [ + "Incdident Year", + "Incident Day of Week", + "Police District", + "Analysis Neighborhood", + "Incident Description", + "Incident Time", + "Incident Code", + "Report Type Code", + "Police District", + "Analysis Neighborhood", + ] + type = "point" + # target = "Incident Category" + target = None + super().__init__( + "kraina/police_department_incidents", + type=type, + numerical_columns=numerical_columns, + categorical_columns=categorical_columns, + target=target, + ) + + def _preprocessing( + self, data: pd.DataFrame, version: Optional[str] = "res_9" + ) -> gpd.GeoDataFrame: + """ + Preprocess the dataset from HuggingFace. + + Args: + data (pd.DataFrame): a dataset to preprocess + version (str, optional): version of dataset. + Available: Official spatial train-test split in chosen h3 resolution: + 'res_8', 'res_9, 'res_10'. Defaults to 'res_9'. All data available + as 'all'. + + Returns: + gpd.GeoDataFrame: preprocessed data. + """ + df = data.copy() + gdf = gpd.GeoDataFrame( + df.drop(["Latitude", "Longitude"], axis=1), + geometry=gpd.points_from_xy(x=df["Longitude"], y=df["Latitude"]), + crs=WGS84_CRS, + ) + return gdf + + def load( + self, hf_token: Optional[str] = None, version: Optional[str] = "res_9" + ) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame]]: + """ + Method to load dataset. + + Args: + hf_token (str, optional): If needed, a User Access Token needed to authenticate to + the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used. + Defaults to None. + version (str, optional): version of a dataset. + Available: 'res_8', 'res_9', 'res_10'. Defaults to 'res_9'. \ + Raw, full data available as 'all'. + + Returns: + gpd.GeoDataFrame, gpd.Geodataframe | None : Loaded train data and test data if exist. + """ + return super().load(hf_token, version)