-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add GTFS downloader #202
base: main
Are you sure you want to change the base?
Changes from all commits
05bfc20
4cedc56
9996c6f
c643f75
b2dbfa9
6a87fe5
91b55f9
9511ce5
2d70d5b
a7ff8e0
3876105
6bdb355
f26a41b
1702b19
9ab9fd1
e61c230
6adda74
cc34dc2
9516798
a634b2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from srai.downloaders import GTFSDownloader" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"downloader = GTFSDownloader(update_catalog=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"downloader.search(query=\"Wrocław, PL\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import osmnx as ox" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"wro_gdf = ox.geocode_to_gdf(\"Wrocław, PL\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"downloader.catalog[downloader.catalog.intersects(wro_gdf[\"geometry\"].iloc[0])]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ax = (\n", | ||
" downloader.catalog[downloader.catalog.intersects(wro_gdf[\"geometry\"].iloc[0])]\n", | ||
" .iloc[[6]]\n", | ||
" .plot(alpha=0.5)\n", | ||
")\n", | ||
"wro_gdf.plot(ax=ax, color=\"red\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"downloader.search(area=wro_gdf)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Downloaders.""" | ||
|
||
from .gtfs_downloader import GTFSDownloader | ||
|
||
__all__ = ["GTFSDownloader"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
""" | ||
GTFS Downloader. | ||
|
||
This module contains GTFS downlaoder. GTFS downloader is a proxy to The Mobility Database[1]. | ||
|
||
References: | ||
[1] https://database.mobilitydata.org/ | ||
""" | ||
|
||
import unicodedata | ||
from pathlib import Path | ||
from typing import List, Optional | ||
|
||
import geopandas as gpd | ||
import pandas as pd | ||
from functional import seq | ||
from shapely.geometry import box | ||
|
||
from srai.constants import WGS84_CRS | ||
from srai.utils import download_file | ||
|
||
CATALOG_URL = "https://bit.ly/catalogs-csv" | ||
CACHE_DIR = Path.home() / ".cache" / "srai" | ||
CATALOG_SEARCH_COLUMNS = [ | ||
"name", | ||
"location.country_code", | ||
"location.subdivision_name", | ||
"location.municipality", | ||
"provider", | ||
] | ||
CATALOG_BBOX_COLUMNS = [ | ||
"location.bounding_box.minimum_longitude", | ||
"location.bounding_box.minimum_latitude", | ||
"location.bounding_box.maximum_longitude", | ||
"location.bounding_box.maximum_latitude", | ||
] | ||
|
||
|
||
class GTFSDownloader: | ||
""" | ||
GTFSDownloader. | ||
|
||
This class provides methods to search and download GTFS feeds from The Mobility Database[1]. | ||
""" | ||
|
||
def __init__(self, update_catalog: bool = False) -> None: | ||
""" | ||
Initialize GTFS downloader. | ||
|
||
Args: | ||
update_catalog (bool, optional): Update catalog file if present. Defaults to False. | ||
""" | ||
self.catalog = self._load_catalog(update_catalog) | ||
|
||
def update_catalog(self) -> None: | ||
"""Update catalog file.""" | ||
self.catalog = self._load_catalog(update_catalog=True) | ||
|
||
def search( | ||
self, query: Optional[str] = None, area: Optional[gpd.GeoDataFrame] = None | ||
) -> pd.DataFrame: | ||
""" | ||
Search catalog by name, location or area. | ||
|
||
Examples for text queries: "Wrocław, PL", "New York, US", "Amtrak". | ||
|
||
Args: | ||
query (str): Search query with elements separated by comma. | ||
area (gpd.GeoDataFrame): Area to search in. | ||
|
||
Returns: | ||
pd.DataFrame: Search results. | ||
|
||
Raises: | ||
ValueError: If `area` is not a GeoDataFrame (has no geometry column). | ||
ValueError: If neither `query` nor `area` is provided. | ||
""" | ||
if query is None and area is None: | ||
raise ValueError("Either query or area must be provided.") | ||
|
||
if query is not None: | ||
query_filter = self._search_by_query(query) | ||
else: | ||
query_filter = [True] * len(self.catalog) | ||
|
||
if area is not None: | ||
if "geometry" not in area.columns: | ||
raise ValueError("Provided area has no geometry column.") | ||
|
||
area_filter = self._search_by_area(area) | ||
else: | ||
area_filter = [True] * len(self.catalog) | ||
|
||
return self.catalog[query_filter & area_filter] | ||
|
||
def _search_by_query(self, query: str) -> pd.Series: | ||
""" | ||
Perform search by query. | ||
|
||
Args: | ||
query (str): Search query with elements separated by comma. | ||
|
||
Returns: | ||
pd.Series: Series of booleans indicating if row matches the query. | ||
""" | ||
query_processed = seq(query.split(",")).map(self._remove_accents).map(str.strip).to_list() | ||
catalog_processed = ( | ||
self.catalog[CATALOG_SEARCH_COLUMNS].fillna("").applymap(self._remove_accents) | ||
) | ||
|
||
res: List[bool] = ( | ||
seq(catalog_processed).map(lambda row: all(q in row for q in query_processed)).to_list() | ||
) | ||
return pd.Series(res, dtype=bool) | ||
|
||
def _search_by_area(self, area: gpd.GeoDataFrame) -> pd.Series: | ||
""" | ||
Perform search by area. | ||
|
||
Args: | ||
area (gpd.GeoDataFrame): Area to search in. | ||
|
||
Returns: | ||
pd.Series: Series of booleans indicating if row matches the area. | ||
""" | ||
area = area.to_crs(WGS84_CRS) | ||
result = self.catalog.intersects(area.geometry.unary_union) | ||
return result | ||
|
||
def _remove_accents(self, text: str) -> str: | ||
""" | ||
Remove accents from text. | ||
|
||
Will remove all accents ("ś" -> "s", "ü" -> "u") and replace "ł" with "l". | ||
|
||
Args: | ||
text (str): Text to process. | ||
|
||
Returns: | ||
str: Text without accents. | ||
""" | ||
result = "".join( | ||
c for c in unicodedata.normalize("NFD", text) if unicodedata.category(c) != "Mn" | ||
) | ||
result = result.replace("ł", "l") # required for Polish | ||
|
||
return result | ||
|
||
def _load_catalog(self, update_catalog: bool = False) -> gpd.GeoDataFrame: | ||
""" | ||
Load catalog and add geometry column. | ||
|
||
Args: | ||
update_catalog (bool, optional): Update catalog file if present. Defaults to False. | ||
|
||
Returns: | ||
pd.DataFrame: Catalog. | ||
""" | ||
catalog_file = CACHE_DIR / "catalog.csv" | ||
|
||
if not catalog_file.exists() or update_catalog: | ||
download_file(CATALOG_URL, catalog_file) | ||
|
||
df = pd.read_csv(catalog_file) | ||
|
||
df[CATALOG_BBOX_COLUMNS] = df[CATALOG_BBOX_COLUMNS].fillna(0) | ||
|
||
df["geometry"] = df.apply( | ||
lambda row: (box(*row[CATALOG_BBOX_COLUMNS].tolist())), | ||
axis=1, | ||
) | ||
|
||
return gpd.GeoDataFrame(df, geometry="geometry", crs=WGS84_CRS) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,33 +7,34 @@ | |
|
||
|
||
def download_file( | ||
url: str, fname: str, chunk_size: int = 1024, force_download: bool = True | ||
url: str, filename: Path, chunk_size: int = 1024, force_download: bool = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't force users to use only string or only Path, we can accept an input like this: |
||
) -> None: | ||
""" | ||
Download a file with progress bar. | ||
|
||
Args: | ||
url (str): URL to download. | ||
fname (str): File name. | ||
filename (Path): File to save data to. | ||
chunk_size (str): Chunk size. | ||
force_download (bool): Flag to force download even if file exists. | ||
|
||
Source: https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 | ||
""" | ||
if Path(fname).exists() and not force_download: | ||
if filename.exists() and not force_download: | ||
warnings.warn("File exists. Skipping download.", stacklevel=1) | ||
return | ||
|
||
Path(fname).parent.mkdir(parents=True, exist_ok=True) | ||
filename.parent.mkdir(parents=True, exist_ok=True) | ||
resp = requests.get( | ||
url, | ||
headers={"User-Agent": "SRAI Python package (https://github.com/kraina-ai/srai)"}, | ||
stream=True, | ||
) | ||
resp.raise_for_status() | ||
total = int(resp.headers.get("content-length", 0)) | ||
with open(fname, "wb") as file, tqdm( | ||
desc=fname.split("/")[-1], | ||
|
||
with filename.open("wb") as file, tqdm( | ||
desc=filename.name, | ||
total=total, | ||
unit="iB", | ||
unit_scale=True, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't make another namespace / module, since
PbfFileDownloader
and other classes related to pbf files are next toOsmPbfLoader
. I'd even go further and merge the logic ofGTFSDownloader
withGTFSLoader
to simplify the usage for the user. Maybe create a directorygtfs
within loaders and place both files there.By this I mean modyfing existing function: