Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 9, 2024
1 parent a93da78 commit a0a3016
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 82 deletions.
2 changes: 2 additions & 0 deletions hudes/hudes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ message Config {
required int32 mesh_grid_size = 3;
required float mesh_step_size = 4;
required int32 mesh_grids = 5;
required int32 batch_size = 6;
required string dtype = 7;
}


Expand Down
5 changes: 5 additions & 0 deletions hudes/hudes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(
self.inital_step_size_idx = inital_step_size_idx
self.step_size_resolution = step_size_resolution

self.batch_size = 128
self.dtype = "float16"

def get_next_batch(self):
self.hudes_websocket_client.send_q.put(next_batch_message().SerializeToString())

Expand All @@ -75,6 +78,8 @@ def send_config(self):
mesh_step_size=self.step_size,
mesh_grids=self.mesh_grids,
mesh_grid_size=self.mesh_grid_size,
batch_size=self.batch_size,
dtype=self.dtype,
)

def dims_and_steps_updated(self):
Expand Down
30 changes: 15 additions & 15 deletions hudes/hudes_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions hudes/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,7 @@ def mnist_model_data_and_subpace(
model: nn.Module,
seed: int = 0,
store: str = "./",
train_batch_size: int = 512,
val_batch_size: int = 1024,
device="cpu",
dtype=torch.float32,
loss_fn=indexed_loss,
):
transform = transforms.Compose(
Expand All @@ -132,7 +129,6 @@ def mnist_model_data_and_subpace(
train=True,
transform=transform,
),
train_batch_size,
seed=seed,
)
val_data_batcher = DatasetBatcher(
Expand All @@ -142,7 +138,6 @@ def mnist_model_data_and_subpace(
train=False,
transform=transform,
),
val_batch_size,
seed=seed,
)
return ModelDataAndSubspace(
Expand All @@ -152,5 +147,4 @@ def mnist_model_data_and_subpace(
loss_fn=indexed_loss,
minimize=False,
device=device,
dtype=dtype,
)
132 changes: 85 additions & 47 deletions hudes/model_data_and_subspace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import math
from functools import cache
Expand All @@ -12,21 +13,24 @@


class DatasetBatcher:
def __init__(self, ds, batch_size: int, seed: int = 0):
self.len = math.ceil(len(ds) / batch_size)
self.batch_size = batch_size
def __init__(self, ds, seed: int = 0):
self.ds = ds
self.seed = seed
g = torch.Generator()
g.manual_seed(seed)
g.manual_seed(self.seed)
self.idxs = torch.randperm(len(self.ds), generator=g)

@cache
def get_len(self, batch_size: int):
return math.ceil(len(self.ds) / batch_size)

# TODO optionally cache this!
@cache
def __getitem__(self, idx: int):
print("BATCHER", idx)
idx = idx % self.len
start_idx = idx * self.batch_size
end_idx = min(len(self.ds), start_idx + self.batch_size)
def get_batch(self, batch_size, batch_idx):
logging.debug(f"get_batch size: {batch_size} idx: {batch_idx}")
batch_idx = batch_idx % self.get_len(batch_size=batch_size)
start_idx = batch_idx * batch_size
end_idx = min(len(self.ds), start_idx + batch_size)
x, y = torch.cat(
[self.ds[idx][0] for idx in self.idxs[start_idx:end_idx]], dim=0
), torch.tensor([self.ds[idx][1] for idx in self.idxs[start_idx:end_idx]])
Expand Down Expand Up @@ -97,12 +101,12 @@ def __init__(
seed: int = 0,
minimize: bool = False,
device="cpu",
dtype=torch.float32,
val_batch_size: int = 1024,
):
self.val_batch_size = val_batch_size
self.device = device
self.dtype = dtype
self.model = model # .to(self.device)
self.num_params = sum([p.numel() for p in self.model.parameters()])
self._model = model # .to(self.device)
self.num_params = sum([p.numel() for p in self._model.parameters()])
self.train_data_batcher = train_data_batcher
self.val_data_batcher = val_data_batcher
self.minimize = minimize
Expand All @@ -116,29 +120,43 @@ def __init__(
) # should be good enough
self.loss_fn = loss_fn
self.return_n_preds = 5
self.models = {
torch.float32: copy.deepcopy(self._model).to(torch.float32),
torch.float16: copy.deepcopy(self._model).to(torch.float16),
}

def move_to_device(self):
self.model = self.model.to(self.dtype).to(self.device)
self.models = {k: v.to(self.device) for k, v in self.models.items()}

def fuse(self):
self.model_params = fuse_parameters(self.model, self.device, self.dtype)
self.saved_weights = self.model_params.detach().clone()
self.model_params = {
torch.float32: fuse_parameters(
self.models[torch.float32], self.device, torch.float32
),
torch.float16: fuse_parameters(
self.models[torch.float16], self.device, torch.float16
),
}
self.saved_weights = {
torch.float32: self.model_params[torch.float32].detach().clone(),
torch.float16: self.model_params[torch.float16].detach().clone(),
}
self.fused = True

# todo cache this?
@cache
@torch.no_grad
def get_dim_vec(self, dim: int):
def get_dim_vec(self, dim: int, dtype):
assert self.fused
g = torch.Generator(device=self.device)
g.manual_seed(self.seeds_for_dims[dim % MAX_DIMS].item())
return torch.nn.functional.normalize(
torch.rand(
1,
*self.model_params.shape,
*self.model_params[dtype].shape,
generator=g,
device=self.device,
dtype=self.dtype,
dtype=dtype,
)
- 0.5,
p=2,
Expand All @@ -147,55 +165,57 @@ def get_dim_vec(self, dim: int):

# dims is a dictionary {dim:step_size}
@torch.no_grad
def delta_from_dims(self, dims: dict[int, float]):
def delta_from_dims(self, dims: dict[int, float], dtype: torch.dtype):
if len(dims) > 0:
return torch.cat(
[
self.get_dim_vec(d) * v for d, v in dims.items()
self.get_dim_vec(d, dtype=dtype) * v for d, v in dims.items()
] # , device=self.device
).sum(axis=0)
else:
return self.blank_weight_vec()

# todo cache this?
@cache
def get_batch(self, idx: int):
def get_batch(self, batch_size: int, batch_idx: int, dtype):
r = {}
for name, batcher in (
("train", self.train_data_batcher),
("val", self.val_data_batcher),
):
batch = batcher[idx]
batch = batcher.get_batch(batch_size, batch_idx)
r[name] = (
batch["data"].to(device=self.device, dtype=self.dtype),
batch["data"].to(device=self.device, dtype=dtype),
batch["label"].to(device=self.device),
)
return r

# TODO could optimize with one large chunk of shared memory? and slice it?
@cache
def blank_weight_vec(self):
wv = torch.zeros(*self.model_params.shape, device=self.device)
wv = torch.zeros(*self.model_params[torch.float32].shape, device=self.device)
# wv.share_memory_()
return wv

@torch.no_grad
def set_parameters(self, weights: torch.Tensor):
def set_parameters(self, weights: torch.Tensor, dtype):
assert self.fused
# self.model_params.copy_(weights) # segfaults?
# self.model_params *= 0
# self.model_params += weights
self.model_params.data.copy_(weights)
self.model_params[dtype].data.copy_(weights)

@torch.no_grad
def val_model_inference_with_delta_weights(self, weights: torch.Tensor):
def val_model_inference_with_delta_weights(self, weights: torch.Tensor, dtype):
assert self.fused
self.set_parameters(weights)
self.set_parameters(weights, dtype)
full_val_loss = 0
n = 0
for batch_idx in range(self.val_data_batcher.len):
batch = self.get_batch(batch_idx)
model_output = self.model(batch["val"][0])
for batch_idx in range(self.val_data_batcher.get_len(self.val_batch_size)):
batch = self.get_batch(
batch_size=self.val_batch_size, batch_idx=batch_idx, dtype=dtype
)
model_output = self.models[dtype](batch["val"][0])
full_val_loss += self.loss_fn(model_output, batch["val"][1]).sum().item()
n += batch["val"][1].shape[0]
if not self.minimize:
Expand All @@ -204,20 +224,20 @@ def val_model_inference_with_delta_weights(self, weights: torch.Tensor):

@torch.no_grad
def train_model_inference_with_delta_weights(
self, weights: torch.Tensor, batch_idx: int
self, weights: torch.Tensor, batch_size: int, batch_idx: int, dtype
) -> dict[str, torch.Tensor]:
assert self.fused
batch = self.get_batch(batch_idx)
self.set_parameters(weights)
model_output = self.model(batch["train"][0])
batch = self.get_batch(batch_size, batch_idx, dtype=dtype)
self.set_parameters(weights, dtype)
model_output = self.models[dtype](batch["train"][0])
train_loss = self.loss_fn(model_output, batch["train"][1]).mean().item()
train_pred = self.model.probs(model_output)
train_pred = self.models[dtype].probs(model_output)
if not self.minimize:
train_loss = -train_loss

confusion_matrix = get_confusion_matrix(train_pred, batch["train"][1])
logging.info(
f"train loss: {train_loss} MO:{model_output.mean()}/{model_output.shape} weights {weights.mean().item()}"
f"train loss: {train_loss} MO:{model_output.mean()}/{model_output.shape} weights {weights.mean().item()} {dtype}"
)
return {
"train_loss": train_loss,
Expand All @@ -226,7 +246,10 @@ def train_model_inference_with_delta_weights(
}

def init_param_model(self):
self.param_model = param_nn_from_sequential(self.model.net)
self.param_models = {
torch.float32: param_nn_from_sequential(self.models[torch.float32].net),
torch.float16: param_nn_from_sequential(self.models[torch.float16].net),
}

# return model parameters for given ranges
def dim_idxs_and_ranges_to_models_parms(
Expand All @@ -235,27 +258,40 @@ def dim_idxs_and_ranges_to_models_parms(
dims: torch.Tensor,
arange: torch.Tensor,
brange: torch.Tensor,
dtype: torch.dtype,
):
assert len(dims) == 2
vs = torch.vstack([self.get_dim_vec(dim) for dim in dims])
vs = torch.vstack([self.get_dim_vec(dim, dtype=dtype) for dim in dims])

agrid, bgrid = torch.meshgrid(torch.tensor(arange), torch.tensor(brange))
agrid = agrid.unsqueeze(2)
bgrid = bgrid.unsqueeze(2)
return torch.concatenate([agrid, bgrid], dim=2).to(
self.dtype
dtype
) @ vs + base_weights.reshape(1, 1, -1)

def get_loss_grid(
self, base_weights, batch_idx, dims_offset, grids, grid_size, step_size
self,
base_weights,
batch_idx,
dims_offset,
grids,
grid_size,
step_size,
batch_size,
dtype,
):
logging.info(f"get_loss: start base weights {base_weights.mean().item()}")
logging.info(
f"get_loss: start base weights {base_weights.mean().item()} {dtype}"
)
assert grid_size % 2 == 1
assert grid_size > 3
assert grids > 0

logging.info("get_loss: get bach")
batch = self.get_batch(batch_idx)["train"]
batch = self.get_batch(batch_size=batch_size, batch_idx=batch_idx, dtype=dtype)[
"train"
]
logging.info("get_loss: get batch done")
data = batch[0].unsqueeze(0)
batch_size = data.shape[1]
Expand All @@ -270,7 +306,7 @@ def get_loss_grid(
] # which dims are doing this for

mp = self.dim_idxs_and_ranges_to_models_parms(
base_weights, dims, arange=r, brange=r
base_weights, dims, arange=r, brange=r, dtype=dtype
)

# self.model.net[:2](batch[0])
Expand All @@ -283,8 +319,10 @@ def get_loss_grid(
mp_reshaped = mp.reshape(-1, self.num_params).contiguous()
# batch = torch.rand(1, 512, 28, 28, device=device)
logging.info(f"get_loss: fwd start , batch size {batch_size}")
predictions = self.param_model.forward(mp_reshaped, data)[1].reshape(
*mp.shape[:2], batch_size, -1
predictions = (
self.param_models[dtype]
.forward(mp_reshaped, data)[1]
.reshape(*mp.shape[:2], batch_size, -1)
)
logging.info("get_loss: done fwd")
loss = torch.gather(
Expand Down
4 changes: 4 additions & 0 deletions hudes/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
# matplotlib.use(backend)
plt_backend = matplotlib.get_backend()

import matplotlib.style as mplstyle

mplstyle.use("fast")


def surface_to_npim(surface):
"""Transforms a Cairo surface into a numpy array."""
Expand Down
Loading

0 comments on commit a0a3016

Please sign in to comment.