Skip to content

Commit

Permalink
2025-01-16 nightly release (d16e2d8)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 16, 2025
1 parent 77be896 commit 738ba45
Show file tree
Hide file tree
Showing 21 changed files with 527 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
int32_t kFixedMaxVecsPerThread
>
__global__ __launch_bounds__(kForwardMaxThreads) void
Expand All @@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
{%- if not dense %}
const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }},
{%- endif %}
Expand Down Expand Up @@ -113,17 +114,17 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
fd_B.DivMod(b_t, &t, &b);
{%- endif %}

int64_t weights_offset = weights_offsets[t];
int32_t D_start = D_offsets[t];
int32_t D_end = D_offsets[t + 1];
int32_t D = D_end - D_start;
int64_t indices_start = offsets[b_t];
int64_t indices_end = offsets[b_t + 1];
int32_t L = indices_end - indices_start;
const auto weights_offset = weights_offsets[t];
const auto D_start = D_offsets[t];
const auto D_end = D_offsets[t + 1];
const auto D = D_end - D_start;
const auto indices_start = offsets[b_t];
const auto indices_end = offsets[b_t + 1];
const auto L = indices_end - indices_start;
if (feature_requires_grad.size(0) > 0 && !feature_requires_grad[t]) {
// If the table does not require gradient computation, we set the gradient to zero.
for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
for (auto l_start = 0; l_start < L; l_start += kWarpSize) {
auto l = l_start + threadIdx.x;
if (l < L) {
grad_indice_weights[indices_start + l] = 0.0;
}
Expand Down Expand Up @@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void

for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
index_t idx = l < L ? indices[indices_start + l] : 0;
{%- if not dense %}
const auto {{ locs_or_addrs_idx }} =
(placement == PlacementType::MANAGED_CACHING && l < L)
? {{ locs_or_addrs_tensor }}[indices_start + l] : 0;
{%- endif %}
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int64_t idx_j = shfl_sync(idx, j);
auto idx_j = shfl_sync(idx, j);
{%- if not dense %}
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
{%- endif %}
Expand Down Expand Up @@ -354,6 +355,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
const uint32_t info_B_mask = info_B_mask_int64;
{%- endif %}

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1", [&] {
DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
aligned_grad_output.scalar_type(),
Expand All @@ -362,7 +364,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
{%- else %}
dev_weights.scalar_type(),
{%- endif %}
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel",
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_2",
[&] {
{%- if vbe %}
const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1});
Expand All @@ -379,13 +381,13 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
mdesc, vdesc, vbdesc)
%}
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"{{ kernel_name }}";
const auto func_name = "{{ kernel_name }}";
#endif
{{ kernel_name }}<
emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread><<<
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
Expand All @@ -400,8 +402,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
{%- endif %}
Expand All @@ -421,6 +423,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
});
{%- endfor %} {# /* for use_vec_blocking */ #}
});
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_indice_weights;
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/training/python/lookup_args.template
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class OptimizerArgs(NamedTuple):
weight_norm_coefficient: float
lower_bound: float
regularization_mode: int
use_rowwise_bias_correction: bool # Used for OptimType.ADAM


