-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathbatch_score.py
34 lines (29 loc) · 1.09 KB
/
batch_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import mlflow
import mlflow.sklearn
import common
print(f"MLflow Version: {mlflow.__version__}")
def score(model_uri, data_path):
_, X_test, _, _ = common.build_data(data_path)
data = X_test
print("==== sklearn score")
model = mlflow.sklearn.load_model(model_uri)
print("model:", model)
print("model.type:", type(model))
predictions = model.predict(data)
print("predictions:", predictions)
print("==== pyfunc score")
model = mlflow.pyfunc.load_model(model_uri)
print("model:", model)
print("model.type:", type(model))
predictions = model.predict(data)
print("predictions:", predictions)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--model_uri", dest="model_uri", help="Model URI", default=common.model_uri)
parser.add_argument("--data_path", dest="data_path", help="Data path", default=common.data_path)
args = parser.parse_args()
print("Arguments:")
for arg in vars(args):
print(f" {arg}: {getattr(args, arg)}")
score(args.model_uri, args.data_path)