Skip to content

Commit

Permalink
fix val batching; val on boot
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 10, 2024
1 parent 64087e3 commit c660c69
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 38 deletions.
4 changes: 2 additions & 2 deletions hudes/hudes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ message BatchExamples {
required Type type = 1;
required int32 n = 2;
required bytes train_data = 3;
required bytes val_data = 4;
optional bytes val_data = 4;
required bytes train_labels = 5;
required bytes val_labels = 6;
optional bytes val_labels = 6;
required int32 batch_idx = 7;
}

Expand Down
5 changes: 3 additions & 2 deletions hudes/hudes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def before_first_loop(self):
self.val_steps,
)
self.dims_and_steps_updated()
self.get_next_batch()

def before_pg_event(self):
pass
Expand Down Expand Up @@ -210,9 +211,9 @@ def receive_messages(self):
logging.debug("hudes_client: recieve message : examples")
received_batch = True
self.train_data = pickle.loads(msg.batch_examples.train_data)
self.val_data = pickle.loads(msg.batch_examples.val_data)
# self.val_data = pickle.loads(msg.batch_examples.val_data)
self.train_labels = pickle.loads(msg.batch_examples.train_labels)
self.val_labels = pickle.loads(msg.batch_examples.val_labels)
# self.val_labels = pickle.loads(msg.batch_examples.val_labels)
logging.debug("hudes_client: recieve message : done")

elif msg.type == hudes_pb2.Control.CONTROL_VAL_LOSS:
Expand Down
2 changes: 1 addition & 1 deletion hudes/hudes_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 27 additions & 27 deletions hudes/model_data_and_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_len(self, batch_size: int):
# TODO optionally cache this!
@cache
def get_batch(self, batch_size, batch_idx):
logging.debug(f"get_batch size: {batch_size} idx: {batch_idx}")
logging.debug(f"get_batch size: {self} {batch_size} idx: {batch_idx}")
batch_idx = batch_idx % self.get_len(batch_size=batch_size)
start_idx = batch_idx * batch_size
end_idx = min(len(self.ds), start_idx + batch_size)
Expand Down Expand Up @@ -107,8 +107,7 @@ def __init__(
self.device = device
self._model = model # .to(self.device)
self.num_params = sum([p.numel() for p in self._model.parameters()])
self.train_data_batcher = train_data_batcher
self.val_data_batcher = val_data_batcher
self.batchers = {"train": train_data_batcher, "val": val_data_batcher}
self.minimize = minimize

self.fused = False # Cant fuse before forking
Expand Down Expand Up @@ -177,18 +176,13 @@ def delta_from_dims(self, dims: dict[int, float], dtype: torch.dtype):

# todo cache this?
@cache
def get_batch(self, batch_size: int, batch_idx: int, dtype):
r = {}
for name, batcher in (
("train", self.train_data_batcher),
("val", self.val_data_batcher),
):
batch = batcher.get_batch(batch_size, batch_idx)
r[name] = (
batch["data"].to(device=self.device, dtype=dtype),
batch["label"].to(device=self.device),
)
return r
def get_batch(self, batch_size: int, batch_idx: int, dtype, train_or_val: str):
assert train_or_val in self.batchers
batch = self.batchers[train_or_val].get_batch(batch_size, batch_idx)
return (
batch["data"].to(device=self.device, dtype=dtype),
batch["label"].to(device=self.device),
)

# TODO could optimize with one large chunk of shared memory? and slice it?
@cache
Expand All @@ -211,13 +205,16 @@ def val_model_inference_with_delta_weights(self, weights: torch.Tensor, dtype):
self.set_parameters(weights, dtype)
full_val_loss = 0
n = 0
for batch_idx in range(self.val_data_batcher.get_len(self.val_batch_size)):
for batch_idx in range(self.batchers["val"].get_len(self.val_batch_size)):
batch = self.get_batch(
batch_size=self.val_batch_size, batch_idx=batch_idx, dtype=dtype
batch_size=self.val_batch_size,
batch_idx=batch_idx,
dtype=dtype,
train_or_val="val",
)
model_output = self.models[dtype](batch["val"][0])
full_val_loss += self.loss_fn(model_output, batch["val"][1]).sum().item()
n += batch["val"][1].shape[0]
model_output = self.models[dtype](batch[0])
full_val_loss += self.loss_fn(model_output, batch[1]).sum().item()
n += batch[1].shape[0]
if not self.minimize:
full_val_loss = -full_val_loss
return {"val_loss": full_val_loss / n}
Expand All @@ -227,15 +224,15 @@ def train_model_inference_with_delta_weights(
self, weights: torch.Tensor, batch_size: int, batch_idx: int, dtype
) -> dict[str, torch.Tensor]:
assert self.fused
batch = self.get_batch(batch_size, batch_idx, dtype=dtype)
batch = self.get_batch(batch_size, batch_idx, dtype=dtype, train_or_val="train")
self.set_parameters(weights, dtype)
model_output = self.models[dtype](batch["train"][0])
train_loss = self.loss_fn(model_output, batch["train"][1]).mean().item()
model_output = self.models[dtype](batch[0])
train_loss = self.loss_fn(model_output, batch[1]).mean().item()
train_pred = self.models[dtype].probs(model_output)
if not self.minimize:
train_loss = -train_loss

confusion_matrix = get_confusion_matrix(train_pred, batch["train"][1])
confusion_matrix = get_confusion_matrix(train_pred, batch[1])
logging.info(
f"train loss: {train_loss} MO:{model_output.mean()}/{model_output.shape} weights {weights.mean().item()} {dtype}"
)
Expand Down Expand Up @@ -289,9 +286,12 @@ def get_loss_grid(
assert grids > 0

logging.info("get_loss: get bach")
batch = self.get_batch(batch_size=batch_size, batch_idx=batch_idx, dtype=dtype)[
"train"
]
batch = self.get_batch(
batch_size=batch_size,
batch_idx=batch_idx,
dtype=dtype,
train_or_val="train",
)
logging.info("get_loss: get batch done")
data = batch[0].unsqueeze(0)
batch_size = data.shape[1]
Expand Down
14 changes: 8 additions & 6 deletions hudes/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ def prepare_batch_example_message(
mad: ModelDataAndSubspace,
n: int = 4,
):
batch = mad.get_batch(batch_size=batch_size, batch_idx=batch_idx, dtype=dtype)
batch = mad.get_batch(
batch_size=batch_size, batch_idx=batch_idx, dtype=dtype, train_or_val="train"
)
return hudes_pb2.Control(
type=hudes_pb2.Control.CONTROL_BATCH_EXAMPLES,
batch_examples=hudes_pb2.BatchExamples(
type=hudes_pb2.BatchExamples.Type.IMG_BW,
n=n,
train_data=pickle.dumps(batch["train"][0][:n].tolist()),
val_data=pickle.dumps(batch["val"][0][:n].tolist()),
train_labels=pickle.dumps(batch["train"][1][:n].tolist()),
val_labels=pickle.dumps(batch["val"][1][:n].tolist()),
train_data=pickle.dumps(batch[0][:n].tolist()),
# val_data=pickle.dumps(batch["val"][0][:n].tolist()),
train_labels=pickle.dumps(batch[1][:n].tolist()),
# val_labels=pickle.dumps(batch["val"][1][:n].tolist()),
batch_idx=batch_idx,
),
)
Expand Down Expand Up @@ -237,6 +239,7 @@ async def inference_runner_clients(mad, client_runner_q, inference_q, stop):
"dtype": client.dtype,
},
)
client.request_full_val = False


async def inference_result_sender(results_q, stop):
Expand Down Expand Up @@ -284,7 +287,6 @@ async def inference_result_sender(results_q, stop):
).SerializeToString()
)
logging.debug(f"inference_result_sender: sent val to client : done")
client.request_full_val = False
if train_or_val == "mesh":
logging.debug(f"inference_result_sender: sent mesh to client")
await client.websocket.send(
Expand Down

0 comments on commit c660c69

Please sign in to comment.