Skip to content

Commit

Permalink
faster rendering times
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 10, 2024
1 parent a0a3016 commit e3cf2a0
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 72 deletions.
31 changes: 26 additions & 5 deletions hudes/hudes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def __init__(
self.step_size_resolution = step_size_resolution

self.batch_size = 128
self.dtype = "float16"

self.dtype_idx = -1
self.dtypes = ("float16", "float32")

self.batch_size_idx = 3 - 1
self.batch_sizes = [2, 8, 32, 128, 512]

def get_next_batch(self):
self.hudes_websocket_client.send_q.put(next_batch_message().SerializeToString())
Expand All @@ -70,6 +75,22 @@ def set_n(self, n):

def attach_view(self, view):
self.view = view
self.toggle_dtype(init=True)
self.toggle_batch_size(init=True)

def toggle_dtype(self, init=False):
self.dtype_idx = (self.dtype_idx + 1) % len(self.dtypes)
self.dtype = self.dtypes[self.dtype_idx]
self.view.dtype = self.dtype
if not init:
self.send_config()

def toggle_batch_size(self, init=False):
self.batch_size_idx = (self.batch_size_idx + 1) % len(self.batch_sizes)
self.batch_size = self.batch_sizes[self.batch_size_idx]
self.view.batch_size = self.batch_size
if not init:
self.send_config()

def send_config(self):
self.hudes_websocket_client.send_config(
Expand All @@ -85,11 +106,11 @@ def send_config(self):
def dims_and_steps_updated(self):
self.view.update_dims_since_last_update(self.dims_and_steps_on_current_dims)

def step_size_increase(self):
self.set_step_size_idx(self.step_size_idx + 1)
def step_size_increase(self, mag: int = 1):
self.set_step_size_idx(self.step_size_idx + 1 * mag)

def step_size_decrease(self):
self.set_step_size_idx(self.step_size_idx - 1)
def step_size_decrease(self, mag: int = 1):
self.set_step_size_idx(self.step_size_idx - 1 * mag)

def set_step_size_idx(self, idx):
self.step_size_idx = idx
Expand Down
4 changes: 2 additions & 2 deletions hudes/keyboard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def process_key_press(self, event):
self.send_dims_and_steps({dim: self.step_size * sign})
return False # we are going to get a response shortly that updates
elif key == "[":
self.step_size_decrease()
self.step_size_increase()
return True
elif key == "]":
self.step_size_increase()
self.step_size_decrease()
return True
elif key == " ":
print("getting new set of vectors")
Expand Down
14 changes: 10 additions & 4 deletions hudes/keyboard_client_openGL.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def process_key_press(self, event):
joystick.rumble(0, 0.7, 500)
self.get_next_batch()

if event.button == 1:
self.toggle_dtype()

if event.button == 3:
self.toggle_batch_size()

if event.type == pg.JOYBUTTONUP:
print("Joystick button released.")

Expand Down Expand Up @@ -83,11 +89,11 @@ def run_loop(self):
redraw = True

if joystick.get_button(9) > 0.5:
self.step_size_decrease()
self.step_size_decrease(2)
self.send_config()

if joystick.get_button(10) > 0.5:
self.step_size_increase()
self.step_size_increase(2)
self.send_config()

if joystick.get_button(11) > 0.5:
Expand Down Expand Up @@ -120,8 +126,8 @@ def run_loop(self):
# math.atan2(A[0], A[1])
self.send_dims_and_steps(
{
1 + selected_grid * 2: A[1] * self.step_size * 0.1,
0 + selected_grid * 2: A[0] * self.step_size * 0.1,
1 + selected_grid * 2: A[1] * self.step_size * 0.25,
0 + selected_grid * 2: A[0] * self.step_size * 0.25,
}
)

Expand Down
4 changes: 2 additions & 2 deletions hudes/model_data_and_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ def get_dim_vec(self, dim: int, dtype):
*self.model_params[dtype].shape,
generator=g,
device=self.device,
dtype=dtype,
dtype=torch.float32,
)
- 0.5,
p=2,
dim=1,
)
).to(dtype)

# dims is a dictionary {dim:step_size}
@torch.no_grad
Expand Down
2 changes: 1 addition & 1 deletion hudes/opengl_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def render_text_2d(text, font_size, screen_width, screen_height):

# Calculate the position to center the text at the top
x_position = (screen_width - text_width) // 2
y_position = text_height
y_position = screen_height - text_height

# Disable depth testing for text rendering
glDisable(GL_DEPTH_TEST)
Expand Down
164 changes: 109 additions & 55 deletions hudes/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,72 +227,114 @@ def __init__(self, use_midi=False):
self.redraw_top = True
self.redraw_train_and_val = True

self.confusion_matrix_init = False
self.example_im_show_init = False
self.example_imshow_init = False
self.dims_chart_init = False
self.init_step_size_plot = False

