diff --git a/python/src/spark_rapids_ml/classification.py b/python/src/spark_rapids_ml/classification.py index 6e366a07..44ae7d8b 100644 --- a/python/src/spark_rapids_ml/classification.py +++ b/python/src/spark_rapids_ml/classification.py @@ -396,9 +396,9 @@ class RandomForestClassifier( * ``4 or False`` - Enables all messages up to and including information messages. * ``5 or True`` - Enables all messages up to and including debug messages. * ``6`` - Enables all messages up to and including trace messages. - n_streams: int (default = 1) + n_streams: int (default = 4) Number of parallel streams used for forest building. - Please note that there is a bug running spark-rapids-ml on a node with multi-gpus + Please note that there could be a bug running spark-rapids-ml on a node with multi-gpus when n_streams > 1. See https://github.com/rapidsai/cuml/issues/5402. min_samples_split: int or float (default = 2) The minimum number of samples required to split an internal node.\n diff --git a/python/src/spark_rapids_ml/clustering.py b/python/src/spark_rapids_ml/clustering.py index 28d38048..446ade99 100644 --- a/python/src/spark_rapids_ml/clustering.py +++ b/python/src/spark_rapids_ml/clustering.py @@ -769,7 +769,6 @@ def __init__( assert max_records_per_batch_str is not None self.max_records_per_batch = int(max_records_per_batch_str) self.BROADCAST_LIMIT = 8 << 30 - self.verbose = verbose self.cuml_params["calc_core_sample_indices"] = False # currently not supported def setEps(self: P, value: float) -> P: @@ -810,7 +809,7 @@ def _fit(self, dataset: DataFrame) -> _CumlModel: # Create parameter-copied model without accessing the input dataframe # All information will be retrieved from Model and transform - model = DBSCANModel(verbose=self.verbose, n_cols=0, dtype="") + model = DBSCANModel(n_cols=0, dtype="") model._num_workers = self.num_workers self._copyValues(model) @@ -842,7 +841,6 @@ def __init__( self, n_cols: int, dtype: str, - verbose: Union[int, bool], ): super(DBSCANClass, self).__init__() super(_CumlModelWithPredictionCol, self).__init__(n_cols=n_cols, dtype=dtype) @@ -852,7 +850,6 @@ def __init__( idCol=alias.row_number, ) - self.verbose = verbose self.BROADCAST_LIMIT = 8 << 30 self._dbscan_spark_model = None @@ -1069,12 +1066,3 @@ def _chunk_arr( # JOIN the transformed label column into the original input dataset # and discard the internal idCol for row matching return dataset.join(pred_df, idCol_name).drop(idCol_name) - - def _get_model_attributes(self) -> Optional[Dict[str, Any]]: - """ - Override parent method to bring broadcast variables to driver before JSON serialization. - """ - - self._model_attributes["verbose"] = self.verbose - - return self._model_attributes diff --git a/python/src/spark_rapids_ml/params.py b/python/src/spark_rapids_ml/params.py index 156d7279..b775b79d 100644 --- a/python/src/spark_rapids_ml/params.py +++ b/python/src/spark_rapids_ml/params.py @@ -128,6 +128,36 @@ def _ensureIdCol(self, df: DataFrame) -> DataFrame: return df_withid +class VerboseTypeConverters(TypeConverters): + @staticmethod + def _toIntOrBool(value: Any) -> Union[int, bool]: + if isinstance(value, bool): + return value + + if TypeConverters._is_integer(value): + return int(value) + + raise TypeError("Could not convert %s to Union[int, bool]" % value) + + +class HasVerboseParam(Params): + """ + Parameter to enable displaying verbose messages from cuml. + Refer to the cuML documentation for details on verbosity levels. + """ + + verbose: "Param[Union[int, bool]]" = Param( + Params._dummy(), + "verbose", + "cuml verbosity level (False, True or an integer between 0 and 6).", + typeConverter=VerboseTypeConverters._toIntOrBool, + ) + + def __init__(self) -> None: + super().__init__() + self._setDefault(verbose=False) + + class _CumlClass(object): """ Base class for all _CumlEstimator and _CumlModel implemenations. @@ -215,7 +245,7 @@ def _get_cuml_params_default(self) -> Dict[str, Any]: raise NotImplementedError() -class _CumlParams(_CumlClass, Params): +class _CumlParams(_CumlClass, HasVerboseParam, Params): """ Mix-in to handle common parameters for all Spark Rapids ML algorithms, along with utilties for synchronizing between Spark ML Params and cuML class parameters. @@ -269,25 +299,46 @@ def num_workers(self, value: int) -> None: self._num_workers = value def copy(self: P, extra: Optional["ParamMap"] = None) -> P: + """ + Create a copy of the current instance, including its parameters and cuml_params. + + This function extends the default `copy()` method to ensure the `cuml_params` variable + is also copied. The default `super().copy()` method only handles `_paramMap` and + `_defaultParamMap`. + + Parameters + ----------- + extra : Optional[ParamMap] + A dictionary or ParamMap containing additional parameters to set in the copied instance. + Note ParamMap = Dict[pyspark.ml.param.Param, Any]. + + Returns + -------- + P + A new instance of the same type as the current object, with parameters and + cuml_params copied. + + Raises + ------- + TypeError + If any key in the `extra` dictionary is not an instance of `pyspark.ml.param.Param`. + """ # override this function to update cuml_params if possible instance: P = super().copy(extra) cuml_params = instance.cuml_params.copy() + instance._cuml_params = cuml_params if isinstance(extra, dict): for param, value in extra.items(): if isinstance(param, Param): - name = instance._get_cuml_param(param.name, silent=False) - if name is not None: - cuml_params[name] = instance._get_cuml_mapping_value( - name, value - ) + instance._set_params(**{param.name: value}) else: raise TypeError( "Expecting a valid instance of Param, but received: {}".format( param ) ) - instance._cuml_params = cuml_params + return instance def _initialize_cuml_params(self) -> None: diff --git a/python/src/spark_rapids_ml/tree.py b/python/src/spark_rapids_ml/tree.py index de0d2f69..8b9e7cce 100644 --- a/python/src/spark_rapids_ml/tree.py +++ b/python/src/spark_rapids_ml/tree.py @@ -38,7 +38,13 @@ RandomForestClassificationModel as SparkRandomForestClassificationModel, ) from pyspark.ml.linalg import Vector -from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol +from pyspark.ml.param.shared import ( + HasFeaturesCol, + HasLabelCol, + Param, + Params, + TypeConverters, +) from pyspark.ml.regression import DecisionTreeRegressionModel from pyspark.ml.regression import ( RandomForestRegressionModel as SparkRandomForestRegressionModel, @@ -148,8 +154,57 @@ class _RandomForestCumlParams( HasFeaturesCols, HasLabelCol, ): + + n_streams = Param( + Params._dummy(), + "n_streams", + "The n_streams parameter to use for cuml.", + typeConverter=TypeConverters.toInt, + ) + + min_samples_split = Param( + Params._dummy(), + "min_samples_split", + "The min_sample_split parameter to use for cuml.", + typeConverter=TypeConverters.toInt, + ) + + max_samples = Param( + Params._dummy(), + "max_samples", + "The max_samples parameter to use for cuml.", + typeConverter=TypeConverters.toFloat, + ) + + max_leaves = Param( + Params._dummy(), + "max_leaves", + "The max_leaves parameter to use for cuml.", + typeConverter=TypeConverters.toInt, + ) + + min_impurity_decrease = Param( + Params._dummy(), + "min_impurity_decrease", + "The min_impurity_decrease parameter to use for cuml.", + typeConverter=TypeConverters.toFloat, + ) + + max_batch_size = Param( + Params._dummy(), + "max_batch_size", + "The max_batch_size parameter to use for cuml.", + typeConverter=TypeConverters.toInt, + ) + def __init__(self) -> None: super().__init__() + self._setDefault(n_streams=4) + self._setDefault(min_samples_split=2) + self._setDefault(max_samples=1.0) + self._setDefault(max_leaves=-1) + self._setDefault(min_impurity_decrease=0.0) + self._setDefault(max_batch_size=4096) # restrict default seed to max value of 32-bit signed integer for CuML self._setDefault(seed=hash(type(self).__name__) & 0x07FFFFFFF) diff --git a/python/tests/test_approximate_nearest_neighbors.py b/python/tests/test_approximate_nearest_neighbors.py index 29744451..f3bb8fd8 100644 --- a/python/tests/test_approximate_nearest_neighbors.py +++ b/python/tests/test_approximate_nearest_neighbors.py @@ -110,6 +110,24 @@ def test_params(default_params: bool) -> None: _test_input_setter_getter(ApproximateNearestNeighbors) +def test_ann_copy() -> None: + from .test_common_estimator import _test_est_copy + + param_list: List[Tuple[Dict[str, Any], Optional[Dict[str, Any]]]] = [ + ({"k": 38}, {"n_neighbors": 38}), + ({"algorithm": "cagra"}, {"algorithm": "cagra"}), + ({"metric": "cosine"}, {"metric": "cosine"}), + ( + {"algoParams": {"nlist": 999, "nprobe": 11}}, + {"algo_params": {"nlist": 999, "nprobe": 11}}, + ), + ({"verbose": True}, {"verbose": True}), + ] + + for pair in param_list: + _test_est_copy(ApproximateNearestNeighbors, pair[0], pair[1]) + + def test_search_index_params() -> None: # test cagra index params and search params cagra_index_param: Dict[str, Any] = { diff --git a/python/tests/test_common_estimator.py b/python/tests/test_common_estimator.py index 80b7c287..a730d93d 100644 --- a/python/tests/test_common_estimator.py +++ b/python/tests/test_common_estimator.py @@ -15,7 +15,7 @@ # from abc import ABCMeta -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np import pandas as pd @@ -345,6 +345,38 @@ def _test_input_setter_getter(est_class: Any) -> None: ] +def _test_est_copy( + Estimator: Type[_CumlEstimator], + input_spark_params: Dict[str, Any], + cuml_params_update: Optional[Dict[str, Any]], +) -> None: + """ + This tests the copy() function of an estimator object. + For Spark-specific parameters (e.g. enable_sparse_data_optim in LogisticRegression), set cuml_params_update to None. + """ + + est = Estimator() + copy_params = {getattr(est, p): input_spark_params[p] for p in input_spark_params} + est_copy = est.copy(copy_params) + + # handle Spark-Rapids-ML-only params + if cuml_params_update is None: + for param in input_spark_params: + assert est_copy.getOrDefault(param) == input_spark_params[param] + return + + res_cuml_params = est.cuml_params.copy() + res_cuml_params.update(cuml_params_update) + assert ( + est.cuml_params != res_cuml_params + ), "please modify cuml_params_update because it does not change the default estimator.cuml_params" + assert est_copy.cuml_params == res_cuml_params + + # test init function + est_init = Estimator(**input_spark_params) + assert est_init.cuml_params == res_cuml_params + + def test_default_cuml_params() -> None: cuml_params = get_default_cuml_parameters([CumlDummy], ["b"]) spark_params = SparkRapidsMLDummy()._get_cuml_params_default() diff --git a/python/tests/test_dbscan.py b/python/tests/test_dbscan.py index fb1c4e12..e7e0d46c 100644 --- a/python/tests/test_dbscan.py +++ b/python/tests/test_dbscan.py @@ -108,6 +108,21 @@ def test_params( _test_input_setter_getter(DBSCAN) +def test_dbscan_copy() -> None: + from .test_common_estimator import _test_est_copy + + param_list: List[Dict[str, Any]] = [ + {"eps": 0.7}, + {"min_samples": 10}, + {"metric": "cosine"}, + {"algorithm": "rbc"}, + {"max_mbytes_per_batch": 1000}, + {"verbose": True}, + ] + for param in param_list: + _test_est_copy(DBSCAN, param, param) + + def test_dbscan_basic( gpu_number: int, tmp_path: str, caplog: LogCaptureFixture ) -> None: diff --git a/python/tests/test_kmeans.py b/python/tests/test_kmeans.py index 1c7f88a5..31618993 100644 --- a/python/tests/test_kmeans.py +++ b/python/tests/test_kmeans.py @@ -14,7 +14,7 @@ # limitations under the License. # -from typing import Any, Dict, List, Tuple, Type, TypeVar +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar import numpy as np import pyspark @@ -183,6 +183,22 @@ def test_kmeans_params( assert not kmeans_float32._float32_inputs +def test_kmeans_copy() -> None: + from .test_common_estimator import _test_est_copy + + param_list: List[Tuple[Dict[str, Any], Optional[Dict[str, Any]]]] = [ + ({"k": 17}, {"n_clusters": 17}), + ({"initMode": "random"}, {"init": "random"}), + ({"tol": 0.0132}, {"tol": 0.0132}), + ({"maxIter": 27}, {"max_iter": 27}), + ({"seed": 11}, {"random_state": 11}), + ({"verbose": True}, {"verbose": True}), + ] + + for pair in param_list: + _test_est_copy(KMeans, pair[0], pair[1]) + + def test_kmeans_basic( gpu_number: int, tmp_path: str, caplog: LogCaptureFixture ) -> None: diff --git a/python/tests/test_linear_model.py b/python/tests/test_linear_model.py index 5585ddae..138f4073 100644 --- a/python/tests/test_linear_model.py +++ b/python/tests/test_linear_model.py @@ -14,7 +14,7 @@ # limitations under the License. # import warnings -from typing import Any, Dict, List, Tuple, Type, TypeVar, cast +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, cast import numpy as np import pyspark @@ -218,6 +218,25 @@ def test_linear_regression_params( assert not lr_float32._float32_inputs +def test_linear_regression_copy() -> None: + from .test_common_estimator import _test_est_copy + + # solver supports 'auto', 'normal' and 'eig', but all of them will be mapped to 'eig' in cuML. + # loss supports 'squaredError' only, + param_list: List[Tuple[Dict[str, Any], Optional[Dict[str, Any]]]] = [ + ({"maxIter": 29}, {"max_iter": 29}), + ({"regParam": 0.12}, {"alpha": 0.12}), + ({"elasticNetParam": 0.23}, {"l1_ratio": 0.23}), + ({"fitIntercept": False}, {"fit_intercept": False}), + ({"standardization": False}, {"normalize": False}), + ({"tol": 0.0132}, {"tol": 0.0132}), + ({"verbose": True}, {"verbose": True}), + ] + + for pair in param_list: + _test_est_copy(LinearRegression, pair[0], pair[1]) + + @pytest.mark.parametrize("data_type", ["byte", "short", "int", "long"]) def test_linear_regression_numeric_type(gpu_number: int, data_type: str) -> None: # reduce the number of GPUs for toy dataset to avoid empty partition diff --git a/python/tests/test_logistic_regression.py b/python/tests/test_logistic_regression.py index a822d8ac..1370cb85 100644 --- a/python/tests/test_logistic_regression.py +++ b/python/tests/test_logistic_regression.py @@ -60,7 +60,7 @@ from scipy.sparse import csr_matrix from spark_rapids_ml.classification import LogisticRegression, LogisticRegressionModel -from spark_rapids_ml.core import _use_sparse_in_cuml, alias +from spark_rapids_ml.core import _CumlEstimator, _use_sparse_in_cuml, alias from spark_rapids_ml.tuning import CrossValidator from .sparksession import CleanSparkSession @@ -329,6 +329,56 @@ def test_params(tmp_path: str, caplog: LogCaptureFixture) -> None: _test_input_setter_getter(LogisticRegression) +@pytest.mark.parametrize( + "input_spark_params,cuml_params_update", + [ + ( + {"regParam": 0.1, "elasticNetParam": 0.5}, + {"penalty": "elasticnet", "C": 10.0, "l1_ratio": 0.5}, + ), + ( + {"maxIter": 13}, + {"max_iter": 13}, + ), + ( + {"regParam": 0.25, "elasticNetParam": 0.0}, + {"penalty": "l2", "C": 4.0, "l1_ratio": 0.0}, + ), + ( + {"regParam": 0.2, "elasticNetParam": 1.0}, + {"penalty": "l1", "C": 5.0, "l1_ratio": 1.0}, + ), + ( + {"tol": 1e-3}, + {"tol": 1e-3}, + ), + ( + {"fitIntercept": False}, + {"fit_intercept": False}, + ), + ( + {"standardization": False}, + {"standardization": False}, + ), + ( + {"enable_sparse_data_optim": True}, + None, + ), + ( + {"verbose": True}, + {"verbose": True}, + ), + ], +) +def test_lr_copy( + input_spark_params: Dict[str, Any], + cuml_params_update: Optional[Dict[str, Any]], +) -> None: + from .test_common_estimator import _test_est_copy + + _test_est_copy(LogisticRegression, input_spark_params, cuml_params_update) + + @pytest.mark.parametrize("fit_intercept", [True, False]) @pytest.mark.parametrize("feature_type", ["array", "multi_cols", "vector"]) @pytest.mark.parametrize("data_shape", [(2000, 8)], ids=idfn) diff --git a/python/tests/test_nearest_neighbors.py b/python/tests/test_nearest_neighbors.py index f81c6fa7..b738bc67 100644 --- a/python/tests/test_nearest_neighbors.py +++ b/python/tests/test_nearest_neighbors.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -94,6 +94,20 @@ def test_params(default_params: bool, caplog: LogCaptureFixture) -> None: _test_input_setter_getter(NearestNeighbors) +def test_knn_copy() -> None: + from .test_common_estimator import _test_est_copy + + param_list: List[Tuple[Dict[str, Any], Optional[Dict[str, Any]]]] = [ + ({"k": 37}, {"n_neighbors": 37}), + ({"verbose": True}, {"verbose": True}), + ] + + for pair in param_list: + spark_param = pair[0] + cuml_param = spark_param if len(pair) == 1 else pair[1] + _test_est_copy(NearestNeighbors, spark_param, cuml_param) + + def func_test_example_no_id( tmp_path: str, gpu_knn: NNEstimator ) -> Tuple[NNEstimator, NNModel]: diff --git a/python/tests/test_pca.py b/python/tests/test_pca.py index 46ec0f2d..e7647940 100644 --- a/python/tests/test_pca.py +++ b/python/tests/test_pca.py @@ -100,6 +100,20 @@ def test_params(default_params: bool, caplog: LogCaptureFixture) -> None: _test_input_setter_getter(PCA) +def test_pca_copy() -> None: + from .test_common_estimator import _test_est_copy + + param_list = [ + ({"k": 42}, {"n_components": 42}), + ({"verbose": 42},), + ] + + for pair in param_list: + spark_param = pair[0] + cuml_param = spark_param if len(pair) == 1 else pair[1] + _test_est_copy(PCA, spark_param, cuml_param) + + def test_fit(gpu_number: int) -> None: # reduce the number of GPUs for toy dataset to avoid empty partition gpu_number = min(gpu_number, 2) diff --git a/python/tests/test_random_forest.py b/python/tests/test_random_forest.py index 92bc434f..1b7f885d 100644 --- a/python/tests/test_random_forest.py +++ b/python/tests/test_random_forest.py @@ -183,6 +183,41 @@ def test_params(default_params: bool, Estimator: RandomForest) -> None: _test_input_setter_getter(Estimator) +def test_rf_copy() -> None: + from .test_common_estimator import _test_est_copy + + param_list: List[Tuple[Dict[str, Any], Optional[Dict[str, Any]]]] = [ + ({"maxDepth": 51}, {"max_depth": 51}), + ({"maxBins": 61}, {"n_bins": 61}), + ({"minInstancesPerNode": 63}, {"min_samples_leaf": 63}), + ({"numTrees": 56}, {"n_estimators": 56}), + ({"featureSubsetStrategy": "onethird"}, {"max_features": 1.0 / 3.0}), + ({"seed": 21}, {"random_state": 21}), + ({"bootstrap": False}, {"bootstrap": False}), + ] + + cuml_specific_params: List[Dict[str, Any]] = [ + {"n_streams": 2}, + {"min_samples_split": 19}, + {"max_samples": 0.77}, + {"max_leaves": 72}, + {"min_impurity_decrease": 0.03}, + {"max_batch_size": 1025}, + {"verbose": True}, + ] + + param_list += [(p, p) for p in cuml_specific_params] + + for pair in param_list: + _test_est_copy(RandomForestClassifier, pair[0], pair[1]) + _test_est_copy(RandomForestRegressor, pair[0], pair[1]) + + # RandomForestRegressor supports impurity="variance" only + _test_est_copy( + RandomForestClassifier, {"impurity": "entropy"}, {"split_criterion": "entropy"} + ) + + @pytest.mark.parametrize("RFEstimator", [RandomForestClassifier, RandomForestRegressor]) def test_random_forest_params( tmp_path: str, RFEstimator: RandomForest, caplog: LogCaptureFixture diff --git a/python/tests/test_umap.py b/python/tests/test_umap.py index 2e93d8ab..81a2ffe0 100644 --- a/python/tests/test_umap.py +++ b/python/tests/test_umap.py @@ -15,7 +15,7 @@ # import math -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import cupy as cp import numpy as np @@ -369,6 +369,45 @@ def test_params(tmp_path: str, default_params: bool) -> None: _test_input_setter_getter(UMAP) +def test_umap_copy() -> None: + from .test_common_estimator import _test_est_copy + + param_list: List[Tuple[Dict[str, Any], Optional[Dict[str, Any]]]] = [ + ({"n_neighbors": 21}, {"n_neighbors": 21}), + ({"n_components": 23}, {"n_components": 23}), + ({"metric": "cosine"}, {"metric": "cosine"}), + ({"metric_kwds": {"p": 5}}, {"metric_kwds": {"p": 5}}), + ({"n_epochs": 132}, {"n_epochs": 132}), + ({"learning_rate": 0.19}, {"learning_rate": 0.19}), + ({"init": "random"}, {"init": "random"}), + ({"min_dist": 0.24}, {"min_dist": 0.24}), + ({"spread": 0.24}, {"spread": 0.24}), + ({"set_op_mix_ratio": 0.94}, {"set_op_mix_ratio": 0.94}), + ({"local_connectivity": 0.98}, {"local_connectivity": 0.98}), + ({"repulsion_strength": 0.99}, {"repulsion_strength": 0.99}), + ({"negative_sample_rate": 7}, {"negative_sample_rate": 7}), + ({"transform_queue_size": 0.77}, {"transform_queue_size": 0.77}), + ({"a": 1.77}, {"a": 1.77}), + ({"b": 2.77}, {"b": 2.77}), + ({"precomputed_knn": [[0.1, 0.2]]}, {"precomputed_knn": [[0.1, 0.2]]}), + ( + {"random_state": 81}, + {"random_state": 81}, + ), + ({"build_algo": "nn_descent"}, {"build_algo": "nn_descent"}), + ( + {"build_kwds": {"nnd_graph_degree": 117}}, + {"build_kwds": {"nnd_graph_degree": 117}}, + ), + ({"sample_fraction": 0.74}, None), + ({"enable_sparse_data_optim": True}, None), + ({"verbose": True}, {"verbose": True}), + ] + + for params in param_list: + _test_est_copy(UMAP, params[0], params[1]) + + @pytest.mark.parametrize("sparse_fit", [True, False]) def test_umap_model_persistence( sparse_fit: bool, gpu_number: int, tmp_path: str