diff --git a/causalml/inference/tree/causal/_builder.pyx b/causalml/inference/tree/causal/_builder.pyx index c2568f2b..ae8430ef 100644 --- a/causalml/inference/tree/causal/_builder.pyx +++ b/causalml/inference/tree/causal/_builder.pyx @@ -190,9 +190,15 @@ cdef class DepthFirstCausalTreeBuilder(TreeBuilder): is_leaf = (is_leaf or split.pos >= end or (split.improvement + EPSILON < min_impurity_decrease)) - node_id = tree._add_node(parent, is_left, is_leaf, split.feature, + IF SKLEARN_VERSION < 13: + node_id = tree._add_node(parent, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, weighted_n_node_samples) + ELSE: + node_id = tree._add_node(parent, is_left, is_leaf, split.feature, + split.threshold, impurity, n_node_samples, + weighted_n_node_samples, + split.missing_go_to_left) if node_id == SIZE_MAX: rc = -1 @@ -459,12 +465,22 @@ cdef class BestFirstCausalTreeBuilder(TreeBuilder): is_leaf = (is_leaf or split.pos >= end or split.improvement + EPSILON < min_impurity_decrease) - node_id = tree._add_node(parent - tree.nodes + IF SKLEARN_VERSION < 13: + node_id = tree._add_node(parent - tree.nodes if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, weighted_n_node_samples) + ELSE: + node_id = tree._add_node(parent - tree.nodes + if parent != NULL + else _TREE_UNDEFINED, + is_left, is_leaf, + split.feature, split.threshold, impurity, n_node_samples, + weighted_n_node_samples, + split.missing_go_to_left) + if node_id == SIZE_MAX: return -1 diff --git a/pyproject.toml b/pyproject.toml index 7046161b..c521c550 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "scipy>=1.4.1", "matplotlib", "pandas>=0.24.1", - "scikit-learn>=1.0.0", + "scikit-learn>=1.0.0,<1.4", "statsmodels>=0.9.0", "seaborn", "xgboost", @@ -62,7 +62,7 @@ requires = [ "wheel", "Cython<=0.29.36", "numpy", - "scikit-learn<=1.3.1", + "scikit-learn<1.4", ] [project.urls]