self.best_score = None

self.font = pygame.font.SysFont("Comic Sans MS", 30)
self.update_top(-math.inf)
for _ax in self.axd:
self.axd[_ax].redraw = True

def update_examples(self, train_data: torch.Tensor):
for idx, _ax in enumerate(("F", "G", "H", "M")):
ax = self.axd[_ax]
ax.cla()
ax.imshow(train_data[idx])
ax.set_title(f"Ex. {idx} img")
ax.redraw = True
if not self.example_im_show_init:
for idx, _ax in enumerate(("F", "G", "H", "M")):
ax = self.axd[_ax]
ax.cla()
ax.im = ax.imshow(train_data[idx])
ax.set_title(f"Ex. {idx} img")
ax.redraw = True
self.example_im_show_init = True
else:
for idx, _ax in enumerate(("F", "G", "H", "M")):
ax = self.axd[_ax]
ax.im.set_data(train_data[idx])
ax.redraw = True

# self.axd["I"].cla()
# self.axd["I"].imshow(train_data[3])

def update_top(self, best_score):
if best_score is None:
self.fig.suptitle("Human Descent: MNIST Top-score: ?")
else:
self.fig.suptitle(f"Human Descent: MNIST Top-score: {best_score:.5e}")
self.redraw_top = True
# self.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
def update_top(self, maybe_new_best_score):
if self.best_score is None or maybe_new_best_score > self.best_score:
self.best_score = maybe_new_best_score
render_str = f"Human Descent: MNIST Top-score: {self.best_score:.5e}"
self.top_title_rendered = self.font.render(render_str, False, (0, 0, 0))

def update_step_size(
self, log_step_size: float, max_log_step_size: float, min_log_step_size: float
):
self.axd["I"].cla()
self.axd["I"].barh([0], log_step_size)
self.axd["I"].set_xlim(min_log_step_size, max_log_step_size)
self.axd["I"].set_title("log(Step size)")
self.axd["I"].set_yticks([])
if not self.init_step_size_plot:
self.axd["I"].cla()
self.axd["I"].bars = self.axd["I"].barh([0], log_step_size)
self.axd["I"].set_xlim(min_log_step_size, max_log_step_size)
self.axd["I"].set_title("log(Step size)")
self.axd["I"].set_yticks([])
self.init_step_size_plot = True
else:
self.axd["I"].bars[0].set_width(log_step_size)
self.axd["I"].redraw = True

def update_confusion_matrix(self, confusion_matrix: torch.Tensor):
self.axd["E"].cla()
self.axd["E"].imshow(confusion_matrix)
self.axd["E"].set_yticks(range(10))
self.axd["E"].set_xticks(range(10))
self.axd["E"].set_ylabel("Ground truth")
self.axd["E"].set_xlabel("Prediction")
self.axd["E"].set_title("Confusion matrix")
if not self.confusion_matrix_init:
self.axd["E"].cla()
self.axd["E"].im = self.axd["E"].imshow(confusion_matrix)
self.axd["E"].set_yticks(range(10))
self.axd["E"].set_xticks(range(10))
self.axd["E"].set_ylabel("Ground truth")
self.axd["E"].set_xlabel("Prediction")
self.axd["E"].set_title("Confusion matrix")
self.confusion_matrix_init = True
else:
self.axd["E"].im.set_data(confusion_matrix)
self.axd["E"].redraw = True

def update_dims_since_last_update(self, dims_and_steps_on_current_dims):
self.axd["O"].cla()
colors = [
self.plt_colors[idx % len(self.plt_colors)]
for idx in range(dims_and_steps_on_current_dims.shape[0])
]
self.axd["O"].bar(
range(dims_and_steps_on_current_dims.shape[0]),
dims_and_steps_on_current_dims,
color=colors,
)
self.axd["O"].set_xlabel("dimension #")
self.axd["O"].set_ylabel("cumulative step")
self.axd["O"].set_title("Dims and Steps")
self.axd["O"].set_yticks([])
if not self.dims_chart_init:
self.axd["O"].cla()
colors = [
self.plt_colors[idx % len(self.plt_colors)]
for idx in range(dims_and_steps_on_current_dims.shape[0])
]
self.axd["O"].bars = self.axd["O"].bar(
range(dims_and_steps_on_current_dims.shape[0]),
dims_and_steps_on_current_dims,
color=colors,
)
self.axd["O"].set_xlabel("dimension #")
self.axd["O"].set_ylabel("cumulative step")
self.axd["O"].set_title("Dims and Steps")
self.axd["O"].set_yticks([])
self.dims_chart_init = True
else:
max_mag = np.abs(dims_and_steps_on_current_dims).max()
self.axd["O"].set_ylim([-max_mag, max_mag])
for bar, new_height in zip(
self.axd["O"].bars, dims_and_steps_on_current_dims
):
bar.set_height(new_height)
self.axd["O"].redraw = True

