Skip to content

Commit

Permalink
fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongyoonlee committed Dec 6, 2024
1 parent 5390458 commit 70df9ef
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 40 deletions.
65 changes: 50 additions & 15 deletions causalml/inference/tree/_tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def get_n_leaves(self):
return self.tree_.n_leaves

def _support_missing_values(self, X):
return not issparse(X) and self._get_tags()["allow_nan"] and self.monotonic_cst is None
return (
not issparse(X)
and self._get_tags()["allow_nan"]
and self.monotonic_cst is None
)

def _compute_missing_values_in_feature_mask(self, X, estimator_name=None):
"""Return boolean mask denoting if there are missing values for each feature.
Expand Down Expand Up @@ -242,22 +246,36 @@ def _fit(

# _compute_missing_values_in_feature_mask will check for finite values and
# compute the missing mask if the tree supports missing values
check_X_params = dict(dtype=DTYPE, accept_sparse="csc", force_all_finite=False)
check_X_params = dict(
dtype=DTYPE, accept_sparse="csc", force_all_finite=False
)
check_y_params = dict(ensure_2d=False, dtype=None)
X, y = self._validate_data(X, y, validate_separately=(check_X_params, check_y_params))
X, y = self._validate_data(
X, y, validate_separately=(check_X_params, check_y_params)
)

missing_values_in_feature_mask = self._compute_missing_values_in_feature_mask(X)
missing_values_in_feature_mask = (
self._compute_missing_values_in_feature_mask(X)
)
if issparse(X):
X.sort_indices()

if X.indices.dtype != np.intc or X.indptr.dtype != np.intc:
raise ValueError("No support for np.int64 index based sparse matrices")
raise ValueError(
"No support for np.int64 index based sparse matrices"
)

if self.criterion == "poisson":
if np.any(y < 0):
raise ValueError("Some value(s) of y are negative which is" " not allowed for Poisson regression.")
raise ValueError(
"Some value(s) of y are negative which is"
" not allowed for Poisson regression."
)
if np.sum(y) <= 0:
raise ValueError("Sum of y is not positive which is " "necessary for Poisson regression.")
raise ValueError(
"Sum of y is not positive which is "
"necessary for Poisson regression."
)

# Determine output settings
n_samples, self.n_features_in_ = X.shape
Expand Down Expand Up @@ -291,7 +309,9 @@ def _fit(
y = y_encoded

if self.class_weight is not None:
expanded_class_weight = compute_sample_weight(self.class_weight, y_original)
expanded_class_weight = compute_sample_weight(
self.class_weight, y_original
)

self.n_classes_ = np.array(self.n_classes_, dtype=np.intp)

Expand Down Expand Up @@ -333,7 +353,10 @@ def _fit(
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes

if len(y) != n_samples:
raise ValueError("Number of labels=%d does not match number of samples=%d" % (len(y), n_samples))
raise ValueError(
"Number of labels=%d does not match number of samples=%d"
% (len(y), n_samples)
)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)
Expand All @@ -354,7 +377,9 @@ def _fit(
criterion = self.criterion
if not isinstance(criterion, Criterion):
if is_classification:
criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, self.n_classes_)
criterion = CRITERIA_CLF[self.criterion](
self.n_outputs_, self.n_classes_
)
else:
criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples)
else:
Expand All @@ -369,7 +394,9 @@ def _fit(
monotonic_cst = None
else:
if self.n_outputs_ > 1:
raise ValueError("Monotonicity constraints are not supported with multiple outputs.")
raise ValueError(
"Monotonicity constraints are not supported with multiple outputs."
)
# Check to correct monotonicity constraint' specification,
# by applying element-wise logical conjunction
# Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)`
Expand All @@ -385,12 +412,16 @@ def _fit(
if not np.all(valid_constraints):
unique_constaints_value = np.unique(monotonic_cst)
raise ValueError(
"monotonic_cst must be None or an array-like of -1, 0 or 1, but" f" got {unique_constaints_value}"
"monotonic_cst must be None or an array-like of -1, 0 or 1, but"
f" got {unique_constaints_value}"
)
monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8)
if is_classifier(self):
if self.n_classes_[0] > 2:
raise ValueError("Monotonicity constraints are not supported with multiclass " "classification")
raise ValueError(
"Monotonicity constraints are not supported with multiclass "
"classification"
)
# Binary classification trees are built by constraining probabilities
# of the *negative class* in order to make the implementation similar
# to regression trees.
Expand Down Expand Up @@ -1370,10 +1401,14 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
The value of the partial dependence function on each grid point.
"""
grid = np.asarray(grid, dtype=DTYPE, order="C")
averaged_predictions = np.zeros(shape=grid.shape[0], dtype=np.float64, order="C")
averaged_predictions = np.zeros(
shape=grid.shape[0], dtype=np.float64, order="C"
)
target_features = np.asarray(target_features, dtype=np.intp, order="C")

self.tree_.compute_partial_dependence(grid, target_features, averaged_predictions)
self.tree_.compute_partial_dependence(
grid, target_features, averaged_predictions
)
return averaged_predictions

def _more_tags(self):
Expand Down
110 changes: 85 additions & 25 deletions causalml/inference/tree/_tree/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import numpy as np

from sklearn.base import is_classifier
from sklearn.utils._param_validation import HasMethods, Interval, StrOptions, validate_params
from sklearn.utils._param_validation import (
HasMethods,
Interval,
StrOptions,
validate_params,
)
from sklearn.utils.validation import check_array, check_is_fitted
from . import DecisionTreeClassifier, DecisionTreeRegressor, _criterion, _tree
from ._reingold_tilford import Tree, buchheim
Expand Down Expand Up @@ -251,7 +256,9 @@ def get_color(self, value):
else:
# Regression tree or multi-output
color = list(self.colors["rgb"][0])
alpha = (value - self.colors["bounds"][0]) / (self.colors["bounds"][1] - self.colors["bounds"][0])
alpha = (value - self.colors["bounds"][0]) / (
self.colors["bounds"][1] - self.colors["bounds"][0]
)
# compute the color as alpha against white
color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color]
# Return html color code in #RRGGBB format
Expand All @@ -275,7 +282,11 @@ def get_fill_color(self, tree, node_id):
self.colors["bounds"] = (np.min(tree.value), np.max(tree.value))
if tree.n_outputs == 1:
node_val = tree.value[node_id][0, :]
if tree.n_classes[0] == 1 and isinstance(node_val, Iterable) and self.colors["bounds"] is not None:
if (
tree.n_classes[0] == 1
and isinstance(node_val, Iterable)
and self.colors["bounds"] is not None
):
# Unpack the float only for the regression tree case.
# Classification tree requires an Iterable in `get_color`.
node_val = node_val.item()
Expand Down Expand Up @@ -331,13 +342,17 @@ def node_to_str(self, tree, node_id, criterion):
criterion = "impurity"
if labels:
node_string += "%s = " % criterion
node_string += str(round(tree.impurity[node_id], self.precision)) + characters[4]
node_string += (
str(round(tree.impurity[node_id], self.precision)) + characters[4]
)

# Write node sample count
if labels:
node_string += "samples = "
if self.proportion:
percent = 100.0 * tree.n_node_samples[node_id] / float(tree.n_node_samples[0])
percent = (
100.0 * tree.n_node_samples[node_id] / float(tree.n_node_samples[0])
)
node_string += str(round(percent, 1)) + "%" + characters[4]
else:
node_string += str(tree.n_node_samples[node_id]) + characters[4]
Expand Down Expand Up @@ -369,7 +384,11 @@ def node_to_str(self, tree, node_id, criterion):
node_string += value_text + characters[4]

# Write node majority class
if self.class_names is not None and tree.n_classes[0] != 1 and tree.n_outputs == 1:
if (
self.class_names is not None
and tree.n_classes[0] != 1
and tree.n_outputs == 1
):
# Only done for single-output classification trees
if labels:
node_string += "class = "
Expand Down Expand Up @@ -462,7 +481,9 @@ def tail(self):
# If required, draw leaf nodes at same depth as each other
if self.leaves_parallel:
for rank in sorted(self.ranks):
self.out_file.write("{rank=same ; " + "; ".join(r for r in self.ranks[rank]) + "} ;\n")
self.out_file.write(
"{rank=same ; " + "; ".join(r for r in self.ranks[rank]) + "} ;\n"
)
self.out_file.write("}")

def head(self):
Expand All @@ -476,7 +497,9 @@ def head(self):
if self.rounded:
rounded_filled.append("rounded")
if len(rounded_filled) > 0:
self.out_file.write(', style="%s", color="black"' % ", ".join(rounded_filled))
self.out_file.write(
', style="%s", color="black"' % ", ".join(rounded_filled)
)

self.out_file.write(', fontname="%s"' % self.fontname)
self.out_file.write("] ;\n")
Expand Down Expand Up @@ -507,10 +530,14 @@ def recurse(self, tree, node_id, criterion, parent=None, depth=0):
else:
self.ranks[str(depth)].append(str(node_id))

self.out_file.write("%d [label=%s" % (node_id, self.node_to_str(tree, node_id, criterion)))
self.out_file.write(
"%d [label=%s" % (node_id, self.node_to_str(tree, node_id, criterion))
)

if self.filled:
self.out_file.write(', fillcolor="%s"' % self.get_fill_color(tree, node_id))
self.out_file.write(
', fillcolor="%s"' % self.get_fill_color(tree, node_id)
)
self.out_file.write("] ;\n")

if parent is not None:
Expand Down Expand Up @@ -601,10 +628,16 @@ def _make_tree(self, node_id, et, criterion, depth=0):
# traverses _tree.Tree recursively, builds intermediate
# "_reingold_tilford.Tree" object
name = self.node_to_str(et, node_id, criterion=criterion)
if et.children_left[node_id] != _tree.TREE_LEAF and (self.max_depth is None or depth <= self.max_depth):
if et.children_left[node_id] != _tree.TREE_LEAF and (
self.max_depth is None or depth <= self.max_depth
):
children = [
self._make_tree(et.children_left[node_id], et, criterion, depth=depth + 1),
self._make_tree(et.children_right[node_id], et, criterion, depth=depth + 1),
self._make_tree(
et.children_left[node_id], et, criterion, depth=depth + 1
),
self._make_tree(
et.children_right[node_id], et, criterion, depth=depth + 1
),
]
else:
return Tree(name, node_id)
Expand Down Expand Up @@ -646,12 +679,16 @@ def export(self, decision_tree, ax=None):
# adjust fontsize to avoid overlap
# get max box width and height
extents = [
bbox_patch.get_window_extent() for ann in anns if (bbox_patch := ann.get_bbox_patch()) is not None
bbox_patch.get_window_extent()
for ann in anns
if (bbox_patch := ann.get_bbox_patch()) is not None
]
max_width = max([extent.width for extent in extents])
max_height = max([extent.height for extent in extents])
# width should be around scale_x in axis coordinates
size = anns[0].get_fontsize() * min(scale_x / max_width, scale_y / max_height)
size = anns[0].get_fontsize() * min(
scale_x / max_width, scale_y / max_height
)
for ann in anns:
ann.set_fontsize(size)

Expand Down Expand Up @@ -863,9 +900,13 @@ def export_graphviz(
'digraph Tree {...
"""
if feature_names is not None:
feature_names = check_array(feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0)
feature_names = check_array(
feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
)
if class_names is not None and not isinstance(class_names, bool):
class_names = check_array(class_names, ensure_2d=False, dtype=None, ensure_min_samples=0)
class_names = check_array(
class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
)

check_is_fitted(decision_tree)
own_file = False
Expand Down Expand Up @@ -911,13 +952,19 @@ def _compute_depth(tree, node):
Returns the depth of the subtree rooted in node.
"""

def compute_depth_(current_node, current_depth, children_left, children_right, depths):
def compute_depth_(
current_node, current_depth, children_left, children_right, depths
):
depths += [current_depth]
left = children_left[current_node]
right = children_right[current_node]
if left != -1 and right != -1:
compute_depth_(left, current_depth + 1, children_left, children_right, depths)
compute_depth_(right, current_depth + 1, children_left, children_right, depths)
compute_depth_(
left, current_depth + 1, children_left, children_right, depths
)
compute_depth_(
right, current_depth + 1, children_left, children_right, depths
)

depths = []
compute_depth_(node, 1, tree.children_left, tree.children_right, depths)
Expand Down Expand Up @@ -1013,9 +1060,13 @@ def export_text(
| | |--- class: 2
"""
if feature_names is not None:
feature_names = check_array(feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0)
feature_names = check_array(
feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
)
if class_names is not None:
class_names = check_array(class_names, ensure_2d=False, dtype=None, ensure_min_samples=0)
class_names = check_array(
class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
)

check_is_fitted(decision_tree)
tree_ = decision_tree.tree_
Expand All @@ -1034,7 +1085,10 @@ def export_text(
truncation_fmt = "{} {}\n"

if feature_names is not None and len(feature_names) != tree_.n_features:
raise ValueError("feature_names must contain %d elements, got %d" % (tree_.n_features, len(feature_names)))
raise ValueError(
"feature_names must contain %d elements, got %d"
% (tree_.n_features, len(feature_names))
)

if isinstance(decision_tree, DecisionTreeClassifier):
value_fmt = "{}{} weights: {}\n"
Expand All @@ -1044,7 +1098,10 @@ def export_text(
value_fmt = "{}{} value: {}\n"

if feature_names is not None:
feature_names_ = [feature_names[i] if i != _tree.TREE_UNDEFINED else None for i in tree_.feature]
feature_names_ = [
feature_names[i] if i != _tree.TREE_UNDEFINED else None
for i in tree_.feature
]
else:
feature_names_ = ["feature_{}".format(i) for i in tree_.feature]

Expand All @@ -1054,7 +1111,10 @@ def _add_leaf(value, weighted_n_node_samples, class_name, indent):
val = ""
if isinstance(decision_tree, DecisionTreeClassifier):
if show_weights:
val = ["{1:.{0}f}, ".format(decimals, v * weighted_n_node_samples) for v in value]
val = [
"{1:.{0}f}, ".format(decimals, v * weighted_n_node_samples)
for v in value
]
val = "[" + "".join(val)[:-2] + "]"
weighted_n_node_samples
val += " class: " + str(class_name)
Expand Down

0 comments on commit 70df9ef

Please sign in to comment.