Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Raise a friendly message when a
str
is provided to `TensorFrame(col…
…_names_dict)` instead of a `list[str]` (#469) ### Repro ```python import torch import torch_frame from torch_frame.data import TensorFrame feat_dict = {torch_frame.categorical: torch.randint(0, 3, size=(10, 1))} col_names_dict = {torch_frame.categorical: 'cat_1'} TensorFrame(feat_dict, col_names_dict) ``` ### Before this PR ``` Traceback (most recent call last): File "/home/aki/work/github.com/kumo-ai/kumo/../../akihironitta/gist/test.py", line 7, in <module> TensorFrame(feat_dict, col_names_dict) File "/home/aki/miniconda3/envs/kumo/lib/python3.10/site-packages/torch_frame/data/tensor_frame.py", line 78, in __init__ self.validate() File "/home/aki/miniconda3/envs/kumo/lib/python3.10/site-packages/torch_frame/data/tensor_frame.py", line 113, in validate raise ValueError( ValueError: The expected number of columns for categorical feature is 5, which does not align with the column dimensionality of feat_dict[categorical] (got 1) ``` ### After this PR ``` Traceback (most recent call last): File "/home/aki/work/github.com/pyg-team/pytorch-frame/../../akihironitta/gist/test.py", line 7, in <module> TensorFrame(feat_dict, col_names_dict) File "/home/aki/work/github.com/pyg-team/pytorch-frame/torch_frame/data/tensor_frame.py", line 78, in __init__ self.validate() File "/home/aki/work/github.com/pyg-team/pytorch-frame/torch_frame/data/tensor_frame.py", line 100, in validate raise ValueError( ValueError: col_names_dict[categorical] must be a list of column names. ```
- Loading branch information