-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_krkn_ai.py
152 lines (132 loc) Β· 4.99 KB
/
run_krkn_ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import logging
import os.path
import shutil
import sys
import time
from multiprocessing import Lock
from typing import Callable
import yaml
from yaspin import yaspin
from yaspin.core import Yaspin
from yaspin.spinners import Spinners
from krkn_ai.data_retriever import DataRetriever
from krkn_ai.llm.model_factory import ModelFactory
from krkn_ai.scenarios.scenario_factory import ScenarioFactory
def update_callback(
console: Yaspin,
message: str,
counter: list[int],
total: int,
lock: Lock,
):
with lock:
counter.append(1)
console.text = f"{message}: {len(counter)}/{total}"
def error_callback(console: Yaspin, message: str, lock: Lock):
with lock:
console.write(f"π¨ {message}")
def main():
config = {}
try:
with open("config/config.yaml", "r") as stream:
config = yaml.safe_load(stream)
except Exception as e:
logging.error(f"failed to load config file: {e}")
sys.exit(1)
download_counter = []
parse_counter = []
console_lock = Lock()
model_factory = ModelFactory()
scenario_factory = ScenarioFactory()
data_retriever = DataRetriever(
config["telemetry"]["api_url"],
config["telemetry"]["username"],
config["telemetry"]["password"],
config["krkn_ai"]["threads"],
config["krkn_ai"]["dataset_path"],
)
with yaspin(color="red") as sp:
update_callback_download: Callable[[int, Lock], None] = (
lambda total, lock: update_callback(
sp, "downloaded: ", download_counter, total, lock
)
)
update_callback_normalize: Callable[[int, Lock], None] = (
lambda total, lock: update_callback(
sp, "file parsed: ", parse_counter, total, lock
)
)
error_callback_global: Callable[[str, Lock], None] = (
lambda message, lock: error_callback(sp, message, lock)
)
if (
not config["krkn_ai"]["reuse_dataset"]
or not data_retriever.get_lock_path()
):
sp.text = "fetching telemetry download links from API..."
if not data_retriever.get_lock_path():
sp.write(
f"lock file not found forced to download training data...."
)
if config["krkn_ai"]["dataset_starting_timestamp"] > 0:
urls = data_retriever.get_telemetry_urls(
config["krkn_ai"]["dataset_starting_timestamp"]
)
else:
urls = data_retriever.get_telemetry_urls(
config["krkn_ai"]["dataset_starting_timestamp"]
)
sp.ok(f"{len(urls)} urls fetched β
")
data_path = data_retriever.download_telemetry_data(
urls,
console_lock,
update_callback_download,
error_callback_global,
)
sp.ok(f"{len(urls)} files downloaded β
")
else:
lock_path = data_retriever.get_lock_path()
sp.text = f"reusing dataset from lockfile {lock_path} β
"
data_path = lock_path
for scenario_config in config["krkn_ai"]["scenarios"]:
try:
model = model_factory.get_instance(
scenario_config["model"]["class_name"],
scenario_config["model"]["package"],
scenario_config["model"]["endpoint"],
scenario_config["model"]["name"],
)
scenario = scenario_factory.get_instance(
model,
scenario_config["class_name"],
scenario_config["package"],
scenario_config["vector_db_path"],
)
except (TypeError, AttributeError) as e:
sp.fail(
f"π¨ failed to run scenario {scenario_config['class_name']} : {e}"
)
continue
if scenario_config["retrain"]:
if os.path.exists(scenario.get_vector_db_path()):
sp.text = "deleting vector db..."
shutil.rmtree(scenario.get_vector_db_path())
sp.ok("vector db deleted β
")
parsed_file = scenario.normalize_data(
data_path,
config["krkn_ai"]["threads"],
console_lock,
update_callback_normalize,
error_callback_global,
)
sp.text = "creating embeddings and writing in the vectordb...."
scenario.train(parsed_file)
sp.ok(f"documents written in the vectordb β
")
try:
sp.text = "starting llm interactive prompt..."
scenario.interactive_prompt()
except ValueError as e:
sp.fail(f"π¨ failed to start interactive prompt : {e}")
continue
if __name__ == "__main__":
main()