Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 11, 2024
1 parent 0ffa103 commit e28398f
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 388 deletions.
12 changes: 0 additions & 12 deletions hudes/hudes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,32 +166,24 @@ def before_pg_event(self):
def run_loop(self):
self.before_first_loop()
while self.hudes_websocket_client.running:
# check and send local interactions(?)
self.before_pg_event()
redraw = False
for event in pg.event.get():
redraw |= self.process_key_press(event)

# logging.debug("hudes_client: receive messages")
redraw |= self.receive_messages()
# logging.debug("hudes_client: receive messages done")
if redraw:
self.view.draw()
else:
# logging.debug("hudes_client: sleep")
sleep(0.01)
# logging.debug("hudes_client: sleep up")

def receive_messages(self):
# listen from server?
received_message = False
received_train = False
received_batch = False
received_val = False
while self.hudes_websocket_client.recv_ready():
# logging.debug("hudes_client: recieve message")
received_message = True
# recv and process!
raw_msg = self.hudes_websocket_client.recv_msg()
msg = hudes_pb2.Control()
msg.ParseFromString(raw_msg)
Expand All @@ -207,15 +199,12 @@ def receive_messages(self):
self.train_losses.append(msg.train_loss_and_preds.train_loss)
self.train_steps.append(msg.request_idx)
logging.debug("hudes_client: recieve message : loss and preds : done")
# self.val_losses.append(msg.loss_and_preds.val_loss)

elif msg.type == hudes_pb2.Control.CONTROL_BATCH_EXAMPLES:
logging.debug("hudes_client: recieve message : examples")
received_batch = True
self.train_data = pickle.loads(msg.batch_examples.train_data)
# self.val_data = pickle.loads(msg.batch_examples.val_data)
self.train_labels = pickle.loads(msg.batch_examples.train_labels)
# self.val_labels = pickle.loads(msg.batch_examples.val_labels)
logging.debug("hudes_client: recieve message : done")

elif msg.type == hudes_pb2.Control.CONTROL_VAL_LOSS:
Expand All @@ -228,7 +217,6 @@ def receive_messages(self):
# called if we only changed scale etc?
elif msg.type == hudes_pb2.Control.CONTROL_MESHGRID_RESULTS:
self.view.update_mesh_grids(pickle.loads(msg.mesh_grid_results))
# print("GOT MESH GRID", self.mesh_grid.shape)

if received_message:
if received_train:
Expand Down
4 changes: 0 additions & 4 deletions hudes/hudes_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os

# from hudes.akai_client import AkaiClient
from hudes.keyboard_client import KeyboardClient
from hudes.keyboard_client_openGL import KeyboardClientGL
from hudes.view import OpenGLView, View
Expand Down Expand Up @@ -41,9 +40,6 @@ def main():
joystick_controller_key=args.controller,
)
view = OpenGLView(grid_size=args.grid_size, grids=args.grids)
# elif args.input == "akai":
# controller = AkaiClient(seed=args.seed)
# view = View()
elif args.input == "xtouch":
controller = XTouchClient(
addr=args.addr,
Expand Down
23 changes: 0 additions & 23 deletions hudes/keyboard_client_openGL.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,6 @@ def run_loop(self):

ct = time.time()
for joystick in self.joysticks.values():
# axes = joystick.get_numaxes()
# print(f"Number of axes: {axes}")
# for i in range(axes):
# axis = joystick.get_axis(i)
# print(f"Axis {i} value: {axis:>6.3f}")

if (
joystick.get_button(self.joystick_controller.right_js_press_button)
> 0.5
Expand All @@ -156,14 +150,6 @@ def run_loop(self):
self.step_size_increase(2)
self.send_config()

# if joystick.get_button(11) > 0.5:
# self.view.increase_zoom()
# redraw = True

# if joystick.get_button(12) > 0.5:
# self.view.decrease_zoom()
# redraw = True

if (
joystick.get_axis(self.joystick_controller.left_trig_axis) > 0.5
and (ct - last_select_press) > 0.2
Expand Down Expand Up @@ -206,15 +192,6 @@ def run_loop(self):

adjustH = B[0] * 2
adjustV = B[1]
# if np.abs(np.abs(angle) - 90 / 2) < 40:
# adjustH += np.sign(angle)
# redraw = True

# if np.abs(angle) < 40:
# adjustV += 1
# redraw = True
# elif np.abs(np.abs(angle) - 180) < 40:
# adjustV += -1
redraw = True
self.view.adjust_angles(adjustH, adjustV)
redraw = redraw | self.view.is_mouse_dragging | self.receive_messages()
Expand Down
1 change: 0 additions & 1 deletion hudes/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def mnist_model_data_and_subpace(
seed: int = 0,
store: str = "./",
device="cpu",
loss_fn=indexed_loss,
):
transform = transforms.Compose(
[
Expand Down
37 changes: 6 additions & 31 deletions hudes/model_data_and_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@ def fuse_parameters(model: nn.Module, device, dtype):
return params


# @torch.jit.script
def indexed_loss(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
@torch.jit.script
def jit_indexed_loss(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
return pred[torch.arange(label.shape[0]), label]


def indexed_loss(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
return jit_indexed_loss(pred, label)


def get_param_module(module):
if isinstance(module, torch.nn.Flatten):
return torch.nn.Flatten(start_dim=module.start_dim + 1)
Expand All @@ -75,10 +79,6 @@ def param_nn_from_sequential(model):

@torch.jit.script
def get_confusion_matrix(preds: torch.Tensor, labels: torch.Tensor):
# (Pdb) preds.shape
# torch.Size([512, 10])
# (Pdb) labels.shape
# torch.Size([512])
assert preds.ndim == 2 and labels.ndim == 1
n = preds.shape[1]
c_matrix = torch.vstack(
Expand Down Expand Up @@ -164,7 +164,6 @@ def get_dim_vec(self, dim: int, dtype):
dim=1,
).to(dtype)

# dims is a dictionary {dim:step_size}
@torch.no_grad
def delta_from_dims(self, dims: dict[int, float], dtype: torch.dtype):
if len(dims) > 0:
Expand Down Expand Up @@ -311,15 +310,9 @@ def get_loss_grid(
base_weights, dims, arange=r, brange=r, dtype=dtype
)

# self.model.net[:2](batch[0])
# self.param_model.modules_list=self.param_model.modules_list[:2]
# base_model = mp[grid_size//2,grid_size//2].reshape(1,-1)
# (self.model.net[:1](batch[0])==self.param_model.forward(bm,data)[1]).all()

logging.info(f"get_loss: mp done {mp.device} {mp.dtype}")

mp_reshaped = mp.reshape(-1, self.num_params).contiguous()
# batch = torch.rand(1, 512, 28, 28, device=device)
logging.info(f"get_loss: fwd start , batch size {batch_size}")
predictions = (
self.param_models[dtype]
Expand All @@ -333,26 +326,8 @@ def get_loss_grid(
label.reshape(1, 1, -1, 1).expand(*mp.shape[:2], batch_size, 1),
).mean(axis=[2, 3])
logging.info("get_loss: gather done")
# logging.info(
# f"get_loss: {label.shape} {predictions.shape} {label.device} {predictions.device} {mp.device}"
# )
# logging.info(
# f"grid {grid_idx} MO:{predictions[grid_size//2,grid_size//2].mean()} {predictions[grid_size//2,grid_size//2].shape}"
# )

# loss_np = loss.detach().cpu().numpy()
# breakpoint()
# loss -= loss.mean()
# loss /= loss.std() + 1e-5
# print(loss.mean())
# invert loss_np
loss = -loss
grid_losses.append(loss.unsqueeze(0))
#
# loss = -loss.cpu()
# breakpoint()
# a = 1
loss = torch.concatenate(grid_losses, dim=0).cpu()
# breakpoint()
logging.info(f"get_loss: return loss {loss[:, grid_size // 2, grid_size // 2]}")
return loss
Loading

0 comments on commit e28398f

Please sign in to comment.