Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-16420: change default values from -1 to the actual default values. #16421

Open
wants to merge 2 commits into
base: rel-3.46.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 14 additions & 39 deletions h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
package hex.Infogram;

import hex.*;
import water.*;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderHelper;
import hex.ModelCategory;
import hex.genmodel.utils.DistributionFamily;
import water.DKV;
import water.H2O;
import water.Key;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

import java.util.*;
import java.util.stream.IntStream;
import hex.genmodel.utils.DistributionFamily;

import static hex.Infogram.InfogramModel.InfogramModelOutput.sortCMIRel;
import static hex.Infogram.InfogramModel.InfogramParameters.Algorithm.AUTO;
import static hex.Infogram.InfogramModel.InfogramParameters.Algorithm.gbm;
Expand Down Expand Up @@ -182,61 +189,29 @@ private void validateInfoGramParameters() {
_buildCore = _parms._protected_columns == null;

if (_buildCore) {
if (_parms._net_information_threshold == -1) { // not set
_parms._cmi_threshold = 0.1;
_parms._net_information_threshold = 0.1;
} else if (_parms._net_information_threshold > 1 || _parms._net_information_threshold < 0) {
if (_parms._net_information_threshold > 1 || _parms._net_information_threshold < 0) {
error("net_information_threshold", " should be set to be between 0 and 1.");
} else {
_parms._cmi_threshold = _parms._net_information_threshold;
}

if (_parms._total_information_threshold == -1) { // not set
_parms._relevance_threshold = 0.1;
_parms._total_information_threshold = 0.1;
} else if (_parms._total_information_threshold < 0 || _parms._total_information_threshold > 1) {
if (_parms._total_information_threshold < 0 || _parms._total_information_threshold > 1) {
error("total_information_threshold", " should be set to be between 0 and 1.");
} else {
_parms._relevance_threshold = _parms._total_information_threshold;
}

if (_parms._safety_index_threshold != -1) {
warn("safety_index_threshold", "Should not set safety_index_threshold for core infogram " +
"runs. Set net_information_threshold instead. Using default of 0.1 if not set");
}

if (_parms._relevance_index_threshold != -1) {
warn("relevance_index_threshold", "Should not set relevance_index_threshold for core " +
"infogram runs. Set total_information_threshold instead. Using default of 0.1 if not set");
}
} else { // fair infogram
if (_parms._safety_index_threshold == -1) {
_parms._cmi_threshold = 0.1;
_parms._safety_index_threshold = 0.1;
} else if (_parms._safety_index_threshold < 0 || _parms._safety_index_threshold > 1) {
if (_parms._safety_index_threshold < 0 || _parms._safety_index_threshold > 1) {
error("safety_index_threshold", " should be set to be between 0 and 1.");
} else {
_parms._cmi_threshold = _parms._safety_index_threshold;
}

if (_parms._relevance_index_threshold == -1) {
_parms._relevance_threshold = 0.1;
_parms._relevance_index_threshold = 0.1;
} else if (_parms._relevance_index_threshold < 0 || _parms._relevance_index_threshold > 1) {
if (_parms._relevance_index_threshold < 0 || _parms._relevance_index_threshold > 1) {
error("relevance_index_threshold", " should be set to be between 0 and 1.");
} else {
_parms._relevance_threshold = _parms._relevance_index_threshold;
}

if (_parms._net_information_threshold != -1) {
warn("net_information_threshold", "Should not set net_information_threshold for fair " +
"infogram runs, set safety_index_threshold instead. Using default of 0.1 if not set");
}
if (_parms._total_information_threshold != -1) {
warn("total_information_threshold", "Should not set total_information_threshold for fair" +
" infogram runs, set relevance_index_threshold instead. Using default of 0.1 if not set");
}

if (AUTO.equals(_parms._algorithm))
_parms._algorithm = gbm;
}
Expand Down
11 changes: 5 additions & 6 deletions h2o-admissibleml/src/main/java/hex/Infogram/InfogramModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import hex.*;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.schemas.*;
import hex.schemas.InfogramV3;
import water.*;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.TwoDimTable;

import java.lang.reflect.Field;
import java.util.*;
Expand Down Expand Up @@ -55,10 +54,10 @@ public static class InfogramParameters extends Model.Parameters {
public String[] _protected_columns = null; // store features to be excluded from final model
public double _cmi_threshold = 0.1; // default set by Deep
public double _relevance_threshold = 0.1; // default set by Deep
public double _total_information_threshold = -1; // relevance threshold for core infogram
public double _net_information_threshold = -1; // cmi threshold for core infogram
public double _safety_index_threshold = -1; // cmi threshold for safe infogram
public double _relevance_index_threshold = -1; // relevance threshold for safe infogram
public double _total_information_threshold = 0.1; // relevance threshold for core infogram
public double _net_information_threshold = 0.1; // cmi threshold for core infogram
public double _safety_index_threshold = 0.1; // cmi threshold for safe infogram
public double _relevance_index_threshold = 0.1; // relevance threshold for safe infogram
public double _data_fraction = 1.0; // fraction of data to use to calculate infogram
public Model.Parameters _infogram_algorithm_parameters; // store parameters of chosen algorithm
public int _top_n_features = 50; // if 0 consider all predictors, otherwise, consider topk predictors
Expand Down
29 changes: 13 additions & 16 deletions h2o-admissibleml/src/main/java/hex/schemas/InfogramV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
import water.api.SchemaServer;
import water.api.schemas3.KeyV3;
import water.api.schemas3.ModelParametersSchemaV3;
import static hex.util.DistributionUtils.distributionToFamily;

import java.util.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;

import static hex.util.DistributionUtils.distributionToFamily;

public class InfogramV3 extends ModelBuilderSchema<Infogram, InfogramV3, InfogramV3.InfogramParametersV3> {
public static final class InfogramParametersV3 extends ModelParametersSchemaV3<InfogramModel.InfogramParameters, InfogramParametersV3> {
Expand Down Expand Up @@ -134,44 +132,43 @@ public static final class InfogramParametersV3 extends ModelParametersSchemaV3<I
level = API.Level.secondary, gridable=true)
public String[] protected_columns;

@API(help = "A number between 0 and 1 representing a threshold for total information, defaulting to 0.1. " +
@API(help = "A number between 0 and 1 representing a threshold for total information. " +
maurever marked this conversation as resolved.
Show resolved Hide resolved
"For a specific feature, if the total information is higher than this threshold, and the corresponding " +
"net information is also higher than the threshold ``net_information_threshold``, that feature will be " +
"considered admissible. The total information is the x-axis of the Core Infogram. " +
"Default is -1 which gets set to 0.1.",
"considered admissible. The total information is the x-axis of the Core Infogram. ",
level = API.Level.secondary, gridable = true)
public double total_information_threshold;

@API(help = "A number between 0 and 1 representing a threshold for net information, defaulting to 0.1. For a " +
@API(help = "A number between 0 and 1 representing a threshold for net information. For a " +
"specific feature, if the net information is higher than this threshold, and the corresponding total " +
"information is also higher than the total_information_threshold, that feature will be considered admissible. " +
"The net information is the y-axis of the Core Infogram. Default is -1 which gets set to 0.1.",
"The net information is the y-axis of the Core Infogram.",
level = API.Level.secondary, gridable = true)
public double net_information_threshold;

@API(help = "A number between 0 and 1 representing a threshold for the relevance index, defaulting to 0.1. This is " +
@API(help = "A number between 0 and 1 representing a threshold for the relevance index. This is " +
"only used when ``protected_columns`` is set by the user. For a specific feature, if the relevance index " +
"value is higher than this threshold, and the corresponding safety index is also higher than the " +
"safety_index_threshold``, that feature will be considered admissible. The relevance index is the x-axis " +
"of the Fair Infogram. Default is -1 which gets set to 0.1.",
"of the Fair Infogram.",
level = API.Level.secondary, gridable = true)
public double relevance_index_threshold;

@API(help = "A number between 0 and 1 representing a threshold for the safety index, defaulting to 0.1. This is " +
@API(help = "A number between 0 and 1 representing a threshold for the safety index. This is " +
"only used when protected_columns is set by the user. For a specific feature, if the safety index value " +
"is higher than this threshold, and the corresponding relevance index is also higher than the " +
"relevance_index_threshold, that feature will be considered admissible. The safety index is the y-axis of " +
"the Fair Infogram. Default is -1 which gets set to 0.1.",
"the Fair Infogram.",
level = API.Level.secondary, gridable = true)
public double safety_index_threshold;

@API(help = "The fraction of training frame to use to build the infogram model. Defaults to 1.0, and any value greater " +
@API(help = "The fraction of training frame to use to build the infogram model. Any value greater " +
"than 0 and less than or equal to 1.0 is acceptable.",
level = API.Level.secondary, gridable = true)
public double data_fraction;

@API(help = "An integer specifying the number of columns to evaluate in the infogram. The columns are ranked by " +
"variable importance, and the top N are evaluated. Defaults to 50.",
"variable importance, and the top N are evaluated.",
level = API.Level.secondary, gridable = true)
public int top_n_features;

Expand Down
Loading