Skip to content

Commit

Permalink
Provide auto-complete in notebooks for table names starting with an u…
Browse files Browse the repository at this point in the history
…nderscore (#315)

* Provide auto-complete in notebooks for table names starting with an underscore

* improve documentation

* improve documentation

* fix tests
  • Loading branch information
nanne-aben authored Feb 27, 2024
1 parent af5dad4 commit fd9cbc6
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 19 deletions.
77 changes: 67 additions & 10 deletions docs/source/loading_datasets_in_notebooks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,15 @@
"source": [
"import pandas as pd\n",
"\n",
"(\n",
" spark.createDataFrame(\n",
" pd.DataFrame(\n",
" dict(\n",
" name=[\"Jack\", \"John\", \"Jane\"],\n",
" age=[20, 30, 40],\n",
" )\n",
"df = spark.createDataFrame(\n",
" pd.DataFrame(\n",
" dict(\n",
" name=[\"Jack\", \"John\", \"Jane\"],\n",
" age=[20, 30, 40],\n",
" )\n",
" ).createOrReplaceTempView(\"person_table\")\n",
")"
" )\n",
")\n",
"df.createOrReplaceTempView(\"person_table\")"
]
},
{
Expand Down Expand Up @@ -298,14 +297,72 @@
"person, Person = db.person_table.load()"
]
},
{
"cell_type": "markdown",
"id": "071cd277",
"metadata": {},
"source": [
"# Names starting with an underscore\n",
"\n",
"For `Catalogs()`, `Databases()` and `Database()`, names starting with an underscore are problematic for auto-complete. For example, suppose we'd make this table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acfd55af",
"metadata": {},
"outputs": [],
"source": [
"df.write.saveAsTable(\"default._person\")"
]
},
{
"cell_type": "markdown",
"id": "7fcf12ef",
"metadata": {},
"source": [
"Then we won't get auto-complete using `db._person`, since our notebook will assume we don't want auto-complete on private variables. To circumvent this, we rename the attribute as such:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0491ad1",
"metadata": {},
"outputs": [],
"source": [
"db = Database()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7f41731",
"metadata": {},
"outputs": [],
"source": [
"persons, Person = db.u_person()"
]
},
{
"cell_type": "markdown",
"id": "5e1b3aaa",
"metadata": {},
"source": [
"The underlying table is not renamed, solely the class attribute used for autocomplete.\n",
"\n",
"When renaming the attribute leads to a naming conflict (e.g. because `u_person` already exists), we resolve the conflict by adding more underscores (e.g. `_person` would then become `u__person`)."
]
},
{
"cell_type": "markdown",
"id": "e0187268",
"metadata": {},
"source": [
"## Loading a single DataSet\n",
"\n",
"Finally, if you really only want to load one DataSet, you can use `load_table()`."
"If you really only want to load one DataSet, you can use `load_table()`."
]
},
{
Expand Down
42 changes: 42 additions & 0 deletions tests/_utils/test_load_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,48 @@ def test_databases_with_table(spark: SparkSession):
_drop_table(spark, "default.table_b")


def test_databases_with_table_name_starting_with_underscore(spark: SparkSession):
df = create_empty_dataset(spark, A)
df.write.saveAsTable("default._table_b")

try:
db = Databases(spark)
df_loaded, _ = db.default.u_table_b() # type: ignore
assert_df_equality(df, df_loaded)
assert db.default.u_table_b.str == "default._table_b" # type: ignore
except Exception as exception:
_drop_table(spark, "default._table_b")
raise exception

_drop_table(spark, "default._table_b")


def test_databases_with_table_name_starting_with_underscore_with_naming_conflict(
spark: SparkSession,
):
df_a = create_empty_dataset(spark, A)
df_b = create_empty_dataset(spark, B)
df_a.write.saveAsTable("default._table_b")
df_b.write.saveAsTable("default.u_table_b")

try:
db = Databases(spark)
df_loaded, _ = db.default.u__table_b() # type: ignore
assert_df_equality(df_a, df_loaded)
assert db.default.u__table_b.str == "default._table_b" # type: ignore

df_loaded, _ = db.default.u_table_b() # type: ignore
assert_df_equality(df_b, df_loaded)
assert db.default.u_table_b.str == "default.u_table_b" # type: ignore
except Exception as exception:
_drop_table(spark, "default._table_b")
_drop_table(spark, "default.u_table_b")
raise exception

_drop_table(spark, "default._table_b")
_drop_table(spark, "default.u_table_b")


def test_catalogs(spark: SparkSession):
df = create_empty_dataset(spark, A)
df.write.saveAsTable("spark_catalog.default.table_b")
Expand Down
53 changes: 44 additions & 9 deletions typedspark/_utils/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,34 @@ def _get_spark_session(spark: Optional[SparkSession]) -> SparkSession:
raise ValueError("No active SparkSession found.") # pragma: no cover


def _resolve_names_starting_with_an_underscore(name: str, names: list[str]) -> str:
"""Autocomplete is currently problematic when a name (of a table, database, or
catlog) starts with an underscore.
In this case, it's considered a private attribute and it doesn't show up in the
autocomplete options in your notebook. To combat this behaviour, we add a u as a
prefix, followed by as many underscores as needed (up to 100) to keep the name
unique.
"""
if not name.startswith("_"):
return name

prefix = "u"
proposed_name = prefix + name
i = 0
while proposed_name in names:
prefix = prefix + "_"
proposed_name = prefix + name
i += 1
if i > 100:
raise Exception(
"Couldn't find a unique name, even when adding 100 underscores. This seems unlikely"
" behaviour, exiting to prevent an infinite loop."
) # pragma: no cover

return proposed_name


class Table:
"""Loads a table in a database."""

Expand Down Expand Up @@ -126,11 +154,13 @@ def __init__(
self._db_name = f"{catalog_name}.{db_name}"

tables = spark.sql(f"show tables from {self._db_name}").collect()
table_names = [table.tableName for table in tables]

for table in tables:
table_name = table.tableName
escaped_name = _resolve_names_starting_with_an_underscore(table.tableName, table_names)
self.__setattr__(
table_name,
Table(spark, self._db_name, table_name, table.isTemporary),
escaped_name,
Table(spark, self._db_name, table.tableName, table.isTemporary),
)

@property
Expand All @@ -156,12 +186,16 @@ def __init__(
query = f"show databases in {catalog_name}"

databases = spark.sql(query).collect()
database_names = [self._extract_db_name(database) for database in databases]
timeout = DatabasesTimeout(silent, n=len(databases))

for i, database in enumerate(databases):
for i, db_name in enumerate(database_names):
timeout.check_for_warning(i)
db_name = self._extract_db_name(database)
self.__setattr__(db_name, Database(spark, db_name, catalog_name))
escaped_name = _resolve_names_starting_with_an_underscore(db_name, database_names)
self.__setattr__(
escaped_name,
Database(spark, db_name, catalog_name),
)

def _extract_db_name(self, database: Row) -> str:
"""Extracts the database name from a Row.
Expand All @@ -183,12 +217,13 @@ def __init__(self, spark: Optional[SparkSession] = None, silent: bool = False):
spark = _get_spark_session(spark)

catalogs = spark.sql("show catalogs").collect()
catalog_names = [catalog.catalog for catalog in catalogs]
timeout = CatalogsTimeout(silent, n=len(catalogs))

for i, catalog in enumerate(catalogs):
for i, catalog_name in enumerate(catalog_names):
escaped_name = _resolve_names_starting_with_an_underscore(catalog_name, catalog_names)
timeout.check_for_warning(i)
catalog_name: str = catalog.catalog
self.__setattr__(
catalog_name,
escaped_name,
Databases(spark, silent=True, catalog_name=catalog_name),
)

0 comments on commit fd9cbc6

Please sign in to comment.