Skip to content

Commit

Permalink
try to fix nano performance
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 6, 2024
1 parent d7f22a0 commit 5dca4a7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
47 changes: 29 additions & 18 deletions hudes/hudes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,52 +144,63 @@ def run_loop(self):
def receive_messages(self):
# listen from server?
received_message = False
received_train = False
received_batch = False
received_val = False
while self.hudes_websocket_client.recv_ready():
received_message = True
# recv and process!
raw_msg = self.hudes_websocket_client.recv_msg()
msg = hudes_pb2.Control()
msg.ParseFromString(raw_msg)
if msg.type == hudes_pb2.Control.CONTROL_TRAIN_LOSS_AND_PREDS:
received_train = True

train_preds = pickle.loads(msg.train_loss_and_preds.preds)
confusion_matrix = pickle.loads(
self.train_preds = pickle.loads(msg.train_loss_and_preds.preds)
self.confusion_matrix = pickle.loads(
msg.train_loss_and_preds.confusion_matrix
)

self.train_losses.append(msg.train_loss_and_preds.train_loss)
self.train_steps.append(msg.request_idx)
# self.val_losses.append(msg.loss_and_preds.val_loss)

self.view.plot_train_and_val(
self.train_losses,
self.train_steps,
self.val_losses,
self.val_steps,
)
self.view.update_example_preds(train_preds=train_preds)
self.view.update_confusion_matrix(confusion_matrix)
elif msg.type == hudes_pb2.Control.CONTROL_BATCH_EXAMPLES:
received_batch = True
self.train_data = pickle.loads(msg.batch_examples.train_data)
self.val_data = pickle.loads(msg.batch_examples.val_data)
self.train_labels = pickle.loads(msg.batch_examples.train_labels)
self.val_labels = pickle.loads(msg.batch_examples.val_labels)
self.view.update_examples(
train_data=self.train_data,
val_data=self.val_data,
)

elif msg.type == hudes_pb2.Control.CONTROL_VAL_LOSS:
received_val = True
self.val_losses.append(msg.val_loss.val_loss)
self.val_steps.append(msg.request_idx)

# called if we only changed scale etc?
elif msg.type == hudes_pb2.Control.CONTROL_MESHGRID_RESULTS:
self.view.update_mesh_grids(pickle.loads(msg.mesh_grid_results))
# print("GOT MESH GRID", self.mesh_grid.shape)

if received_message:
if received_train:
self.view.plot_train_and_val(
self.train_losses,
self.train_steps,
self.val_losses,
self.val_steps,
)
self.view.update_example_preds(train_preds=self.train_preds)
self.view.update_confusion_matrix(self.confusion_matrix)
if received_batch:
self.view.update_examples(
train_data=self.train_data,
)
if received_val:
self.view.plot_train_and_val(
self.train_losses,
self.train_steps,
self.val_losses,
self.val_steps,
)
# called if we only changed scale etc?
elif msg.type == hudes_pb2.Control.CONTROL_MESHGRID_RESULTS:
self.view.update_mesh_grids(pickle.loads(msg.mesh_grid_results))
# print("GOT MESH GRID", self.mesh_grid.shape)
return received_message
4 changes: 2 additions & 2 deletions hudes/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(self, use_midi=False):
prop_cycle = plt.rcParams["axes.prop_cycle"]
self.plt_colors = prop_cycle.by_key()["color"]

def update_examples(self, train_data: torch.Tensor, val_data: torch.Tensor):
def update_examples(self, train_data: torch.Tensor):
self.axd["F"].cla()
self.axd["F"].imshow(train_data[0])
self.axd["F"].set_title("Ex. 1 img")
Expand Down Expand Up @@ -472,7 +472,7 @@ def __init__(self, grid_size, grids):
# ax.plot(x, y)
# ax.set_title("Sine Wave")

def update_examples(self, train_data: torch.Tensor, val_data: torch.Tensor):
def update_examples(self, train_data: torch.Tensor):
self.axd["F"].cla()
self.axd["F"].imshow(train_data[0])
self.axd["F"].set_title("Ex. 1 img")
Expand Down

0 comments on commit 5dca4a7

Please sign in to comment.