Skip to content

Commit

Permalink
try to optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 8, 2024
1 parent 0bb124a commit a0e1bb4
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
2 changes: 1 addition & 1 deletion hudes/keyboard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def process_key_press(self, event):
if key in self.key_to_param_and_sign:
dim, sign = self.key_to_param_and_sign[key]
self.send_dims_and_steps({dim: self.step_size * sign})
return True
return False
elif key == "[":
self.step_size_decrease()
return True
Expand Down
69 changes: 51 additions & 18 deletions hudes/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
update_grid_vbo,
)

# backend = "Agg"
backend = "Agg"
# backend='cairo'
# matplotlib.use(backend)
matplotlib.use(backend)


def surface_to_npim(surface):
Expand Down Expand Up @@ -166,22 +166,23 @@ def __init__(self, use_midi=False):
print(f"using input_id :{self.midi_input_id}:")
self.midi_input = pygame.midi.Input(self.midi_input_id)

dpi = 200 # plt.rcParams["figure.dpi"]
# dpi = 200 # plt.rcParams["figure.dpi"]
logging.info(f"Matplotlib backend: {plt.get_backend()}")

self.window_size = (1200, 800)
self.fig = plt.figure(
figsize=(self.window_size[0] / dpi, self.window_size[1] / dpi),
)

if self.fig.dpi != dpi:
logging.warning(
f"DPI flag not respected by matplotlib backend ({plt.get_backend()})! Should be {dpi} but is {self.fig.dpi} "
)
self.window_size = (
int(self.fig.get_figwidth() * self.fig.dpi),
int(self.fig.get_figheight() * self.fig.dpi),
)
# self.fig = plt.figure(
# figsize=(self.window_size[0] / dpi, self.window_size[1] / dpi), dpi=dpi
# )
self.fig = plt.figure(figsize=(12, 8))

# if self.fig.dpi != dpi:
# logging.warning(
# f"DPI flag not respected by matplotlib backend ({plt.get_backend()})! Should be {dpi} but is {self.fig.dpi} "
# )
# self.window_size = (
# int(self.fig.get_figwidth() * self.fig.dpi),
# int(self.fig.get_figheight() * self.fig.dpi),
# )

self.window = pg.display.set_mode(self.window_size)

Expand All @@ -205,6 +206,15 @@ def __init__(self, use_midi=False):
prop_cycle = plt.rcParams["axes.prop_cycle"]
self.plt_colors = prop_cycle.by_key()["color"]

self.best_score = math.inf

self.redraw_confusion = True
self.redraw_dims = True
self.redraw_preds = True
self.redraw_step_size = True
self.redraw_top = True
self.redraw_train_and_val = True

def update_examples(self, train_data: torch.Tensor):
self.axd["F"].cla()
self.axd["F"].imshow(train_data[0])
Expand All @@ -221,6 +231,7 @@ def update_examples(self, train_data: torch.Tensor):
self.axd["M"].cla()
self.axd["M"].imshow(train_data[3])
self.axd["M"].set_title("Ex. 4 img")
self.redraw_examples = True

# self.axd["I"].cla()
# self.axd["I"].imshow(train_data[3])
Expand All @@ -230,6 +241,7 @@ def update_top(self, best_score):
self.fig.suptitle("Human Descent: MNIST Top-score: ?")
else:
self.fig.suptitle(f"Human Descent: MNIST Top-score: {best_score:.5e}")
self.redraw_top = True
# self.fig.tight_layout(rect=[0, 0.03, 1, 0.95])

def update_step_size(
Expand All @@ -240,6 +252,7 @@ def update_step_size(
self.axd["I"].set_xlim(min_log_step_size, max_log_step_size)
self.axd["I"].set_title("log(Step size)")
self.axd["I"].set_yticks([])
self.redraw_step_size = True

def update_confusion_matrix(self, confusion_matrix: torch.Tensor):
self.axd["E"].cla()
Expand All @@ -249,6 +262,7 @@ def update_confusion_matrix(self, confusion_matrix: torch.Tensor):
self.axd["E"].set_ylabel("Ground truth")
self.axd["E"].set_xlabel("Prediction")
self.axd["E"].set_title("Confusion matrix")
self.redraw_confusion = True

def update_dims_since_last_update(self, dims_and_steps_on_current_dims):
self.axd["O"].cla()
Expand All @@ -265,6 +279,7 @@ def update_dims_since_last_update(self, dims_and_steps_on_current_dims):
self.axd["O"].set_ylabel("cumulative step")
self.axd["O"].set_title("Dims and Steps")
self.axd["O"].set_yticks([])
self.redraw_dims = True

def update_example_preds(self, train_preds: List[float]):
self.axd["J"].cla()
Expand All @@ -285,6 +300,7 @@ def update_example_preds(self, train_preds: List[float]):

# self.axd["M"].cla()
# self.axd["M"].bar(torch.arange(10), train_preds[3])
self.redraw_preds = True

def plot_train_and_val(
self,
Expand All @@ -293,8 +309,10 @@ def plot_train_and_val(
val_losses: List[float],
val_steps: List[int],
):
best_score = min(val_losses) if len(val_losses) > 0 else -math.inf
self.update_top(best_score)
new_best_score = min(val_losses) if len(val_losses) > 0 else -math.inf
if new_best_score < self.best_score:
self.best_score = new_best_score
self.update_top(self.best_score)

n = len(train_losses)
# x = torch.arange(n)
Expand All @@ -312,15 +330,30 @@ def plot_train_and_val(
# self.axd["C"].set_xlabel("Step")
# self.axd["C"].set_yticks([])

self.axd["D"].clear()
self.axd["D"].cla()
self.axd["D"].plot(train_steps[-8:], train_losses[-8:], label="train")
self.axd["D"].set_title("Loss [last 8steps]")
self.axd["D"].set_yticks([])
self.axd["D"].set_xlabel("Step")
self.redraw_train_and_val = True

def draw(self):
if True: # backend.lower()=='agg':
self.canvas.draw()
# self.canvas.draw()
# self.canvas.update()
self.renderer.clear()
if self.redraw_train_and_val:
self.axd["B"].draw(self.renderer)
self.axd["B"].cache = self.fig.canvas.copy_from_bbox(
self.axd["B"].get_tightbbox(self.renderer)
)
self.axd["D"].draw(self.renderer)
self.axd["D"].cache = self.fig.canvas.copy_from_bbox(self.axd["D"].bbox)
self.redraw_train_and_val = False
else:
self.fig.canvas.restore_region(self.axd["B"].cache)
self.fig.canvas.restore_region(self.axd["D"].cache)
surf = pg.image.frombytes(
self.renderer.tostring_rgb(),
self.window_size,
Expand Down

0 comments on commit a0e1bb4

Please sign in to comment.