-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathtrain.py
56 lines (49 loc) · 2.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import mlflow
import mlflow.sklearn
import common
client = mlflow.client.MlflowClient()
print(f"MLflow Version: {mlflow.__version__}")
print("Tracking URI:", mlflow.get_tracking_uri())
def train(X_train, X_test, y_train, y_test, max_depth):
with mlflow.start_run() as run:
mlflow.set_tag("mlflow_version", mlflow.__version__)
mlflow.log_param("max_depth", max_depth)
dt = DecisionTreeRegressor(max_depth=max_depth)
dt.fit(X_train, y_train)
predictions = dt.predict(X_test)
rmse = np.sqrt(mean_squared_error(y_test, predictions))
mlflow.log_metric("rmse", rmse)
print(f"{rmse:5.3f} {max_depth:2d} {run.info.run_id} {run.info.experiment_id}")
mlflow.sklearn.log_model(dt, "sklearn-model")
def run(experiment_name, data_path):
print(f"==== {__file__} ====")
mlflow.set_experiment(experiment_name)
exp = client.get_experiment_by_name(experiment_name)
print(f"Experiment ID: {exp.experiment_id}")
# Delete existing runs
runs = client.search_runs(exp.experiment_id)
for run in runs:
client.delete_run(run.info.run_id)
# Train against different parameters
X_train, X_test, y_train, y_test = common.build_data(data_path)
params = (1, 2, 4, 16)
print(f"Params: {params}")
for p in params:
train(X_train, X_test, y_train, y_test, p)
# Find best run
runs = client.search_runs(exp.experiment_id, order_by=["metrics.rmse ASC"], max_results=1)
best_run = runs[0]
print(f"Best run: {best_run.data.metrics['rmse']:5.3f} {best_run.info.run_id}")
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--experiment_name", dest="experiment_name", help="experiment_name", default=common.experiment_name)
parser.add_argument("--data_path", dest="data_path", help="Data path", default=common.data_path)
args = parser.parse_args()
print("Arguments:")
for arg in vars(args):
print(f" {arg}: {getattr(args, arg)}")
run(args.experiment_name, args.data_path)