Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 11, 2024
1 parent 757f0d8 commit b881080
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
4 changes: 3 additions & 1 deletion hudes/model_data_and_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def fuse_parameters(model: nn.Module, device, dtype):
return params


# @torch.jit.script
def indexed_loss(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
return pred[torch.arange(label.shape[0]), label]

Expand All @@ -72,6 +73,7 @@ def param_nn_from_sequential(model):
return param_nn.Sequential([get_param_module(m) for m in model])


@torch.jit.script
def get_confusion_matrix(preds: torch.Tensor, labels: torch.Tensor):
# (Pdb) preds.shape
# torch.Size([512, 10])
Expand All @@ -87,7 +89,7 @@ def get_confusion_matrix(preds: torch.Tensor, labels: torch.Tensor):
for idx in torch.arange(n)
]
)
return torch.nn.functional.normalize(c_matrix, p=1, dim=1)
return torch.nn.functional.normalize(c_matrix, p=1.0, dim=1)


class ModelDataAndSubspace:
Expand Down
1 change: 1 addition & 0 deletions hudes/param_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def forward(self, models_params, x):
return models_params, x


@torch.jit.script
class Linear:
def __init__(self, input_channels, output_channels):
self.input_channels = input_channels
Expand Down

0 comments on commit b881080

Please sign in to comment.