diff --git a/pipeline/01-train.R b/pipeline/01-train.R index bbcf5b59..2166b7ee 100644 --- a/pipeline/01-train.R +++ b/pipeline/01-train.R @@ -71,28 +71,28 @@ message("Creating and fitting linear baseline model") # Create a linear model recipe with additional imputation, transformations, # and feature interactions -# lin_recipe <- model_lin_recipe( -# data = training_data_full %>% -# mutate(meta_sale_price = log(meta_sale_price)), -# pred_vars = params$model$predictor$all, -# cat_vars = params$model$predictor$categorical, -# id_vars = params$model$predictor$id -# ) -# -# # Create a linear model specification and workflow -# lin_model <- parsnip::linear_reg() %>% -# set_mode("regression") %>% -# set_engine("lm") -# lin_wflow <- workflow() %>% -# add_model(lin_model) %>% -# add_recipe( -# recipe = lin_recipe, -# blueprint = hardhat::default_recipe_blueprint(allow_novel_levels = TRUE) -# ) -# -# # Fit the linear model on the training data -# lin_wflow_final_fit <- lin_wflow %>% -# fit(data = train %>% mutate(meta_sale_price = log(meta_sale_price))) +lin_recipe <- model_lin_recipe( + data = training_data_full %>% + mutate(meta_sale_price = log(meta_sale_price)), + pred_vars = params$model$predictor$all, + cat_vars = params$model$predictor$categorical, + id_vars = params$model$predictor$id +) + +# Create a linear model specification and workflow +lin_model <- parsnip::linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") +lin_wflow <- workflow() %>% + add_model(lin_model) %>% + add_recipe( + recipe = lin_recipe, + blueprint = hardhat::default_recipe_blueprint(allow_novel_levels = TRUE) + ) + +# Fit the linear model on the training data +lin_wflow_final_fit <- lin_wflow %>% + fit(data = train %>% mutate(meta_sale_price = log(meta_sale_price))) @@ -396,10 +396,10 @@ message("Finalizing and saving trained model") test %>% mutate( pred_card_initial_fmv = predict(lgbm_wflow_final_fit, test)$.pred, - # pred_card_initial_fmv_lin = exp(predict( - # lin_wflow_final_fit, - # test %>% mutate(meta_sale_price = log(meta_sale_price)) - # )$.pred) + pred_card_initial_fmv_lin = exp(predict( + lin_wflow_final_fit, + test %>% mutate(meta_sale_price = log(meta_sale_price)) + )$.pred) ) %>% select( meta_year, meta_pin, meta_class, meta_card_num, meta_triad_code,