Skip to content

Commit

Permalink
Implement a wrapper class with custom fit, predict and predict_proba …
Browse files Browse the repository at this point in the history
…for XGBClassifier.
  • Loading branch information
bojan-karlas committed Feb 24, 2024
1 parent f9fd86e commit a798974
Showing 1 changed file with 105 additions and 2 deletions.
107 changes: 105 additions & 2 deletions experiments/datascope/experiments/pipelines/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,23 @@
from enum import Enum
from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from pandas import DataFrame, Series
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import RandomForestClassifier
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC, LinearSVC
from sklearn.utils.multiclass import unique_labels
from transformers import AutoImageProcessor, ResNetForImageClassification, TrainingArguments, Trainer
from typing import Dict
from xgboost import XGBClassifier
from typing import Dict, Optional, Union, Sequence, List, Callable, Any
from xgboost import XGBClassifier as XGBClassifierOriginal


from datascope.importance.common import SklearnModel
from ..baselines.matchingnet import resnet12, default_transform, MatchingNetworks


Expand Down Expand Up @@ -55,6 +58,106 @@ class ModelType(str, Enum):
}


class XGBClassifier(SklearnModel, BaseEstimator, ClassifierMixin):
model: XGBClassifierOriginal

def __init__(
self,
max_depth: Optional[int] = None,
max_leaves: Optional[int] = None,
max_bin: Optional[int] = None,
grow_policy: Optional[str] = None,
learning_rate: Optional[float] = None,
n_estimators: int = 100,
booster: Optional[str] = None,
tree_method: Optional[str] = None,
gamma: Optional[float] = None,
min_child_weight: Optional[float] = None,
max_delta_step: Optional[float] = None,
subsample: Optional[float] = None,
sampling_method: Optional[str] = None,
colsample_bytree: Optional[float] = None,
colsample_bylevel: Optional[float] = None,
colsample_bynode: Optional[float] = None,
reg_alpha: Optional[float] = None,
reg_lambda: Optional[float] = None,
scale_pos_weight: Optional[float] = None,
base_score: Optional[float] = None,
random_state: Optional[Union[np.random.RandomState, int]] = None,
missing: float = np.nan,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
interaction_constraints: Optional[Union[str, Sequence[Sequence[str]]]] = None,
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
max_cat_to_onehot: Optional[int] = None,
max_cat_threshold: Optional[int] = None,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any,
) -> None:
self.model = XGBClassifierOriginal(
max_depth=max_depth,
max_leaves=max_leaves,
max_bin=max_bin,
grow_policy=grow_policy,
learning_rate=learning_rate,
n_estimators=n_estimators,
booster=booster,
tree_method=tree_method,
gamma=gamma,
min_child_weight=min_child_weight,
max_delta_step=max_delta_step,
subsample=subsample,
sampling_method=sampling_method,
colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight,
base_score=base_score,
random_state=random_state,
missing=missing,
num_parallel_tree=num_parallel_tree,
monotone_constraints=monotone_constraints,
interaction_constraints=interaction_constraints,
importance_type=importance_type,
gpu_id=gpu_id,
validate_parameters=validate_parameters,
predictor=predictor,
enable_categorical=enable_categorical,
max_cat_to_onehot=max_cat_to_onehot,
max_cat_threshold=max_cat_threshold,
eval_metric=eval_metric,
early_stopping_rounds=early_stopping_rounds,
**kwargs,
)
self.label_encoder = LabelEncoder()

def fit(
self, X: Union[NDArray, DataFrame], y: Union[NDArray, Series], sample_weight: Optional[NDArray] = None
) -> None:
y = self.label_encoder.fit_transform(y)
self.classes_ = self.label_encoder.classes_
if isinstance(X, DataFrame):
X = X.to_numpy()
self.model.fit(X, y, sample_weight=sample_weight)

def predict(self, X: Union[NDArray, DataFrame]) -> NDArray:
if isinstance(X, DataFrame):
X = X.to_numpy()
return self.model.predict(X)

def predict_proba(self, X: Union[NDArray, DataFrame]) -> NDArray:
if isinstance(X, DataFrame):
X = X.to_numpy()
return self.model.predict_proba(X)


class TorchDataset(torch.utils.data.Dataset):
def __init__(self, X: NDArray, y: NDArray, feature_extractor=None):
if X.ndim == 3:
Expand Down

0 comments on commit a798974

Please sign in to comment.