diff --git a/hudes/hudes_client.py b/hudes/hudes_client.py index 6270018..840fe92 100644 --- a/hudes/hudes_client.py +++ b/hudes/hudes_client.py @@ -166,32 +166,24 @@ def before_pg_event(self): def run_loop(self): self.before_first_loop() while self.hudes_websocket_client.running: - # check and send local interactions(?) self.before_pg_event() redraw = False for event in pg.event.get(): redraw |= self.process_key_press(event) - # logging.debug("hudes_client: receive messages") redraw |= self.receive_messages() - # logging.debug("hudes_client: receive messages done") if redraw: self.view.draw() else: - # logging.debug("hudes_client: sleep") sleep(0.01) - # logging.debug("hudes_client: sleep up") def receive_messages(self): - # listen from server? received_message = False received_train = False received_batch = False received_val = False while self.hudes_websocket_client.recv_ready(): - # logging.debug("hudes_client: recieve message") received_message = True - # recv and process! raw_msg = self.hudes_websocket_client.recv_msg() msg = hudes_pb2.Control() msg.ParseFromString(raw_msg) @@ -207,15 +199,12 @@ def receive_messages(self): self.train_losses.append(msg.train_loss_and_preds.train_loss) self.train_steps.append(msg.request_idx) logging.debug("hudes_client: recieve message : loss and preds : done") - # self.val_losses.append(msg.loss_and_preds.val_loss) elif msg.type == hudes_pb2.Control.CONTROL_BATCH_EXAMPLES: logging.debug("hudes_client: recieve message : examples") received_batch = True self.train_data = pickle.loads(msg.batch_examples.train_data) - # self.val_data = pickle.loads(msg.batch_examples.val_data) self.train_labels = pickle.loads(msg.batch_examples.train_labels) - # self.val_labels = pickle.loads(msg.batch_examples.val_labels) logging.debug("hudes_client: recieve message : done") elif msg.type == hudes_pb2.Control.CONTROL_VAL_LOSS: @@ -228,7 +217,6 @@ def receive_messages(self): # called if we only changed scale etc? elif msg.type == hudes_pb2.Control.CONTROL_MESHGRID_RESULTS: self.view.update_mesh_grids(pickle.loads(msg.mesh_grid_results)) - # print("GOT MESH GRID", self.mesh_grid.shape) if received_message: if received_train: diff --git a/hudes/hudes_play.py b/hudes/hudes_play.py index 9907661..0b32897 100644 --- a/hudes/hudes_play.py +++ b/hudes/hudes_play.py @@ -2,7 +2,6 @@ import logging import os -# from hudes.akai_client import AkaiClient from hudes.keyboard_client import KeyboardClient from hudes.keyboard_client_openGL import KeyboardClientGL from hudes.view import OpenGLView, View @@ -41,9 +40,6 @@ def main(): joystick_controller_key=args.controller, ) view = OpenGLView(grid_size=args.grid_size, grids=args.grids) - # elif args.input == "akai": - # controller = AkaiClient(seed=args.seed) - # view = View() elif args.input == "xtouch": controller = XTouchClient( addr=args.addr, diff --git a/hudes/keyboard_client_openGL.py b/hudes/keyboard_client_openGL.py index bb94d74..43be1e0 100644 --- a/hudes/keyboard_client_openGL.py +++ b/hudes/keyboard_client_openGL.py @@ -135,12 +135,6 @@ def run_loop(self): ct = time.time() for joystick in self.joysticks.values(): - # axes = joystick.get_numaxes() - # print(f"Number of axes: {axes}") - # for i in range(axes): - # axis = joystick.get_axis(i) - # print(f"Axis {i} value: {axis:>6.3f}") - if ( joystick.get_button(self.joystick_controller.right_js_press_button) > 0.5 @@ -156,14 +150,6 @@ def run_loop(self): self.step_size_increase(2) self.send_config() - # if joystick.get_button(11) > 0.5: - # self.view.increase_zoom() - # redraw = True - - # if joystick.get_button(12) > 0.5: - # self.view.decrease_zoom() - # redraw = True - if ( joystick.get_axis(self.joystick_controller.left_trig_axis) > 0.5 and (ct - last_select_press) > 0.2 @@ -206,15 +192,6 @@ def run_loop(self): adjustH = B[0] * 2 adjustV = B[1] - # if np.abs(np.abs(angle) - 90 / 2) < 40: - # adjustH += np.sign(angle) - # redraw = True - - # if np.abs(angle) < 40: - # adjustV += 1 - # redraw = True - # elif np.abs(np.abs(angle) - 180) < 40: - # adjustV += -1 redraw = True self.view.adjust_angles(adjustH, adjustV) redraw = redraw | self.view.is_mouse_dragging | self.receive_messages() diff --git a/hudes/mnist.py b/hudes/mnist.py index ab6c48c..f6bba38 100644 --- a/hudes/mnist.py +++ b/hudes/mnist.py @@ -113,7 +113,6 @@ def mnist_model_data_and_subpace( seed: int = 0, store: str = "./", device="cpu", - loss_fn=indexed_loss, ): transform = transforms.Compose( [ diff --git a/hudes/model_data_and_subspace.py b/hudes/model_data_and_subspace.py index aa680ed..56de8c3 100644 --- a/hudes/model_data_and_subspace.py +++ b/hudes/model_data_and_subspace.py @@ -51,11 +51,15 @@ def fuse_parameters(model: nn.Module, device, dtype): return params -# @torch.jit.script -def indexed_loss(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor: +@torch.jit.script +def jit_indexed_loss(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor: return pred[torch.arange(label.shape[0]), label] +def indexed_loss(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor: + return jit_indexed_loss(pred, label) + + def get_param_module(module): if isinstance(module, torch.nn.Flatten): return torch.nn.Flatten(start_dim=module.start_dim + 1) @@ -75,10 +79,6 @@ def param_nn_from_sequential(model): @torch.jit.script def get_confusion_matrix(preds: torch.Tensor, labels: torch.Tensor): - # (Pdb) preds.shape - # torch.Size([512, 10]) - # (Pdb) labels.shape - # torch.Size([512]) assert preds.ndim == 2 and labels.ndim == 1 n = preds.shape[1] c_matrix = torch.vstack( @@ -164,7 +164,6 @@ def get_dim_vec(self, dim: int, dtype): dim=1, ).to(dtype) - # dims is a dictionary {dim:step_size} @torch.no_grad def delta_from_dims(self, dims: dict[int, float], dtype: torch.dtype): if len(dims) > 0: @@ -311,15 +310,9 @@ def get_loss_grid( base_weights, dims, arange=r, brange=r, dtype=dtype ) - # self.model.net[:2](batch[0]) - # self.param_model.modules_list=self.param_model.modules_list[:2] - # base_model = mp[grid_size//2,grid_size//2].reshape(1,-1) - # (self.model.net[:1](batch[0])==self.param_model.forward(bm,data)[1]).all() - logging.info(f"get_loss: mp done {mp.device} {mp.dtype}") mp_reshaped = mp.reshape(-1, self.num_params).contiguous() - # batch = torch.rand(1, 512, 28, 28, device=device) logging.info(f"get_loss: fwd start , batch size {batch_size}") predictions = ( self.param_models[dtype] @@ -333,26 +326,8 @@ def get_loss_grid( label.reshape(1, 1, -1, 1).expand(*mp.shape[:2], batch_size, 1), ).mean(axis=[2, 3]) logging.info("get_loss: gather done") - # logging.info( - # f"get_loss: {label.shape} {predictions.shape} {label.device} {predictions.device} {mp.device}" - # ) - # logging.info( - # f"grid {grid_idx} MO:{predictions[grid_size//2,grid_size//2].mean()} {predictions[grid_size//2,grid_size//2].shape}" - # ) - - # loss_np = loss.detach().cpu().numpy() - # breakpoint() - # loss -= loss.mean() - # loss /= loss.std() + 1e-5 - # print(loss.mean()) - # invert loss_np loss = -loss grid_losses.append(loss.unsqueeze(0)) - # - # loss = -loss.cpu() - # breakpoint() - # a = 1 loss = torch.concatenate(grid_losses, dim=0).cpu() - # breakpoint() logging.info(f"get_loss: return loss {loss[:, grid_size // 2, grid_size // 2]}") return loss diff --git a/hudes/view.py b/hudes/view.py index ad8be33..4ddc871 100644 --- a/hudes/view.py +++ b/hudes/view.py @@ -27,9 +27,6 @@ update_grid_vbo, ) -# backend = "Agg" -# backend = "cairo" -# matplotlib.use(backend) plt_backend = matplotlib.get_backend() import matplotlib.style as mplstyle @@ -37,78 +34,6 @@ mplstyle.use("fast") -def surface_to_npim(surface): - """Transforms a Cairo surface into a numpy array.""" - im = +np.frombuffer(surface.get_data(), np.uint8) - H, W = surface.get_height(), surface.get_width() - im.shape = (H, W, 4) # for RGBA - return im[:, :, :3] - - -def svg_to_npim(svg_bytestring, dpi): - """Renders a svg bytestring as a RGB image in a numpy array""" - tree = cairosvg.parser.Tree(bytestring=svg_bytestring) - surf = cairosvg.surface.PNGSurface(tree, None, dpi).cairo - return surface_to_npim(surf) - - -# Shader creation helper functions -def create_shader(shader_type, source): - shader = glCreateShader(shader_type) - glShaderSource(shader, source) - glCompileShader(shader) - - # Check for compilation errors - if not glGetShaderiv(shader, GL_COMPILE_STATUS): - error = glGetShaderInfoLog(shader).decode() - raise RuntimeError(f"Shader compilation error: {error}") - return shader - - -def create_shader_program(vertex_source, fragment_source): - program = glCreateProgram() - - # Create vertex and fragment shaders - vertex_shader = create_shader(GL_VERTEX_SHADER, vertex_source) - fragment_shader = create_shader(GL_FRAGMENT_SHADER, fragment_source) - - # Attach shaders to the program - glAttachShader(program, vertex_shader) - glAttachShader(program, fragment_shader) - - # Link the program - glLinkProgram(program) - - # Check for linking errors - if not glGetProgramiv(program, GL_LINK_STATUS): - error = glGetProgramInfoLog(program).decode() - raise RuntimeError(f"Program linking error: {error}") - - # Clean up shaders (they are now linked into the program) - glDeleteShader(vertex_shader) - glDeleteShader(fragment_shader) - - return program - - -# Vertex and fragment shader source code (compatible with OpenGL 2.1) -vertex_shader_src = """ -#version 120 -attribute vec3 vertexPosition; // Position of each vertex (x, y, z) - -void main() { - gl_Position = vec4(vertexPosition, 1.0); // Set the vertex position -} -""" - -fragment_shader_src = """ -#version 120 -void main() { - gl_FragColor = vec4(1.0, 0.0, 0.0, 1.0); // Set all fragments to red -} -""" - - def norm_mesh(mesh_grid): # self.mesh_grid /= 3 # self.origin_loss = mesh_grid[self.grid_size // 2, self.grid_size // 2] @@ -120,46 +45,6 @@ def norm_mesh(mesh_grid): return mesh_grid -# Helper function used for visualization in the following examples -def identify_axes(ax_dict, fontsize=48): - """ - Helper to identify the Axes in the examples below. - - Draws the label in a large font in the center of the Axes. - - Parameters - ---------- - ax_dict : dict[str, Axes] - Mapping between the title / label and the Axes. - fontsize : int, optional - How big the label should be. - """ - kw = dict(ha="center", va="center", fontsize=fontsize, color="darkgrey") - for k, ax in ax_dict.items(): - ax.text(0.5, 0.5, k, transform=ax.transAxes, **kw) - - # pygame.init() - # pygame.display.set_mode(self.window_size) - return pg.display.get_surface() - - -def _print_device_info(): - for i in range(pygame.midi.get_count()): - r = pygame.midi.get_device_info(i) - (interf, name, input, output, opened) = r - - in_out = "" - if input: - in_out = "(input)" - if output: - in_out = "(output)" - - print( - "%2i: interface :%s:, name :%s:, opened :%s: %s" - % (i, interf, name, opened, in_out) - ) - - class View: def __init__(self, use_midi=False): @@ -172,27 +57,13 @@ def __init__(self, use_midi=False): print(f"using input_id :{self.midi_input_id}:") self.midi_input = pygame.midi.Input(self.midi_input_id) - # dpi = 200 # plt.rcParams["figure.dpi"] logging.info(f"Matplotlib backend: {plt.get_backend()}") self.window_size = (1200, 800) - # self.fig = plt.figure( - # figsize=(self.window_size[0] / dpi, self.window_size[1] / dpi), dpi=dpi - # ) self.fig = plt.figure(figsize=(12, 8), facecolor="white") - # if self.fig.dpi != dpi: - # logging.warning( - # f"DPI flag not respected by matplotlib backend ({plt.get_backend()})! Should be {dpi} but is {self.fig.dpi} " - # ) - # self.window_size = ( - # int(self.fig.get_figwidth() * self.fig.dpi), - # int(self.fig.get_figheight() * self.fig.dpi), - # ) - self.window = pg.display.set_mode(self.window_size) - # self.canvas = agg.FigureCanvasAgg(self.fig) self.canvas = self.fig.canvas self.renderer = self.canvas.get_renderer() if "cairo" in plt_backend.lower(): @@ -255,9 +126,6 @@ def update_examples(self, train_data: torch.Tensor): ax.im.set_data(train_data[idx]) ax.redraw = True - # self.axd["I"].cla() - # self.axd["I"].imshow(train_data[3]) - def update_top(self, maybe_new_best_score): if self.best_score is None or maybe_new_best_score > self.best_score: self.best_score = maybe_new_best_score @@ -346,8 +214,6 @@ def plot_train_and_val( maybe_new_best_score = min(val_losses) if len(val_losses) > 0 else -math.inf self.update_top(maybe_new_best_score=maybe_new_best_score) - n = len(train_losses) - # x = torch.arange(n) self.axd["B"].cla() self.axd["B"].plot(train_steps, train_losses, label="train") self.axd["B"].plot(val_steps, val_losses, label="val") @@ -357,12 +223,6 @@ def plot_train_and_val( self.axd["B"].set_ylabel("Loss") self.axd["B"].redraw = True - # self.axd["C"].cla() - # self.axd["C"].plot(train_steps[n // 2 :], train_losses[n // 2 :], label="train") - # self.axd["C"].set_title("Loss [half time]") - # self.axd["C"].set_xlabel("Step") - # self.axd["C"].set_yticks([]) - self.axd["D"].cla() self.axd["D"].plot(train_steps[-8:], train_losses[-8:], label="train") self.axd["D"].set_title("Loss [last 8steps]") @@ -384,69 +244,27 @@ def draw_or_restore(self): def draw(self): logging.debug("hudes_client: redraw") - # cairo - # np.frombuffer(self.canvas._get_printed_image_surface().get_data(),np.uint8) - # https://www.pygame.org/wiki/CairoPygame if "cairo" in plt_backend.lower(): - # self.renderer.gc.ctx = fake_ctx() - # self.canvas.draw() - if True: - self.draw_or_restore() + self.draw_or_restore() - # self.update_top(self.best_score) - surf = pygame.image.frombuffer( - self.surface.get_data(), self.window_size, "RGBA" - ) - self.screen.blit(surf, (0, 0)) - self.screen.blit(self.top_title_rendered, (0, 0)) - # self.draw_or_restore() - # self.renderer = self.canvas._renderer - - # self.draw_or_restore() - if False: - self.canvas.draw() - surf = pg.image.frombuffer( - # self.renderer.tostring_rgb(), - self.canvas._get_printed_image_surface().get_data(), - self.window_size, - "RGBA", - ) - self.screen.blit(surf, (0, 0)) - else: # backend.lower()=='agg': - # self.canvas.draw() - # self.canvas.update() - # breakpoint() + surf = pygame.image.frombuffer( + self.surface.get_data(), self.window_size, "RGBA" + ) + self.screen.blit(surf, (0, 0)) + self.screen.blit(self.top_title_rendered, (0, 0)) + else: self.renderer.clear() self.draw_or_restore() - # self.update_top(self.best_score) - # if self.redraw_train_and_val: - # self.axd["B"].draw(self.renderer) - # self.axd["B"].cache = self.fig.canvas.copy_from_bbox( - # self.axd["B"].get_tightbbox(self.renderer) - # ) - # self.axd["D"].draw(self.renderer) - # self.axd["D"].cache = self.fig.canvas.copy_from_bbox(self.axd["D"].bbox) - # self.redraw_train_and_val = False - # else: - # self.fig.canvas.restore_region(self.axd["B"].cache) - # self.fig.canvas.restore_region(self.axd["D"].cache) - # TODO TRY FORM BUFFER AND USE THE TOSTRING BUFFER!!! surf = pg.image.frombytes( self.renderer.tostring_rgb(), self.window_size, "RGB", ) - # surf = pg.image.frombuffer( - # self.renderer.buffer_rgba(), - # self.window_size, - # "RGBA", - # ) self.screen.blit(surf, (0, 0)) self.screen.blit(self.top_title_rendered, (0, 0)) - # else: - # self.draw_text() + pg.display.flip() # draws whole screen vs update that draws a parts logging.debug("hudes_client: redraw done") @@ -461,7 +279,6 @@ def __init__(self, grid_size, grids): pg.init() pg.font.init() - # display = (800, 600) self.window_size = (1200, 800) @@ -573,10 +390,6 @@ def __init__(self, grid_size, grids): GL_ELEMENT_ARRAY_BUFFER, self.indices.nbytes, self.indices, GL_STATIC_DRAW ) - # # Enable vertex attribute 0 (positions) - # glEnableClientState(GL_VERTEX_ARRAY) - # glVertexPointer(3, GL_FLOAT, 0, None) - # Enable vertex arrays and color arrays glEnableClientState(GL_VERTEX_ARRAY) glEnableClientState(GL_COLOR_ARRAY) @@ -593,7 +406,6 @@ def __init__(self, grid_size, grids): self.angleH = 0.0 self.angleV = self.default_angleV self.origin_loss = 0.0 - self.target = (0.0, 0.0, 0.0) # init plt plt.style.use("dark_background") @@ -612,15 +424,6 @@ def __init__(self, grid_size, grids): self.dtype = "?" self.batch_size = "?" - # self.screen = pg.display.get_surface() - - # # # Step 2: Create the Matplotlib figure - # self.fig, ax = plt.subplots(figsize=(4, 3), facecolor="white") - # x = np.linspace(0, 10, 100) - # y = np.sin(x) - # ax.plot(x, y) - # ax.set_title("Sine Wave") - def update_examples(self, train_data: torch.Tensor): self.axd["F"].cla() self.axd["F"].imshow(train_data[0]) @@ -635,16 +438,10 @@ def update_top(self, best_score): self.fig.suptitle("Human Descent: MNIST Top-score: ?") else: self.fig.suptitle(f"Human Descent: MNIST Top-score: {best_score:.5e}") - # self.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) def update_step_size( self, log_step_size: float, max_log_step_size: float, min_log_step_size: float ): - # self.axd["I"].cla() - # self.axd["I"].bar([0], log_step_size) - # self.axd["I"].set_ylim(min_log_step_size, max_log_step_size) - # self.axd["I"].set_title("log(Step size)") - # self.axd["I"].set_xticks([]) pass def update_confusion_matrix(self, confusion_matrix: torch.Tensor): @@ -678,8 +475,6 @@ def plot_train_and_val( best_score = min(val_losses) if len(val_losses) > 0 else -math.inf self.update_top(best_score) - n = len(train_losses) - # x = torch.arange(n) self.axd["B"].cla() self.axd["B"].plot(train_steps, train_losses, label="train") self.axd["B"].plot(val_steps, val_losses, label="val") @@ -697,77 +492,33 @@ def plot_train_and_val( def update_mesh_grids(self, mesh_grids): self.raw_mesh_grids = mesh_grids - # if self.grids > 1: - # normalized_grids.append(norm_mesh(mesh_grids.sum(axis=0)).unsqueeze(0)) - # for grid_idx in range(mesh_grids.shape[0]): - # normalized_grids.append( - # norm_mesh(mesh_grids[grid_idx].clone()).unsqueeze(0) - # ) # might not need clone if we are safe here - - # if self.grids > 1: - # normalized_grids.append(norm_mesh(mesh_grids.sum(axis=0)).unsqueeze(0)) - origin_loss = mesh_grids[0, self.grid_size // 2, self.grid_size // 2].item() mesh_grids -= origin_loss - # _u, _std, _mx = mesh_grids.mean(), mesh_grids.var() + _mx = mesh_grids.abs().max() eps = 1e-3 mesh_grids *= self.grid_width / (_mx + eps) - # for grid_idx in range(mesh_grids.shape[0]): - # normalized_grids.append( - # # ((mesh_grids[grid_idx] - _u) / (_std + eps)).unsqueeze(0) - # self.grid_width - # * ((mesh_grids[grid_idx] - _u) / (_mx + eps)).unsqueeze(0) - # ) # might not need clone if we are safe here - - # # origin_loss = mesh_grid[grid_size // 2, grid_size // 2].item() - # self.normalized_grids = torch.concatenate(normalized_grids, dim=0) self.normalized_grids = mesh_grids - # breakpoint() - # Define the center point (where the red sphere is located) and relative target position (A, B) - center_row, center_col = self.grid_size // 2, self.grid_size // 2 - - # Relative target offset (A, B) in grid units - A, B = 0, 5 # Example: One cell to the right of the red sphere - - # Convert relative target (A, B) to absolute grid coordinates - target_row = center_row + A - target_col = center_col + B - - self.target = (0, 1, 0) - # # Ensure target is within bounds of the height_map - # if 0 <= target_row < self.grid_size and 0 <= target_col < self.grid_size: - # # Convert grid position to 3D coordinates - # target_x = (target_row - (self.grid_size - 1) / 2.0) * self.spacing - # target_y = self.origin_loss # 3 # self.mesh_grid[-target_row, -target_col] - # target_z = (target_col - (self.grid_size - 1) / 2.0) * self.spacing - # self.target = (target_x, target_y, target_z) # Target grid point - # else: - # self.target = (0, self.origin_loss, 0) # Default to center if out of bounds - - # self.target = (target_x, target_y, target_z) # Target grid point self.update_points_and_colors() def update_points_and_colors(self): if self.use_surface: new_points = create_surface_grid_points(self.normalized_grids, self.spacing) else: - # new_points_old = create_grid_points(self.normalized_grids, self.spacing) new_points, new_colors = create_grid_points_with_colors( self.normalized_grids, self.spacing, self.grid_colors, selected_grid=self.selected_grid, ) - # print(new_points.shape) - print(new_points.shape[0] / self.grid_size) + start_idx = self.selected_grid * self.grid_size * self.grid_size new_points[ start_idx : start_idx + self.grid_size * self.grid_size, [0, 2] ] *= self.selected_grid_multiplier - # breakpoint() + update_grid_vbo(self.vbo, new_points) update_grid_cbo(self.cbo, new_colors) @@ -795,15 +546,12 @@ def adjust_angles(self, angle_H, angle_V): self.angleV = norm_deg(self.angleV) # % 360 self.angleH = norm_deg(self.angleH) # % 360 self.angleV = np.sign(self.angleV) * min(np.abs(self.angleV), self.max_angleV) - # print(self.angleH, self.angleV) def reset_angle(self): self.angleH = 0 self.angleV = self.default_angleV def draw_all_text(self): - - # render_text_2d("Batch size:", 36, self.window_size[0], self.window_size[1]) render_text_2d( f"batch-size: {self.batch_size}, dtype: {self.dtype}", 20, @@ -812,8 +560,6 @@ def draw_all_text(self): ) def draw(self): - - # glClearColor(1.0, 1.0, 1.0, 1.0) # Set the clear color to white # Handle mouse motion for rotation if self.is_mouse_dragging: mouse_pos = pg.mouse.get_pos() @@ -835,10 +581,6 @@ def draw(self): self.camera_distance = self.total_width * self.scale_factor glTranslatef(0, 0.0, -self.camera_distance) - # Apply rotations - # glRotatef(self.angleV, 1, 0, 0) - # glRotatef(self.angleH, 0, 1, 0) - # Translate to center the grids # -3 for now, moves it up glTranslatef(-self.total_width / 2.0 + self.grid_width / 2, -3, 0.0) @@ -912,13 +654,7 @@ def draw(self): ) # Draw the grid as a surface using triangles - draw_red_sphere(0.0) - # draw_red_plane( - # 0.0, - # grid_size=self.grid_size, - # spacing=self.spacing, - # ) # Restore the previous matrix state glPopMatrix() diff --git a/hudes/websocket_client.py b/hudes/websocket_client.py index ba95e59..b1b8fd0 100644 --- a/hudes/websocket_client.py +++ b/hudes/websocket_client.py @@ -15,37 +15,6 @@ from hudes import hudes_pb2 -""" -Websocket client has a queue in and out -can push many messages down into ws client, but only sends as fast as possible - -ws client loop (thread) - - -While True: - #prepare and send - while get(timeout=0.1) from queue(): - check how many events - goup as many as possible - - send - - #recv and pipe back up - recv(timeout=0) - - - -""" - - -# def mesh_grid_config_message(dimA, dimB, grid_size, step_size): -# return hudes_pb2.Control( -# type=hudes_pb2.Control.CONTROL_MESHGRID_CONFIG, -# mesh_grid_config=hudes_pb2.MeshGridConfig( -# dimA=dimA, dimB=dimB, grid_size=grid_size, step_size=step_size -# ), -# ) - @cache def next_batch_message(): @@ -180,9 +149,7 @@ async def send_dims(n: int = 10): for _ in range(n): msg = hudes_pb2.Control( type=hudes_pb2.Control.CONTROL_DIMS, - # dims_and_steps=[hudes_pb2.DimAndStep(dim=1, step=0.01)], ) - # msg = {"type": "control", "dims": {1: 0.1, 2: 0.3}} logging.info(msg.SerializeToString()) await websocket.send(msg.SerializeToString()) await asyncio.sleep(0.01) diff --git a/hudes/websocket_server.py b/hudes/websocket_server.py index 1d07b47..31120d3 100644 --- a/hudes/websocket_server.py +++ b/hudes/websocket_server.py @@ -49,9 +49,7 @@ def prepare_batch_example_message( type=hudes_pb2.BatchExamples.Type.IMG_BW, n=n, train_data=pickle.dumps(batch[0][:n].tolist()), - # val_data=pickle.dumps(batch["val"][0][:n].tolist()), train_labels=pickle.dumps(batch[1][:n].tolist()), - # val_labels=pickle.dumps(batch["val"][1][:n].tolist()), batch_idx=batch_idx, ), ) @@ -65,7 +63,6 @@ def listen_and_run( mad.move_to_device() mad.fuse() mad.init_param_model() - # breakpoint() client_weights = {} # TODO memory leak, but prevents continously copying models @@ -377,7 +374,6 @@ async def process_client(websocket, client_runner_q): logging.debug(f"process_client: {client_idx} : next batch") client.batch_idx += 1 client.request_full_val = True - # send back batch examples elif msg.type == hudes_pb2.Control.CONTROL_NEXT_DIMS: logging.debug(f"process_client: {client_idx} : next dims") @@ -394,14 +390,9 @@ async def process_client(websocket, client_runner_q): client.batch_size = msg.config.batch_size client.dtype = getattr(torch, msg.config.dtype) - # send back batch examples elif msg.type == hudes_pb2.Control.CONTROL_QUIT: logging.debug(f"process_client: {client_idx} : quit") break - # elif msg.type == hudes_pb2.Control.CONTROL_MESHGRID_CONFIG: - # logging.info( - # f"process_client: {client_idx} : meshgrid, {client.dims_offset}, {config['dims_at_a_time']}" - # ) else: logging.warning("received invalid type from client")