diff --git a/hudes/keyboard_client.py b/hudes/keyboard_client.py index 55fd647..80d7215 100644 --- a/hudes/keyboard_client.py +++ b/hudes/keyboard_client.py @@ -66,7 +66,7 @@ def process_key_press(self, event): if key in self.key_to_param_and_sign: dim, sign = self.key_to_param_and_sign[key] self.send_dims_and_steps({dim: self.step_size * sign}) - return True + return False elif key == "[": self.step_size_decrease() return True diff --git a/hudes/view.py b/hudes/view.py index 5c8464d..ed38112 100644 --- a/hudes/view.py +++ b/hudes/view.py @@ -26,9 +26,9 @@ update_grid_vbo, ) -# backend = "Agg" +backend = "Agg" # backend='cairo' -# matplotlib.use(backend) +matplotlib.use(backend) def surface_to_npim(surface): @@ -166,22 +166,23 @@ def __init__(self, use_midi=False): print(f"using input_id :{self.midi_input_id}:") self.midi_input = pygame.midi.Input(self.midi_input_id) - dpi = 200 # plt.rcParams["figure.dpi"] + # dpi = 200 # plt.rcParams["figure.dpi"] logging.info(f"Matplotlib backend: {plt.get_backend()}") self.window_size = (1200, 800) - self.fig = plt.figure( - figsize=(self.window_size[0] / dpi, self.window_size[1] / dpi), - ) - - if self.fig.dpi != dpi: - logging.warning( - f"DPI flag not respected by matplotlib backend ({plt.get_backend()})! Should be {dpi} but is {self.fig.dpi} " - ) - self.window_size = ( - int(self.fig.get_figwidth() * self.fig.dpi), - int(self.fig.get_figheight() * self.fig.dpi), - ) + # self.fig = plt.figure( + # figsize=(self.window_size[0] / dpi, self.window_size[1] / dpi), dpi=dpi + # ) + self.fig = plt.figure(figsize=(12, 8)) + + # if self.fig.dpi != dpi: + # logging.warning( + # f"DPI flag not respected by matplotlib backend ({plt.get_backend()})! Should be {dpi} but is {self.fig.dpi} " + # ) + # self.window_size = ( + # int(self.fig.get_figwidth() * self.fig.dpi), + # int(self.fig.get_figheight() * self.fig.dpi), + # ) self.window = pg.display.set_mode(self.window_size) @@ -205,6 +206,15 @@ def __init__(self, use_midi=False): prop_cycle = plt.rcParams["axes.prop_cycle"] self.plt_colors = prop_cycle.by_key()["color"] + self.best_score = math.inf + + self.redraw_confusion = True + self.redraw_dims = True + self.redraw_preds = True + self.redraw_step_size = True + self.redraw_top = True + self.redraw_train_and_val = True + def update_examples(self, train_data: torch.Tensor): self.axd["F"].cla() self.axd["F"].imshow(train_data[0]) @@ -221,6 +231,7 @@ def update_examples(self, train_data: torch.Tensor): self.axd["M"].cla() self.axd["M"].imshow(train_data[3]) self.axd["M"].set_title("Ex. 4 img") + self.redraw_examples = True # self.axd["I"].cla() # self.axd["I"].imshow(train_data[3]) @@ -230,6 +241,7 @@ def update_top(self, best_score): self.fig.suptitle("Human Descent: MNIST Top-score: ?") else: self.fig.suptitle(f"Human Descent: MNIST Top-score: {best_score:.5e}") + self.redraw_top = True # self.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) def update_step_size( @@ -240,6 +252,7 @@ def update_step_size( self.axd["I"].set_xlim(min_log_step_size, max_log_step_size) self.axd["I"].set_title("log(Step size)") self.axd["I"].set_yticks([]) + self.redraw_step_size = True def update_confusion_matrix(self, confusion_matrix: torch.Tensor): self.axd["E"].cla() @@ -249,6 +262,7 @@ def update_confusion_matrix(self, confusion_matrix: torch.Tensor): self.axd["E"].set_ylabel("Ground truth") self.axd["E"].set_xlabel("Prediction") self.axd["E"].set_title("Confusion matrix") + self.redraw_confusion = True def update_dims_since_last_update(self, dims_and_steps_on_current_dims): self.axd["O"].cla() @@ -265,6 +279,7 @@ def update_dims_since_last_update(self, dims_and_steps_on_current_dims): self.axd["O"].set_ylabel("cumulative step") self.axd["O"].set_title("Dims and Steps") self.axd["O"].set_yticks([]) + self.redraw_dims = True def update_example_preds(self, train_preds: List[float]): self.axd["J"].cla() @@ -285,6 +300,7 @@ def update_example_preds(self, train_preds: List[float]): # self.axd["M"].cla() # self.axd["M"].bar(torch.arange(10), train_preds[3]) + self.redraw_preds = True def plot_train_and_val( self, @@ -293,8 +309,10 @@ def plot_train_and_val( val_losses: List[float], val_steps: List[int], ): - best_score = min(val_losses) if len(val_losses) > 0 else -math.inf - self.update_top(best_score) + new_best_score = min(val_losses) if len(val_losses) > 0 else -math.inf + if new_best_score < self.best_score: + self.best_score = new_best_score + self.update_top(self.best_score) n = len(train_losses) # x = torch.arange(n) @@ -312,15 +330,30 @@ def plot_train_and_val( # self.axd["C"].set_xlabel("Step") # self.axd["C"].set_yticks([]) + self.axd["D"].clear() self.axd["D"].cla() self.axd["D"].plot(train_steps[-8:], train_losses[-8:], label="train") self.axd["D"].set_title("Loss [last 8steps]") self.axd["D"].set_yticks([]) self.axd["D"].set_xlabel("Step") + self.redraw_train_and_val = True def draw(self): if True: # backend.lower()=='agg': - self.canvas.draw() + # self.canvas.draw() + # self.canvas.update() + self.renderer.clear() + if self.redraw_train_and_val: + self.axd["B"].draw(self.renderer) + self.axd["B"].cache = self.fig.canvas.copy_from_bbox( + self.axd["B"].get_tightbbox(self.renderer) + ) + self.axd["D"].draw(self.renderer) + self.axd["D"].cache = self.fig.canvas.copy_from_bbox(self.axd["D"].bbox) + self.redraw_train_and_val = False + else: + self.fig.canvas.restore_region(self.axd["B"].cache) + self.fig.canvas.restore_region(self.axd["D"].cache) surf = pg.image.frombytes( self.renderer.tostring_rgb(), self.window_size,