Skip to content

Commit

Permalink
try to fix render time
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 8, 2024
1 parent a0e1bb4 commit 961db41
Showing 1 changed file with 42 additions and 50 deletions.
92 changes: 42 additions & 50 deletions hudes/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,23 +215,16 @@ def __init__(self, use_midi=False):
self.redraw_top = True
self.redraw_train_and_val = True

def update_examples(self, train_data: torch.Tensor):
self.axd["F"].cla()
self.axd["F"].imshow(train_data[0])
self.axd["F"].set_title("Ex. 1 img")

self.axd["G"].cla()
self.axd["G"].imshow(train_data[1])
self.axd["G"].set_title("Ex. 2 img")
for _ax in self.axd:
self.axd[_ax].redraw = True

self.axd["H"].cla()
self.axd["H"].imshow(train_data[2])
self.axd["H"].set_title("Ex. 3 img")

self.axd["M"].cla()
self.axd["M"].imshow(train_data[3])
self.axd["M"].set_title("Ex. 4 img")
self.redraw_examples = True
def update_examples(self, train_data: torch.Tensor):
for idx, _ax in enumerate(("F", "G", "H", "M")):
ax = self.axd[_ax]
ax.cla()
ax.imshow(train_data[idx])
ax.set_title(f"Ex. {idx} img")
ax.redraw = True

# self.axd["I"].cla()
# self.axd["I"].imshow(train_data[3])
Expand All @@ -252,7 +245,7 @@ def update_step_size(
self.axd["I"].set_xlim(min_log_step_size, max_log_step_size)
self.axd["I"].set_title("log(Step size)")
self.axd["I"].set_yticks([])
self.redraw_step_size = True
self.axd["I"].redraw = True

def update_confusion_matrix(self, confusion_matrix: torch.Tensor):
self.axd["E"].cla()
Expand All @@ -262,7 +255,7 @@ def update_confusion_matrix(self, confusion_matrix: torch.Tensor):
self.axd["E"].set_ylabel("Ground truth")
self.axd["E"].set_xlabel("Prediction")
self.axd["E"].set_title("Confusion matrix")
self.redraw_confusion = True
self.axd["E"].redraw = True

def update_dims_since_last_update(self, dims_and_steps_on_current_dims):
self.axd["O"].cla()
Expand All @@ -279,28 +272,15 @@ def update_dims_since_last_update(self, dims_and_steps_on_current_dims):
self.axd["O"].set_ylabel("cumulative step")
self.axd["O"].set_title("Dims and Steps")
self.axd["O"].set_yticks([])
self.redraw_dims = True
self.axd["O"].redraw = True

def update_example_preds(self, train_preds: List[float]):
self.axd["J"].cla()
self.axd["J"].bar(torch.arange(10), train_preds[0])
self.axd["J"].set_title("Ex. 1 pr(y)")

self.axd["K"].cla()
self.axd["K"].bar(torch.arange(10), train_preds[1])
self.axd["K"].set_title("Ex. 2 pr(y)")

self.axd["L"].cla()
self.axd["L"].bar(torch.arange(10), train_preds[2])
self.axd["L"].set_title("Ex. 3 pr(y)")

self.axd["N"].cla()
self.axd["N"].bar(torch.arange(10), train_preds[3])
self.axd["N"].set_title("Ex. 4 pr(y)")

# self.axd["M"].cla()
# self.axd["M"].bar(torch.arange(10), train_preds[3])
self.redraw_preds = True
for idx, _ax in enumerate(("J", "K", "L", "N")):
ax = self.axd[_ax]
ax.cla()
ax.bar(torch.arange(10), train_preds[idx])
ax.set_title(f"Ex. {idx} pr(y)")
ax.redraw = True

def plot_train_and_val(
self,
Expand All @@ -323,6 +303,7 @@ def plot_train_and_val(
self.axd["B"].set_title("Loss")
self.axd["B"].set_xlabel("Step")
self.axd["B"].set_ylabel("Loss")
self.axd["B"].redraw = True

# self.axd["C"].cla()
# self.axd["C"].plot(train_steps[n // 2 :], train_losses[n // 2 :], label="train")
Expand All @@ -336,24 +317,35 @@ def plot_train_and_val(
self.axd["D"].set_title("Loss [last 8steps]")
self.axd["D"].set_yticks([])
self.axd["D"].set_xlabel("Step")
self.redraw_train_and_val = True
self.axd["D"].redraw = True

def draw_or_restore(self):
for _ax, ax in self.axd.items():
if ax.redraw:
ax.draw(self.renderer)
ax.cache = self.fig.canvas.copy_from_bbox(
ax.get_tightbbox(self.renderer) # TODO cache bbox?
)
else:
self.fig.canvas.restore_region(ax.cache)

def draw(self):
if True: # backend.lower()=='agg':
# self.canvas.draw()
# self.canvas.update()
self.renderer.clear()
if self.redraw_train_and_val:
self.axd["B"].draw(self.renderer)
self.axd["B"].cache = self.fig.canvas.copy_from_bbox(
self.axd["B"].get_tightbbox(self.renderer)
)
self.axd["D"].draw(self.renderer)
self.axd["D"].cache = self.fig.canvas.copy_from_bbox(self.axd["D"].bbox)
self.redraw_train_and_val = False
else:
self.fig.canvas.restore_region(self.axd["B"].cache)
self.fig.canvas.restore_region(self.axd["D"].cache)
self.draw_or_restore()
# if self.redraw_train_and_val:
# self.axd["B"].draw(self.renderer)
# self.axd["B"].cache = self.fig.canvas.copy_from_bbox(
# self.axd["B"].get_tightbbox(self.renderer)
# )
# self.axd["D"].draw(self.renderer)
# self.axd["D"].cache = self.fig.canvas.copy_from_bbox(self.axd["D"].bbox)
# self.redraw_train_and_val = False
# else:
# self.fig.canvas.restore_region(self.axd["B"].cache)
# self.fig.canvas.restore_region(self.axd["D"].cache)
surf = pg.image.frombytes(
self.renderer.tostring_rgb(),
self.window_size,
Expand Down

0 comments on commit 961db41

Please sign in to comment.