Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: apply trainable scale for qdq linear and matmul #51

Open
yiliu30 opened this issue Apr 18, 2024 · 3 comments
Open

Question: apply trainable scale for qdq linear and matmul #51

yiliu30 opened this issue Apr 18, 2024 · 3 comments

Comments

@yiliu30
Copy link
Contributor

yiliu30 commented Apr 18, 2024

In a quantization scenario where fake quantization is utilized to assess the accuracy of a new algorithm with trainable scale, we can implement it for an eager model by replacing the Linear module with QDQLinear, as demonstrated below:

class QDQLinear(torch.nn.Module):
    def __init__(self, orign_linear: torch.nn.Module) -> None:
        super().__init__()
        self._orign_linear = original_tensor
        self.trainable_scale =  torch.nn.Parameter(torch.tensor(1), requires_grad=True)
    
    def qdq_tensor(self, input: torch.Tensor):
        # ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
        # int_input = q(input)
        # qdq_input = dq(int_input)
        # return qdq_input
        pass
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = self.qdq_tensor(input)
        return torch.nn.functional.linear(input, self._orign_linear.weight, self._orign_linear.bias)


### replace all `Linear` with `QDQLinear`

However, some models utilize torch.matmul to perform similar thing as torch.nn.Linear. We also want to apply the aforementioned QDQ method to torch.matmul, but this cannot be achieved through module swapping.

We may probably customize a new TorchDispatchMode to replace all aten.mm with qdq - aten.mm to apply qdq to all input tensors of torch.matmul or torch.nn.Linear. However, I'm currently unsure how to handle the trainable_scale. Do you happen to have any suggestions?

Thank you very much!

@albanD
Copy link
Owner

albanD commented Apr 18, 2024

Hi

If you want to pre-process the input to the Module, I think a module pre forward hook would work?

def qdq_tensor(input):
    pass

your_layer.register_forward_pre_hook(qdq_tensor)

Is that enough for you?

@yiliu30
Copy link
Contributor Author

yiliu30 commented Apr 19, 2024

Oh, sorry. There are a few errors in the question.
I want to pre-process the module's weight (for weight-only quantization).

class QDQLinear(torch.nn.Module):
    def __init__(self, orign_linear: torch.nn.Module) -> None:
        super().__init__()
        self._orign_linear = original_tensor
        self.trainable_scale =  torch.nn.Parameter(torch.tensor(1), requires_grad=True)
    
    def qdq_tensor(self, input: torch.Tensor, scale: torch.Tensor):
        # ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
        # int_input = q(input, scale)
        # qdq_input = dq(int_input)
        # return qdq_input
        pass
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        qdq_weight= self.qdq_tensor(self._orign_linear.weight, self.trainable_scale)  # <---------- q-dq weight
        return torch.nn.functional.linear(input, qdq_weight, self._orign_linear.bias)


### replace all `Linear` with `QDQLinear`

@albanD
Copy link
Owner

albanD commented Apr 19, 2024

If you do want these two params and combine them only when mod.weight is used, I would suggest reparametrization:

from torch.nn.utils.parametrize import register_parametrization

class QDQParam(torch.nn.Module):
    def forward(self, orig_linear_weight, scale):
        return qdq_tensor(orig_linear_weight, scale)

    def right_inverse(self, orig_linear_weight):
        return orig_linear_weight, torch.tensor(1)

m = nn.Linear(2, 2)
register_parametrization(m, "weight", QDQParam())

More details at https://pytorch.org/tutorials/intermediate/parametrizations.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants