-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to correctly load model matrices into a learner (e.g., a Random Forest)? #125
Comments
I think I have a workaround! First, as background, the original training script that doesn't have to do any loading of model matrices, it will achieve the following test set performance:
That script also saves the model matrices using what I outlined earlier. Therefore, it logically follows that for the loading process to work, I should be able to set up a new script, create the same Random Forest (w/same parameters), load those model matrices, and do the prediction, and get the same ROC AUC and the same MAE, WITHOUT having to do any training. My insight is to do the following: I will have the same new script as outlined in my previous post. But after setting the 'opts', I will set opts.depth = 1, so that I can train with minimal lag. So unfortunately I need to train, but the training is only done to force initialization. Then I will manually load the matrices in, change the depth to be larger, and then predict, and I get the same AUC and MAE. To be explicit, here's what the new script looks like:
The reason why the
In line 225 of the RandomForest.scala code (as of today), it performs the correct assignment:
So I need to do that myself. |
Correction to the above: the better way is to do the following:
This way you can avoid calling the train method entirely. Edit: where matrix0, etc., are the loaded files that store the trained model matrices from previous iterations. |
BIDMach version: 1037ae7 (August 12)
BIDMat version: 1383cb4ccf3933a8175073b8eab9819be7e252bf (August 12)
OS: Linux (it's on "stout")
Here's the problem setup. I have run Random Forests on some training data. At the end of my script, I call the model's save method to save the model matrices:
This creates four files since RFs have four model matrices (ctrees.fmat.lz4, ftrees.imat.lz4, itrees.imat.lz4, vtrees.imat.lz4). Now, in a separate script, I want to create a new Random Forests that will load in these model matrices so that it doesn't have to train. Here's what an example script might look like, with file names removed for privacy:
I have the training data there even though I don't think it's needed. I have it there because I am trying to keep everything consistent with the original script that ran training. I'm assuming that if the Random Forest got trained with tree depth 20, then here, we should also have a tree depth of 20 if we're going to be loading the model matrices, and soon.
Unfortunately, running the above (with the appropriate data, but I think any data will do) I get:
This error happens in the init method, implying that the Random Forests have to be initialized somehow. This happens automatically when you call the train method, but I don't know how to get it initialized without calling train. Are there some examples of scripts that do that here? I couldn't find any by searching. The Random Forest's 'load' method looks like it "returns" a Random Forest model, but I cannot simply do:
Do you have some advice? I'm currently working through this issue so hopefully I can find how to do it, but even if I do, it would be great to have confirmation that I'm doing the steps the way it's supposed to work.
Final (somewhat unrelated) comment: the Random Forest model matrices all have dimension (opts.nnodes, opts.ntrees). Therefore, if we want to combine multiple Random Forest trained trees together for a testing set, we have to horizontally concatenate the matrices (not vertically) to make more columns. The Random Forest code doesn't seem to have a method for that but I can do that offline myself.
-Daniel
The text was updated successfully, but these errors were encountered: