Skip to content

Commit

Permalink
Add params missing_go_to_left to Tree._add_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperBo committed Oct 17, 2023
1 parent b3e149b commit 3c7c0e6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
20 changes: 18 additions & 2 deletions causalml/inference/tree/causal/_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,15 @@ cdef class DepthFirstCausalTreeBuilder(TreeBuilder):
is_leaf = (is_leaf or split.pos >= end or
(split.improvement + EPSILON < min_impurity_decrease))

node_id = tree._add_node(parent, is_left, is_leaf, split.feature,
IF SKLEARN_VERSION < 13:
node_id = tree._add_node(parent, is_left, is_leaf, split.feature,
split.threshold, impurity, n_node_samples,
weighted_n_node_samples)
ELSE:
node_id = tree._add_node(parent, is_left, is_leaf, split.feature,
split.threshold, impurity, n_node_samples,
weighted_n_node_samples,
split.missing_go_to_left)

if node_id == SIZE_MAX:
rc = -1
Expand Down Expand Up @@ -459,12 +465,22 @@ cdef class BestFirstCausalTreeBuilder(TreeBuilder):
is_leaf = (is_leaf or split.pos >= end or
split.improvement + EPSILON < min_impurity_decrease)

node_id = tree._add_node(parent - tree.nodes
IF SKLEARN_VERSION < 13:
node_id = tree._add_node(parent - tree.nodes
if parent != NULL
else _TREE_UNDEFINED,
is_left, is_leaf,
split.feature, split.threshold, impurity, n_node_samples,
weighted_n_node_samples)
ELSE:
node_id = tree._add_node(parent - tree.nodes
if parent != NULL
else _TREE_UNDEFINED,
is_left, is_leaf,
split.feature, split.threshold, impurity, n_node_samples,
weighted_n_node_samples,
split.missing_go_to_left)

if node_id == SIZE_MAX:
return -1

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"scipy>=1.4.1",
"matplotlib",
"pandas>=0.24.1",
"scikit-learn>=1.0.0",
"scikit-learn>=1.0.0,<1.4",
"statsmodels>=0.9.0",
"seaborn",
"xgboost",
Expand Down Expand Up @@ -62,7 +62,7 @@ requires = [
"wheel",
"Cython<=0.29.36",
"numpy",
"scikit-learn<=1.3.1",
"scikit-learn<1.4",
]

[project.urls]
Expand Down

0 comments on commit 3c7c0e6

Please sign in to comment.