-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathkeras_predict.py
33 lines (26 loc) · 941 Bytes
/
keras_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import click
import mlflow
import mlflow.keras
import utils
utils.display_versions()
@click.command()
@click.option("--model-uri", help="Model URI", required=True, type=str)
@click.option("--data-path", help="Data path", default=None, type=str)
def main(model_uri, data_path):
print("Options:")
for k,v in locals().items(): print(f" {k}: {v}")
data = utils.get_prediction_data(data_path)
print("data.type:", type(data))
print("data.shape:", data.shape)
print("\n**** mlflow.keras.load_model\n")
model = mlflow.keras.load_model(model_uri)
print("model:", type(model))
print("\n== model.predict")
predictions = model.predict(data)
print("predictions.type:", type(predictions))
print("predictions.shape:", predictions.shape)
#print("predictions:", predictions)
utils.display_predictions(predictions)
utils.predict_pyfunc(model_uri, data)
if __name__ == "__main__":
main()