def update_example_preds(self, train_preds: List[float]):
for idx, _ax in enumerate(("J", "K", "L", "N")):
ax = self.axd[_ax]
ax.cla()
ax.bar(torch.arange(10), train_preds[idx])
ax.set_title(f"Ex. {idx} pr(y)")
ax.redraw = True
if not self.example_imshow_init:
for idx, _ax in enumerate(("J", "K", "L", "N")):
ax = self.axd[_ax]
ax.cla()
ax.bars = ax.bar(torch.arange(10), train_preds[idx])
ax.set_xlim([0, 9])
ax.set_ylim([0, 1.0])
ax.set_title(f"Ex. {idx} pr(y)")
ax.redraw = True
self.example_imshow_init = True
else:
for idx, _ax in enumerate(("J", "K", "L", "N")):
ax = self.axd[_ax]
for bar, new_height in zip(ax.bars, train_preds[idx]):
bar.set_height(new_height)
ax.redraw = True

def plot_train_and_val(
self,
Expand All @@ -301,10 +343,8 @@ def plot_train_and_val(
val_losses: List[float],
val_steps: List[int],
):
new_best_score = min(val_losses) if len(val_losses) > 0 else -math.inf
if new_best_score < self.best_score:
self.best_score = new_best_score
self.update_top(self.best_score)
maybe_new_best_score = min(val_losses) if len(val_losses) > 0 else -math.inf
self.update_top(maybe_new_best_score=maybe_new_best_score)

n = len(train_losses)
# x = torch.arange(n)
Expand Down Expand Up @@ -353,10 +393,13 @@ def draw(self):
# self.canvas.draw()
if True:
self.draw_or_restore()

# self.update_top(self.best_score)
surf = pygame.image.frombuffer(
self.surface.get_data(), self.window_size, "RGBA"
)
self.screen.blit(surf, (0, 0))
self.screen.blit(self.top_title_rendered, (0, 0))
# self.draw_or_restore()
# self.renderer = self.canvas._renderer

Expand All @@ -376,6 +419,8 @@ def draw(self):
# breakpoint()
self.renderer.clear()
self.draw_or_restore()

# self.update_top(self.best_score)
# if self.redraw_train_and_val:
# self.axd["B"].draw(self.renderer)
# self.axd["B"].cache = self.fig.canvas.copy_from_bbox(
Expand All @@ -399,8 +444,9 @@ def draw(self):
# "RGBA",
# )
self.screen.blit(surf, (0, 0))
self.screen.blit(self.top_title_rendered, (0, 0))
# else:

# self.draw_text()
pg.display.flip() # draws whole screen vs update that draws a parts

logging.debug("hudes_client: redraw done")
Expand Down Expand Up @@ -543,7 +589,7 @@ def __init__(self, grid_size, grids):

self.running = True
self.default_angleV = 20
self.max_angleV = 40
self.max_angleV = 25
self.angleH = 0.0
self.angleV = self.default_angleV
self.origin_loss = 0.0
Expand All @@ -563,6 +609,9 @@ def __init__(self, grid_size, grids):
"""
)

self.dtype = "?"
self.batch_size = "?"

# self.screen = pg.display.get_surface()

# # # Step 2: Create the Matplotlib figure
Expand Down Expand Up @@ -734,20 +783,25 @@ def decrement_selected_grid(self):
self.selected_grid = (self.selected_grid - 1) % self.effective_grids

def adjust_angles(self, angle_H, angle_V):
self.angleH += angle_H
self.angleV += angle_V
self.angleH += 2 * angle_H
self.angleV += 2 * angle_V
self.angleV = norm_deg(self.angleV) # % 360
self.angleH = norm_deg(self.angleH) # % 360
self.angleV = np.sign(self.angleV) * min(np.abs(self.angleV), self.max_angleV)
# print(self.angleH, self.angleV)

def reset_angle(self):
self.angleH = 0
self.angleV = self.default_angleV

def draw_all_text(self):

# render_text_2d("Batch size:", 36, self.window_size[0], self.window_size[1])
render_text_2d(
"Human Descent: MNIST", 36, self.window_size[0], self.window_size[1]
f"batch-size: {self.batch_size}, dtype: {self.dtype}",
20,
self.window_size[0],
self.window_size[1],
)

def draw(self):
Expand Down Expand Up @@ -853,7 +907,7 @@ def draw(self):
glDisableClientState(GL_VERTEX_ARRAY)
glDisableClientState(GL_COLOR_ARRAY)

# self.draw_all_text()
self.draw_all_text()
# # Render the texture from the Matplotlib figure in 2D
window_size = pg.display.get_surface().get_size() # Get window size

Expand Down
Loading

0 comments on commit e3cf2a0

Please sign in to comment.