Skip to content

Commit

Permalink
feat: add Model.transform_columns property (#1661)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Salem Boyland <[email protected]>
Co-authored-by: Tim Swast <[email protected]>
  • Loading branch information
3 people authored Oct 12, 2023
1 parent faa50b9 commit 5ceed05
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
71 changes: 71 additions & 0 deletions google/cloud/bigquery/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

"""Define resources for the BigQuery ML Models API."""

from __future__ import annotations # type: ignore

import copy
import datetime
import typing
Expand Down Expand Up @@ -184,6 +186,21 @@ def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]:
standard_sql.StandardSqlField.from_api_repr(column) for column in resource
]

@property
def transform_columns(self) -> Sequence[TransformColumn]:
"""The input feature columns that were used to train this model.
The output transform columns used to train this model.
See REST API:
https://cloud.google.com/bigquery/docs/reference/rest/v2/models#transformcolumn
Read-only.
"""
resources: Sequence[Dict[str, Any]] = typing.cast(
Sequence[Dict[str, Any]], self._properties.get("transformColumns", [])
)
return [TransformColumn(resource) for resource in resources]

@property
def label_columns(self) -> Sequence[standard_sql.StandardSqlField]:
"""Label columns that were used to train this model.
Expand Down Expand Up @@ -434,6 +451,60 @@ def __repr__(self):
)


class TransformColumn:
"""TransformColumn represents a transform column feature.
See
https://cloud.google.com/bigquery/docs/reference/rest/v2/models#transformcolumn
Args:
resource:
A dictionary representing a transform column feature.
"""

def __init__(self, resource: Dict[str, Any]):
self._properties = resource

@property
def name(self) -> Optional[str]:
"""Name of the column."""
return self._properties.get("name")

@property
def type_(self) -> Optional[standard_sql.StandardSqlDataType]:
"""Data type of the column after the transform.
Returns:
Optional[google.cloud.bigquery.standard_sql.StandardSqlDataType]:
Data type of the column.
"""
type_json = self._properties.get("type")
if type_json is None:
return None
return standard_sql.StandardSqlDataType.from_api_repr(type_json)

@property
def transform_sql(self) -> Optional[str]:
"""The SQL expression used in the column transform."""
return self._properties.get("transformSql")

@classmethod
def from_api_repr(cls, resource: Dict[str, Any]) -> "TransformColumn":
"""Constructs a transform column feature given its API representation
Args:
resource:
Transform column feature representation from the API
Returns:
Transform column feature parsed from ``resource``.
"""
this = cls({})
resource = copy.deepcopy(resource)
this._properties = resource
return this


def _model_arg_to_model_ref(value, default_project=None):
"""Helper to convert a string or Model to ModelReference.
Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[mypy]
python_version = 3.6
python_version = 3.8
namespace_packages = True
68 changes: 68 additions & 0 deletions tests/unit/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import pytest


import google.cloud._helpers
import google.cloud.bigquery.model

KMS_KEY_NAME = "projects/1/locations/us/keyRings/1/cryptoKeys/1"

Expand Down Expand Up @@ -136,6 +138,7 @@ def test_from_api_repr(target_class):
google.cloud._helpers._rfc3339_to_datetime(got.training_runs[2]["startTime"])
== expiration_time
)
assert got.transform_columns == []


def test_from_api_repr_w_minimal_resource(target_class):
Expand Down Expand Up @@ -293,6 +296,71 @@ def test_feature_columns(object_under_test):
assert object_under_test.feature_columns == expected


def test_from_api_repr_w_transform_columns(target_class):
resource = {
"modelReference": {
"projectId": "my-project",
"datasetId": "my_dataset",
"modelId": "my_model",
},
"transformColumns": [
{
"name": "transform_name",
"type": {"typeKind": "INT64"},
"transformSql": "transform_sql",
}
],
}
got = target_class.from_api_repr(resource)
assert len(got.transform_columns) == 1
transform_column = got.transform_columns[0]
assert isinstance(transform_column, google.cloud.bigquery.model.TransformColumn)
assert transform_column.name == "transform_name"


def test_transform_column_name():
transform_columns = google.cloud.bigquery.model.TransformColumn(
{"name": "is_female"}
)
assert transform_columns.name == "is_female"


def test_transform_column_transform_sql():
transform_columns = google.cloud.bigquery.model.TransformColumn(
{"transformSql": "is_female"}
)
assert transform_columns.transform_sql == "is_female"


def test_transform_column_type():
transform_columns = google.cloud.bigquery.model.TransformColumn(
{"type": {"typeKind": "BOOL"}}
)
assert transform_columns.type_.type_kind == "BOOL"


def test_transform_column_type_none():
transform_columns = google.cloud.bigquery.model.TransformColumn({})
assert transform_columns.type_ is None


def test_transform_column_from_api_repr_with_unknown_properties():
transform_column = google.cloud.bigquery.model.TransformColumn.from_api_repr(
{
"name": "is_female",
"type": {"typeKind": "BOOL"},
"transformSql": "is_female",
"test": "one",
}
)
assert transform_column._properties == {
"name": "is_female",
"type": {"typeKind": "BOOL"},
"transformSql": "is_female",
"test": "one",
}


def test_label_columns(object_under_test):
from google.cloud.bigquery import standard_sql

Expand Down

0 comments on commit 5ceed05

Please sign in to comment.