Skip to content

Commit

Permalink
Raise a friendly message when a str is provided to `TensorFrame(col…
Browse files Browse the repository at this point in the history
…_names_dict)` instead of a `list[str]` (#469)

### Repro
```python
import torch
import torch_frame
from torch_frame.data import TensorFrame

feat_dict = {torch_frame.categorical: torch.randint(0, 3, size=(10, 1))}
col_names_dict = {torch_frame.categorical: 'cat_1'}
TensorFrame(feat_dict, col_names_dict)
```

### Before this PR
```
Traceback (most recent call last):
  File "/home/aki/work/github.com/kumo-ai/kumo/../../akihironitta/gist/test.py", line 7, in <module>
    TensorFrame(feat_dict, col_names_dict)
  File "/home/aki/miniconda3/envs/kumo/lib/python3.10/site-packages/torch_frame/data/tensor_frame.py", line 78, in __init__
    self.validate()
  File "/home/aki/miniconda3/envs/kumo/lib/python3.10/site-packages/torch_frame/data/tensor_frame.py", line 113, in validate
    raise ValueError(
ValueError: The expected number of columns for categorical feature is 5, which does not align with the column dimensionality of feat_dict[categorical] (got 1)
```

### After this PR
```
Traceback (most recent call last):
  File "/home/aki/work/github.com/pyg-team/pytorch-frame/../../akihironitta/gist/test.py", line 7, in <module>
    TensorFrame(feat_dict, col_names_dict)
  File "/home/aki/work/github.com/pyg-team/pytorch-frame/torch_frame/data/tensor_frame.py", line 78, in __init__
    self.validate()
  File "/home/aki/work/github.com/pyg-team/pytorch-frame/torch_frame/data/tensor_frame.py", line 100, in validate
    raise ValueError(
ValueError: col_names_dict[categorical] must be a list of column names.
```
  • Loading branch information
akihironitta authored Dec 7, 2024
1 parent 6fca625 commit e7a6474
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
8 changes: 8 additions & 0 deletions test/data/test_tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,11 @@ def test_custom_tf_get_col_feat():
assert torch.equal(feat, feat_dict['numerical'][:, 0:1])
feat = tf.get_col_feat('num_2')
assert torch.equal(feat, feat_dict['numerical'][:, 1:2])


def test_non_list_col_names_dict():
feat_dict = {torch_frame.categorical: torch.randint(0, 3, size=(10, 1))}
# Oops, user provided a single column name without wrapping it in a list:
col_names_dict = {torch_frame.categorical: 'cat_1'}
with pytest.raises(ValueError, match='must be a list of column names'):
TensorFrame(feat_dict, col_names_dict)
8 changes: 7 additions & 1 deletion torch_frame/data/tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ def validate(self) -> None:
num_rows = self.num_rows
empty_stypes: list[torch_frame.stype] = []
for stype_name, feats in self.feat_dict.items():
num_cols = len(self.col_names_dict[stype_name])
col_names = self.col_names_dict[stype_name]
if not isinstance(col_names, list):
raise ValueError(
f"col_names_dict[{stype_name}] must be a list of column "
f"names.")

num_cols = len(col_names)
if num_cols == 0:
empty_stypes.append(stype_name)

Expand Down

0 comments on commit e7a6474

Please sign in to comment.