class Momentum(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def invoke(
prev_iter_dev: Optional[torch.Tensor] = None,
{%- endif %}
gwd_lower_bound: float = 0.0,
{%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
row_counter: Optional[Momentum] = None,
{%- endif %}
) -> torch.Tensor:
{%- if is_experimental_optimizer %}
# By design, the warning only shows up once
Expand All @@ -94,7 +97,20 @@ def invoke(
{%- endif %}

vbe_metadata = common_args.vbe_metadata

{%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
if not optimizer_args.use_rowwise_bias_correction or row_counter is None:
row_counter_dev = None
row_counter_uvm = None
row_counter_offsets = None
row_counter_placements = None
elif optimizer_args.use_rowwise_bias_correction and row_counter is None:
assert False, "use_rowwise_bias_correction is set but row_counter cannot be None"
else:
row_counter_dev = row_counter.dev
row_counter_uvm = row_counter.uvm
row_counter_offsets = row_counter.offsets
row_counter_placements = row_counter.placements
{%- endif %}
{%- if has_cpu_support and not ssd %}
if (common_args.host_weights.numel() > 0):
T = common_args.D_offsets.numel() - 1
Expand Down Expand Up @@ -263,7 +279,6 @@ def invoke(
{%- endfor %}
{%- endif %}


return torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
# common_args
{%- if not dense %}
Expand Down Expand Up @@ -393,6 +408,15 @@ def invoke(
row_counter_offsets=row_counter.offsets,
row_counter_placements=row_counter.placements,
{%- endif %}
{%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
row_counter_dev=row_counter_dev,
row_counter_uvm=row_counter_uvm,
row_counter_offsets=row_counter_offsets,
row_counter_placements=row_counter_placements,
{%- endif %}
{%- if "use_rowwise_bias_correction" in args_pt2.unified_pt2.split_function_arg_names %}
use_rowwise_bias_correction=optimizer_args.use_rowwise_bias_correction,
{%- endif %}
# iter
iter=iter,
# max counter
Expand Down
7 changes: 7 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,9 @@ def matmul_fp8_row(
)

output_shape = a_shape[:-1] + (N,)
# Handle tensor with empty inputs.
if (M == 0) or (N == 0) or (K == 0):
return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
# launch kernel
if a.device == torch.device("cpu"):
logger.info(
Expand Down Expand Up @@ -2084,6 +2087,10 @@ def matmul_fp8_block(
)

output_shape = a_shape[:-1] + (N,)
# Handle case where inputs are empty.
if (M == 0) or (N == 0) or (K == 0):
return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)

# launch kernel
assert device != torch.device(
"cpu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ at::Tensor f8f8bf16_blockwise_impl(

// Create output tensor.
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
// If inputs are empty return an empty tensor.
if (M == 0 || N == 0 || K == 0) {
return Y;
}

int StrideA = K;
int StrideB = K;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,14 @@ at::Tensor f8f8bf16_rowwise_wrapper(
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
// Compute target output sizes.
auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;
// Handle case where an input dimension is zero.
if (M == 0 || N == 0 || K == 0) {
// Return a tensor of zeros to handle case where K is 0.
return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16));
}

// Prepare output tensor if needed.
at::Tensor Y;
Expand All @@ -324,10 +332,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
TORCH_CHECK(Y_M == M && Y.sizes().vec().back() == N);
TORCH_CHECK(Y.dtype() == at::kBFloat16);
} else {
// 1. If the input tensor is {M, K}, the output tensor is {M, N}.
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;
Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ at::Tensor f8f8bf16_impl(
auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;

// Handle case where inputs are empty.
if (M == 0 || N == 0 || K == 0) {
return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16));
}

TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ at::Tensor f8f8bf16_blockwise_impl(
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;
// Handle case where input shapes are empty.
if (M == 0 || N == 0 || K == 0) {
// Return a zero tensor in case K is 0.
return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16));
}

TORCH_CHECK(WQ.size(1) == K);
TORCH_CHECK(XQ.stride(-1) == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,18 @@ at::Tensor f8f8bf16_rowwise_impl(
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
TORCH_CHECK(XQ.size(-1) == K);
// 1. If the input tensor is {M, K}, the output tensor is {M, N}.
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;
// Handle case where there is a zero dimension, we simply return an empty
// tensor.
if (M == 0 || N == 0 || K == 0) {
// Use zeros instead of empty for special case where K=0.
return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16));
}

TORCH_CHECK(XQ.size(-1) == K);
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ at::Tensor f8f8bf16_tensorwise_impl(
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;
// Handle case where inputs are empty.
if (M == 0 || N == 0 || K == 0) {
return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16));
}

TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
Expand Down
26 changes: 21 additions & 5 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -684,19 +684,25 @@ at::Tensor get_fp8_per_tensor_scale(
std::optional<at::Tensor> scale_ub) // scale upper bound
{
CUDA_DEVICE_GUARD(input);
TORCH_CHECK(input.numel() != 0, "input should not be empty tensor");
TORCH_CHECK(
input.dim() >= 2,
"Invalid dim. The dim of input should be greater than or equal to 2");
auto _st = input.scalar_type();
TORCH_CHECK(_st == torch::kBFloat16, "Invalid datatype. input must be BF16");
int out_size = input.numel() == 0 ? 0 : 1;
at::Tensor scale = torch::empty(
{1},
{out_size},
torch::dtype(torch::kFloat32)
.device(torch::kCUDA, at::cuda::current_device())
.requires_grad(false));
// Handle case where input is empty.
if (input.numel() == 0) {
return scale;
}
const auto stream = at::cuda::getCurrentCUDAStream();
invokeComputeScale(
reinterpret_cast<float*>(scale.data_ptr()),
Expand All @@ -720,7 +726,6 @@ at::Tensor quantize_fp8_per_tensor_fixed_scale(
std::optional<at::Tensor> bs, // batch size
bool stochastic_rounding) {
CUDA_DEVICE_GUARD(input);
TORCH_CHECK(input.numel() != 0, "input should not be empty tensor");
TORCH_CHECK(
input.dim() >= 2,
"Invalid dim. The dim of input should be greater than or equal to 2");
Expand All @@ -739,6 +744,11 @@ at::Tensor quantize_fp8_per_tensor_fixed_scale(
.device(torch::kCUDA, at::cuda::current_device())
.requires_grad(false));
// When input is empty, return empty scale as well.
if (input.numel() == 0) {
return quantized_input;
}
const auto stream = at::cuda::getCurrentCUDAStream();
invokeQuantizeMatrix(
reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr()),
Expand All @@ -761,7 +771,6 @@ std::vector<at::Tensor> quantize_fp8_per_tensor(
bool stochastic_rounding) // stochastic rounding
{
CUDA_DEVICE_GUARD(input);
TORCH_CHECK(input.numel() != 0, "input should not be empty tensor");
TORCH_CHECK(
input.dim() >= 2,
"Invalid dim. The dim of input should be greater than or equal to 2");
Expand Down Expand Up @@ -789,6 +798,10 @@ std::vector<at::Tensor> quantize_fp8_per_tensor(
torch::dtype(torch::kFloat32)
.device(torch::kCUDA, at::cuda::current_device())
.requires_grad(false));
// When input is empty, return empty tensors.
if (input.numel() == 0) {
return std::vector<at::Tensor>{quantized_input, scales};
}
auto* const quantized_input_ptr =
reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr());
const auto stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -1177,7 +1190,6 @@ std::vector<at::Tensor> quantize_fp8_per_col(
std::optional<at::Tensor> scale_ub) // scale upperbound)
{
CUDA_DEVICE_GUARD(input);
TORCH_CHECK(input.numel() != 0, "input should not be empty tensor");
TORCH_CHECK(
input.dim() >= 2,
"Invalid dim. The dim of input should be greater than or equal to 2");
Expand All @@ -1201,6 +1213,10 @@ std::vector<at::Tensor> quantize_fp8_per_col(
torch::dtype(torch::kFloat32)
.device(torch::kCUDA, at::cuda::current_device())
.requires_grad(false));
// When input is empty, return empty tensors.
if (input.numel() == 0) {
return std::vector<at::Tensor>{quantized_input, scales};
}
auto* const quantized_input_ptr =
reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr());
const auto stream = at::cuda::getCurrentCUDAStream();
Expand Down
Loading

0 comments on commit 738ba45

Please sign in to comment.