Skip to content

Commit

Permalink
[ET-VK][ez] Fix linear weight int4 test due to change in ATen API
Browse files Browse the repository at this point in the history
Pull Request resolved: #7739

## Context

Recently the ATen API for 4-bit quantized linear has changed, so our test must adapt to the change in API.

Concretely, the changes in API were:

* The `_for_cpu` suffix was added to the operator name
* The `_convert_weight_to_int4pack_mm` operator now expects unpacked 4-bit weights instead of a packed scheme where 2 4-bit values are packed into a single 8-bit value.
ghstack-source-id: 261959346
@exported-using-ghexport

Differential Revision: [D68333687](https://our.internmc.facebook.com/intern/diff/D68333687/)
  • Loading branch information
SS-JIA committed Jan 17, 2025
1 parent 1a6b7a6 commit d154714
Showing 1 changed file with 32 additions and 5 deletions.
37 changes: 32 additions & 5 deletions backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,38 @@ at::Tensor linear_weight_int4_reference_impl(
const size_t ndim = original_x_size.size();
const int64_t out_features = weights_4x2.size(0);
const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]});
const at::Tensor packed_weights =
at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles);
at::Tensor out = at::_weight_int4pack_mm(
x_flattened, packed_weights, groupsize, scales_and_zeros);
at::Tensor out = at::_weight_int4pack_mm_for_cpu(
x_flattened, weights_4x2, groupsize, scales_and_zeros);
std::vector<int64_t> out_shape(
original_x_size.begin(), original_x_size.end());
out_shape.at(ndim - 1) = out_features;
return out.reshape(out_shape);
}

at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) {
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
weights_shape[1] *= 2;

at::Tensor weights_unpacked =
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kInt));

const int64_t N = weights_unpacked.size(0);
const int64_t K = weights_unpacked.size(1);

for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k += 2) {
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
const uint8_t second_val = packed_val & 0x0F;
const uint8_t first_val = (packed_val & 0xF0) >> 4;

weights_unpacked[n][k] = int(first_val);
weights_unpacked[n][k + 1] = int(second_val);
}
}

return weights_unpacked;
}

at::Tensor dequantize_and_linear(
const at::Tensor& x,
const at::Tensor& weights_4x2,
Expand Down Expand Up @@ -91,13 +113,18 @@ void test_reference_linear_int4(
at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat));
at::Tensor weights_4x2 =
at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte));
at::Tensor weights_int = unpack_weights_4x2(weights_4x2);

const int k_groups = K / group_size;
at::Tensor scales_and_zeros =
at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));

at::Tensor out = linear_weight_int4_reference_impl(
x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
x,
at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size),
group_size,
scales_and_zeros,
inner_k_tiles);

at::Tensor out_ref = dequantize_and_linear(
x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
Expand Down

0 comments on commit d154714

Please sign in to comment.