Skip to content
This repository has been archived by the owner on Nov 15, 2022. It is now read-only.

Commit

Permalink
More efficient MHA - faster padding (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored May 27, 2021
1 parent e2aaac9 commit 3e535a7
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 107 deletions.
4 changes: 1 addition & 3 deletions nestedtensor/csrc/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ Tensor NestedTensor_add_Tensor(
}
}
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
if (!get_is_contiguous(self)) {
self = NestedTensor_contiguous(self);
}
self = NestedTensor_contiguous(self);
int64_t self_dim = get_dim(self);
auto self_opt_sizes = get_opt_sizes(self);
if (self_opt_sizes[self_dim - 1] && other.dim() == 1 &&
Expand Down
121 changes: 25 additions & 96 deletions nestedtensor/csrc/cuda/mha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ at::Tensor bt_min_mha(
TORCH_CHECK(get_dim(query) == 3, "query needs to be 3 dim.");
TORCH_CHECK(get_dim(key) == 3, "key needs to be 3 dim.");
TORCH_CHECK(get_dim(value) == 3, "value needs to be 3 dim.");
TORCH_CHECK(get_nested_dim(query) == 1, "Query nested dim isn't 1.");
TORCH_CHECK(get_nested_dim(key) == 1, "Key nested dim isn't 1.");
TORCH_CHECK(get_nested_dim(value) == 1, "Value nested dim isn't 1.");
// TORCH_CHECK(in_proj_bias, "Input projection bias needs to be defined.");
// auto opt_sizes = get_opt_sizes(query);
// if (!opt_sizes[2]) {
Expand All @@ -57,88 +60,31 @@ at::Tensor bt_min_mha(
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
at::cuda::setCurrentCUDAStream(defaultStream);

int64_t input_tensor_size = batch_size * head_num * seq_len * size_per_head;
int64_t attn_tensor_size = batch_size * head_num * seq_len * seq_len;
int word_num = batch_size * seq_len;
Tensor prefix_sum = torch::zeros({word_num}, options);
Tensor batch_idx = torch::zeros({word_num}, options);
Tensor word_idx = torch::zeros({word_num}, options);
at::Tensor packed = at::matmul(query, attr_kernel.t()) + attr_bias;

int* prefix_sum_ptr = prefix_sum.data_ptr<int>();
int* batch_idx_ptr = batch_idx.data_ptr<int>();
int* word_idx_ptr = word_idx.data_ptr<int>();

at::Tensor tmp = get_buffer(query);

auto query_esize = get_efficient_nested_size(query);
TORCH_CHECK(query_esize.height() == 1, "Query nested dim isn't 1.");
auto query_esize_sizes = query_esize.sizes();

at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(float_options);
attr_mask = attr_mask * attr_mask.transpose(2, 3);

nteffectivetransformer::exclusiveScan_kernelLauncher(
prefix_sum_ptr,
input_mask.data_ptr<int>(),
input_mask.size(0) * input_mask.size(1),
defaultStream);


nteffectivetransformer::compressBertInput_kernelLauncher(
input_mask.data_ptr<int>(),
prefix_sum_ptr,
batch_idx_ptr,
word_idx_ptr,
(int32_t)(batch_size),
(int32_t)(seq_len),
(int32_t)(embedding_dim),
defaultStream);

at::Tensor packed = at::matmul(query, attr_kernel.t());
// TODO: Move into implementation of chunk for NestedTensor
at::Tensor packed_buf = get_buffer(packed).contiguous().reshape({-1, 3 * embedding_dim});
std::vector<at::Tensor> packed_chunks = packed_buf.chunk(3, -1);
at::Tensor q_buf = packed_chunks[0].contiguous().reshape({-1});
at::Tensor k_buf = packed_chunks[1].contiguous().reshape({-1});
at::Tensor v_buf = packed_chunks[2].contiguous().reshape({-1});

int valid_word_num = get_numel(query) / embedding_dim;

at::Tensor query_buf = torch::zeros(
{batch_size, head_num, seq_len, size_per_head}, float_options);
at::Tensor key_buf = torch::zeros(
{batch_size, head_num, seq_len, size_per_head}, float_options);
at::Tensor val_buf = torch::zeros(
{batch_size, head_num, seq_len, size_per_head}, float_options);
at::Tensor attr_out =
torch::zeros({valid_word_num, embedding_dim}, float_options);

std::vector<at::Tensor> bias_chunks = attr_bias.chunk(3);
at::Tensor attr_bias_Q = bias_chunks[0];
at::Tensor attr_bias_K = bias_chunks[1];
at::Tensor attr_bias_V = bias_chunks[2];

nteffectivetransformer::cuda::add_QKV_bias_padding_kernelLauncher<float>(
q_buf.data_ptr<float>(),
attr_bias_Q.data_ptr<float>(),
k_buf.data_ptr<float>(),
attr_bias_K.data_ptr<float>(),
v_buf.data_ptr<float>(),
attr_bias_V.data_ptr<float>(),
query_buf.data_ptr<float>(),
key_buf.data_ptr<float>(),
val_buf.data_ptr<float>(),
valid_word_num,
batch_size,
seq_len,
head_num,
size_per_head,
batch_idx_ptr,
word_idx_ptr,
defaultStream);
at::Tensor q_buf_ = packed_chunks[0].contiguous().reshape({-1});
at::Tensor k_buf_ = packed_chunks[1].contiguous().reshape({-1});
at::Tensor v_buf_ = packed_chunks[2].contiguous().reshape({-1});
at::Tensor q = wrap_buffer(std::move(q_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));
at::Tensor k = wrap_buffer(std::move(k_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));
at::Tensor v = wrap_buffer(std::move(v_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));

at::Tensor query_buf = to_padded_tensor(q, 0).contiguous();
at::Tensor key_buf = to_padded_tensor(k, 0).contiguous();
at::Tensor val_buf = to_padded_tensor(v, 0).contiguous();
query_buf = query_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
key_buf = key_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
val_buf = val_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);

key_buf = key_buf.transpose(2, 3);
at::Tensor attn_output_weights = at::matmul(query_buf, key_buf).contiguous();

at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(float_options);
attr_mask = attr_mask * attr_mask.transpose(2, 3);

nteffectivetransformer::cuda::softmax_kernel_kernelLauncher<float>(
attn_output_weights.data_ptr<float>(),
attr_mask.data_ptr<float>(),
Expand All @@ -148,27 +94,10 @@ at::Tensor bt_min_mha(
(float)(scaling),
defaultStream);

auto attn_output = at::matmul(attn_output_weights, val_buf);

nteffectivetransformer::cuda::transpose_rm_padding_kernelLauncher<float>(
attn_output.data_ptr<float>(),
attr_out.data_ptr<float>(),
valid_word_num,
batch_size,
seq_len,
head_num,
size_per_head,
batch_idx_ptr,
word_idx_ptr,
defaultStream);

// TODO: Bias is variably sized, need to add support for that.
at::Tensor result = at::matmul(attr_out, out_proj_weight.t());
result = result.reshape({-1});
return wrap_buffer(
std::move(result),
get_efficient_nested_size(query),
get_efficient_nested_stride(query));
auto attn_output = at::matmul(attn_output_weights, val_buf).contiguous();
attn_output = attn_output.transpose(1, 2).reshape({batch_size, seq_len, embedding_dim}).contiguous();
at::Tensor attr_out = from_padded_tensor(attn_output, get_efficient_nested_size(query), get_efficient_nested_stride(query));
return at::matmul(attr_out, out_proj_weight.t());
}

TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
Expand Down
73 changes: 70 additions & 3 deletions nestedtensor/csrc/cuda/padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,18 @@ void add_padding(
const int inner_size)
{
const int batch_id = blockIdx.x;
for (int i = 0; i < (offsets[batch_id + 1] - offsets[batch_id]) * inner_size; i++) {
output[batch_id * output_stride + i] = input[offsets[batch_id] * inner_size + i];
const int grain_size = blockDim.x;
const int tid = threadIdx.x;
const int range = (offsets[batch_id + 1] - offsets[batch_id]) * inner_size;
const int num_chunks = range / grain_size;
for (int id = 0; id < num_chunks; id++) {
output[batch_id * output_stride + id * grain_size + tid]
= input[offsets[batch_id] * inner_size + id * grain_size + tid];
}
const int leftover = num_chunks * grain_size;
if (leftover + tid < range) {
output[batch_id * output_stride + leftover + tid]
= input[offsets[batch_id] * inner_size + leftover + tid];
}
}

Expand All @@ -36,7 +46,7 @@ void add_padding_kernelLauncher(
dim3 grid;
grid.x = batch_size;

add_padding<float><<<grid, 1, 0, stream>>>(
add_padding<float><<<grid, 1024, 0, stream>>>(
input,
output,
offsets,
Expand Down Expand Up @@ -111,5 +121,62 @@ template void add_padding_mask_kernelLauncher<float>(
const int output_stride,
const int inner_size,
const cudaStream_t stream);

template<typename T>
__global__
void remove_padding(
const T* input,
T* output,
const int* offsets,
const int batch_size,
const int output_stride,
const int inner_size)
{
const int batch_id = blockIdx.x;
const int grain_size = blockDim.x;
const int tid = threadIdx.x;
const int range = (offsets[batch_id + 1] - offsets[batch_id]) * inner_size;
const int num_chunks = range / grain_size;
for (int id = 0; id < num_chunks; id++) {
output[offsets[batch_id] * inner_size + id * grain_size + tid]
= input[batch_id * output_stride + id * grain_size + tid];
}
const int leftover = num_chunks * grain_size;
if (leftover + tid < range) {
output[offsets[batch_id] * inner_size + leftover + tid]
= input[batch_id * output_stride + leftover + tid];
}
}

template<typename T>
void remove_padding_kernelLauncher(
T* input, // [batch_size x None]
T* output, // [batch_size x max(input.nested_size(1)) x inner_size]
const int* offsets, // [batch_size]
const int batch_size,
const int output_stride,
const int inner_size,
const cudaStream_t stream)
{
dim3 grid;
grid.x = batch_size;

remove_padding<float><<<grid, 1024, 0, stream>>>(
input,
output,
offsets,
batch_size,
output_stride,
inner_size);
}

template void remove_padding_kernelLauncher<float>(
float* input,
float* output,
const int* offsets,
const int batch_size,
const int output_stride,
const int inner_size,
const cudaStream_t stream);
}
}
10 changes: 10 additions & 0 deletions nestedtensor/csrc/cuda/padding.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,15 @@ void add_padding_mask_kernelLauncher(
const int inner_size,
const cudaStream_t stream);

template <typename T>
void remove_padding_kernelLauncher(
T* input,
T* output,
const int* lengths,
const int batch_size,
const int output_stride,
const int inner_size,
const cudaStream_t stream);

}
} // namespace nested_tensor
31 changes: 31 additions & 0 deletions nestedtensor/csrc/masking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,37 @@ Tensor to_mask(
return merge_mask(res_mask, mask_dim);
}

Tensor from_padded_tensor(Tensor padded, EfficientSizeNode target_size,
EfficientSizeNode target_stride) {
#ifdef WITH_CUDA
if (padded.dim() == 3 && target_size.dim() == 3 && get_is_contiguous(padded)) {
auto nt_opt_size = target_size.opt_sizes();
if (nt_opt_size[2] && padded.is_cuda()) {
Tensor nt_sizes_ = target_size.sizes().to(torch::kInt32);
TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor must be of nested_dim 2.")
Tensor nt_sizes = at::native::narrow(nt_sizes_, 1, 0, 1);
int max_size_1 = nt_sizes.max().item<int>();
nt_sizes =
at::native::cumsum(nt_sizes, 0).to(torch::kInt32).reshape({-1});
nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes});
Tensor output = torch::empty({target_size.numel()}, padded.options());
nt_sizes = nt_sizes.to(torch::kCUDA);
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
nested_tensor::cuda::remove_padding_kernelLauncher(
padded.data_ptr<float>(),
output.data_ptr<float>(),
nt_sizes.data_ptr<int>(),
*nt_opt_size[0],
padded.stride(0),
*nt_opt_size[2],
defaultStream);
return wrap_buffer(std::move(output), target_size, target_stride);
}
}
#endif
TORCH_CHECK(false, "from_padded_tensor not implemented for this case.");
}

