Skip to content

Commit

Permalink
Merge branch 'master' into tab_frame_comp_2
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta authored Sep 6, 2024
2 parents 57a3115 + 475105e commit aa6ae6f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion benchmark/pytorch_tabular_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pytorch_tabular import TabularModel
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models import FTTransformerConfig, TabTransformerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

Expand Down Expand Up @@ -114,6 +113,7 @@ def train_tabular_model() -> float:
)
else:
raise ValueError(f"Invalid model type: {args.model_type}")

tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
Expand Down

0 comments on commit aa6ae6f

Please sign in to comment.