Tensor to_padded_tensor(Tensor nt, double padding) {
#ifdef WITH_CUDA
if (get_dim(nt) == 3 && get_is_contiguous(nt)) {
Expand Down
10 changes: 10 additions & 0 deletions nestedtensor/csrc/masking.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <nestedtensor/csrc/python_functions.h>
#include <nestedtensor/csrc/utils/nested_node_functions.h>
#include <nestedtensor/csrc/utils/python_nested_node.h>
#include <nestedtensor/csrc/storage/EfficientSizeNode.h>
#include <torch/csrc/Size.h>
#include <torch/csrc/autograd/python_variable_indexing.h>
#include <torch/extension.h>
Expand All @@ -16,6 +17,15 @@ at::Tensor to_mask(
at::Tensor nt,
c10::optional<int64_t> mask_dim);

at::Tensor to_padded_tensor(
at::Tensor nt,
double padding);

at::Tensor from_padded_tensor(
at::Tensor nt,
torch::nested_tensor::EfficientSizeNode target_size,
torch::nested_tensor::EfficientSizeNode target_stride);

c10::optional<at::Tensor> nt_from_tensor_mask(
at::Tensor tensor,
at::Tensor mask,
Expand Down
11 changes: 8 additions & 3 deletions nestedtensor/csrc/storage/EfficientSizeNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ struct EfficientSizeNode {
const std::vector<c10::optional<int64_t>>& opt_sizes() const {
return _opt_sizes;
}
void refresh_opt_sizes() {
_opt_sizes = impl::construct_efficient_size(_structure, _height, _sizes);
}
const at::Tensor& sizes() const {
return _sizes;
}
Expand Down Expand Up @@ -167,7 +170,7 @@ struct EfficientSizeNode {
std::vector<int64_t> _structure;
const at::Tensor _sizes;
bool _opt_sizes_set = false;
const std::vector<c10::optional<int64_t>> _opt_sizes;
std::vector<c10::optional<int64_t>> _opt_sizes;
};

inline bool efficient_size_structure_matches(
Expand Down Expand Up @@ -230,10 +233,12 @@ inline void apply_efficient_size(
}
for (int64_t i = 0; i < sizes0.size(0); i++) {
fn(sizes0_ptr + i * sizes0.size(1),
sizes0.size(0),
sizes0.size(1),
sizes1_ptr + i * sizes1.size(1),
sizes1.size(0));
sizes1.size(1));
}
size_node0.refresh_opt_sizes();
size_node1.refresh_opt_sizes();
}

} // namespace nested_tensor
Expand Down
4 changes: 2 additions & 2 deletions nestedtensor/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__version__ = '0.1.4+e1d384f'
git_version = 'e1d384fea9d70a664b38a53768f82c81057a7d13'
__version__ = '0.1.4+3a8fd81'
git_version = '3a8fd81e999271b1ecdbf6cad8d1b6e1718d00c7'
from nestedtensor import _C
if hasattr(_C, 'CUDA_VERSION'):
cuda = _C.CUDA_VERSION

0 comments on commit 3e535a7

Please sign in to comment.