diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 8e1db3675..5c4e783d3 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -64,6 +64,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, int32_t kFixedMaxVecsPerThread > __global__ __launch_bounds__(kForwardMaxThreads) void @@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L] - const pta::PackedTensorAccessor32 offsets, // [B x T + 1] + const pta::PackedTensorAccessor32 indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L] + const pta::PackedTensorAccessor32 offsets, // [B x T + 1] {%- if not dense %} const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} @@ -113,17 +114,17 @@ __global__ __launch_bounds__(kForwardMaxThreads) void fd_B.DivMod(b_t, &t, &b); {%- endif %} - int64_t weights_offset = weights_offsets[t]; - int32_t D_start = D_offsets[t]; - int32_t D_end = D_offsets[t + 1]; - int32_t D = D_end - D_start; - int64_t indices_start = offsets[b_t]; - int64_t indices_end = offsets[b_t + 1]; - int32_t L = indices_end - indices_start; + const auto weights_offset = weights_offsets[t]; + const auto D_start = D_offsets[t]; + const auto D_end = D_offsets[t + 1]; + const auto D = D_end - D_start; + const auto indices_start = offsets[b_t]; + const auto indices_end = offsets[b_t + 1]; + const auto L = indices_end - indices_start; if (feature_requires_grad.size(0) > 0 && !feature_requires_grad[t]) { // If the table does not require gradient computation, we set the gradient to zero. - for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) { - int32_t l = l_start + threadIdx.x; + for (auto l_start = 0; l_start < L; l_start += kWarpSize) { + auto l = l_start + threadIdx.x; if (l < L) { grad_indice_weights[indices_start + l] = 0.0; } @@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) { int32_t l = l_start + threadIdx.x; - int64_t idx = l < L ? indices[indices_start + l] : 0; + index_t idx = l < L ? indices[indices_start + l] : 0; {%- if not dense %} const auto {{ locs_or_addrs_idx }} = (placement == PlacementType::MANAGED_CACHING && l < L) ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - int64_t idx_j = shfl_sync(idx, j); + auto idx_j = shfl_sync(idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); {%- endif %} @@ -354,6 +355,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const uint32_t info_B_mask = info_B_mask_int64; {%- endif %} + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -362,7 +364,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( {%- else %} dev_weights.scalar_type(), {%- endif %} - "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel", + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_2", [&] { {%- if vbe %} const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1}); @@ -379,13 +381,13 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( mdesc, vdesc, vbdesc) %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "{{ kernel_name }}"; + const auto func_name = "{{ kernel_name }}"; #endif {{ kernel_name }}< emb_t, grad_t, cache_t, + index_t, kFixedMaxVecsPerThread><<< div_round_up(total_B, kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), @@ -400,8 +402,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( {%- endif %} MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), {%- if not dense %} MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), {%- endif %} @@ -421,6 +423,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( }); {%- endfor %} {# /* for use_vec_blocking */ #} }); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return grad_indice_weights; diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.template b/fbgemm_gpu/codegen/training/python/lookup_args.template index f3fd7aa87..ca79a16f8 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.template +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -76,6 +76,7 @@ class OptimizerArgs(NamedTuple): weight_norm_coefficient: float lower_bound: float regularization_mode: int + use_rowwise_bias_correction: bool # Used for OptimType.ADAM class Momentum(NamedTuple): diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index b55b850c5..a2c1b7695 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -81,6 +81,9 @@ def invoke( prev_iter_dev: Optional[torch.Tensor] = None, {%- endif %} gwd_lower_bound: float = 0.0, + {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %} + row_counter: Optional[Momentum] = None, + {%- endif %} ) -> torch.Tensor: {%- if is_experimental_optimizer %} # By design, the warning only shows up once @@ -94,7 +97,20 @@ def invoke( {%- endif %} vbe_metadata = common_args.vbe_metadata - + {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %} + if not optimizer_args.use_rowwise_bias_correction or row_counter is None: + row_counter_dev = None + row_counter_uvm = None + row_counter_offsets = None + row_counter_placements = None + elif optimizer_args.use_rowwise_bias_correction and row_counter is None: + assert False, "use_rowwise_bias_correction is set but row_counter cannot be None" + else: + row_counter_dev = row_counter.dev + row_counter_uvm = row_counter.uvm + row_counter_offsets = row_counter.offsets + row_counter_placements = row_counter.placements + {%- endif %} {%- if has_cpu_support and not ssd %} if (common_args.host_weights.numel() > 0): T = common_args.D_offsets.numel() - 1 @@ -263,7 +279,6 @@ def invoke( {%- endfor %} {%- endif %} - return torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( # common_args {%- if not dense %} @@ -393,6 +408,15 @@ def invoke( row_counter_offsets=row_counter.offsets, row_counter_placements=row_counter.placements, {%- endif %} + {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %} + row_counter_dev=row_counter_dev, + row_counter_uvm=row_counter_uvm, + row_counter_offsets=row_counter_offsets, + row_counter_placements=row_counter_placements, + {%- endif %} + {%- if "use_rowwise_bias_correction" in args_pt2.unified_pt2.split_function_arg_names %} + use_rowwise_bias_correction=optimizer_args.use_rowwise_bias_correction, + {%- endif %} # iter iter=iter, # max counter diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 35555ee65..b554fe9a1 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -1266,6 +1266,9 @@ def matmul_fp8_row( ) output_shape = a_shape[:-1] + (N,) + # Handle tensor with empty inputs. + if (M == 0) or (N == 0) or (K == 0): + return torch.zeros(output_shape, device=device, dtype=torch.bfloat16) # launch kernel if a.device == torch.device("cpu"): logger.info( @@ -2084,6 +2087,10 @@ def matmul_fp8_block( ) output_shape = a_shape[:-1] + (N,) + # Handle case where inputs are empty. + if (M == 0) or (N == 0) or (K == 0): + return torch.zeros(output_shape, device=device, dtype=torch.bfloat16) + # launch kernel assert device != torch.device( "cpu" diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip index 53b8020c6..38ee3c9b2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip @@ -66,6 +66,10 @@ at::Tensor f8f8bf16_blockwise_impl( // Create output tensor. auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + // If inputs are empty return an empty tensor. + if (M == 0 || N == 0 || K == 0) { + return Y; + } int StrideA = K; int StrideB = K; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip index 1f2d347d2..9ddfc87d8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip @@ -314,6 +314,14 @@ at::Tensor f8f8bf16_rowwise_wrapper( int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); int N = WQ.size(0); int K = WQ.size(1); + // Compute target output sizes. + auto out_sizes = XQ.sizes().vec(); + out_sizes.back() = N; + // Handle case where an input dimension is zero. + if (M == 0 || N == 0 || K == 0) { + // Return a tensor of zeros to handle case where K is 0. + return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16)); + } // Prepare output tensor if needed. at::Tensor Y; @@ -324,10 +332,6 @@ at::Tensor f8f8bf16_rowwise_wrapper( TORCH_CHECK(Y_M == M && Y.sizes().vec().back() == N); TORCH_CHECK(Y.dtype() == at::kBFloat16); } else { - // 1. If the input tensor is {M, K}, the output tensor is {M, N}. - // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}. - auto out_sizes = XQ.sizes().vec(); - out_sizes.back() = N; Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16.cu index ed3bb462d..b205a64fa 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16.cu @@ -47,6 +47,11 @@ at::Tensor f8f8bf16_impl( auto out_sizes = XQ.sizes().vec(); out_sizes.back() = N; + // Handle case where inputs are empty. + if (M == 0 || N == 0 || K == 0) { + return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16)); + } + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu index 2836a2352..6614b02eb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu @@ -60,6 +60,11 @@ at::Tensor f8f8bf16_blockwise_impl( // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}. auto out_sizes = XQ.sizes().vec(); out_sizes.back() = N; + // Handle case where input shapes are empty. + if (M == 0 || N == 0 || K == 0) { + // Return a zero tensor in case K is 0. + return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16)); + } TORCH_CHECK(WQ.size(1) == K); TORCH_CHECK(XQ.stride(-1) == 1); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu index 809aa7f22..84edb5221 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu @@ -51,12 +51,18 @@ at::Tensor f8f8bf16_rowwise_impl( int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); int N = WQ.size(0); int K = WQ.size(1); - TORCH_CHECK(XQ.size(-1) == K); // 1. If the input tensor is {M, K}, the output tensor is {M, N}. // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}. auto out_sizes = XQ.sizes().vec(); out_sizes.back() = N; + // Handle case where there is a zero dimension, we simply return an empty + // tensor. + if (M == 0 || N == 0 || K == 0) { + // Use zeros instead of empty for special case where K=0. + return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16)); + } + TORCH_CHECK(XQ.size(-1) == K); TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu index b9e567b47..9388fbff0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu @@ -48,6 +48,10 @@ at::Tensor f8f8bf16_tensorwise_impl( // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}. auto out_sizes = XQ.sizes().vec(); out_sizes.back() = N; + // Handle case where inputs are empty. + if (M == 0 || N == 0 || K == 0) { + return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16)); + } TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu index 6b104a769..31cdb554a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu @@ -684,19 +684,25 @@ at::Tensor get_fp8_per_tensor_scale( std::optional scale_ub) // scale upper bound { CUDA_DEVICE_GUARD(input); - TORCH_CHECK(input.numel() != 0, "input should not be empty tensor"); TORCH_CHECK( input.dim() >= 2, "Invalid dim. The dim of input should be greater than or equal to 2"); auto _st = input.scalar_type(); TORCH_CHECK(_st == torch::kBFloat16, "Invalid datatype. input must be BF16"); + int out_size = input.numel() == 0 ? 0 : 1; + at::Tensor scale = torch::empty( - {1}, + {out_size}, torch::dtype(torch::kFloat32) .device(torch::kCUDA, at::cuda::current_device()) .requires_grad(false)); + // Handle case where input is empty. + if (input.numel() == 0) { + return scale; + } + const auto stream = at::cuda::getCurrentCUDAStream(); invokeComputeScale( reinterpret_cast(scale.data_ptr()), @@ -720,7 +726,6 @@ at::Tensor quantize_fp8_per_tensor_fixed_scale( std::optional bs, // batch size bool stochastic_rounding) { CUDA_DEVICE_GUARD(input); - TORCH_CHECK(input.numel() != 0, "input should not be empty tensor"); TORCH_CHECK( input.dim() >= 2, "Invalid dim. The dim of input should be greater than or equal to 2"); @@ -739,6 +744,11 @@ at::Tensor quantize_fp8_per_tensor_fixed_scale( .device(torch::kCUDA, at::cuda::current_device()) .requires_grad(false)); + // When input is empty, return empty scale as well. + if (input.numel() == 0) { + return quantized_input; + } + const auto stream = at::cuda::getCurrentCUDAStream(); invokeQuantizeMatrix( reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr()), @@ -761,7 +771,6 @@ std::vector quantize_fp8_per_tensor( bool stochastic_rounding) // stochastic rounding { CUDA_DEVICE_GUARD(input); - TORCH_CHECK(input.numel() != 0, "input should not be empty tensor"); TORCH_CHECK( input.dim() >= 2, "Invalid dim. The dim of input should be greater than or equal to 2"); @@ -789,6 +798,10 @@ std::vector quantize_fp8_per_tensor( torch::dtype(torch::kFloat32) .device(torch::kCUDA, at::cuda::current_device()) .requires_grad(false)); + // When input is empty, return empty tensors. + if (input.numel() == 0) { + return std::vector{quantized_input, scales}; + } auto* const quantized_input_ptr = reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr()); const auto stream = at::cuda::getCurrentCUDAStream(); @@ -1177,7 +1190,6 @@ std::vector quantize_fp8_per_col( std::optional scale_ub) // scale upperbound) { CUDA_DEVICE_GUARD(input); - TORCH_CHECK(input.numel() != 0, "input should not be empty tensor"); TORCH_CHECK( input.dim() >= 2, "Invalid dim. The dim of input should be greater than or equal to 2"); @@ -1201,6 +1213,10 @@ std::vector quantize_fp8_per_col( torch::dtype(torch::kFloat32) .device(torch::kCUDA, at::cuda::current_device()) .requires_grad(false)); + // When input is empty, return empty tensors. + if (input.numel() == 0) { + return std::vector{quantized_input, scales}; + } auto* const quantized_input_ptr = reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr()); const auto stream = at::cuda::getCurrentCUDAStream(); diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index ac21c5706..96d59339e 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -224,7 +224,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None: ) @settings(deadline=None) @given( - B_T=st.sampled_from([2048, 4096]), + B_T=st.sampled_from([0, 2048, 4096]), D=st.sampled_from([128, 256]), HD_L=st.sampled_from([256, 512]), Mode=st.sampled_from( @@ -433,8 +433,8 @@ def test_quantize_fp8_matmul( # Blockwise seems to have slightly more noisy outputs. # Special case correctness to avoid flakiness. if Mode == "blockwise": - atol = 1.2e-1 - rtol = 1.2e-1 + atol = 1.3e-1 + rtol = 1.3e-1 else: atol = 9.0e-2 rtol = 9.0e-2 @@ -919,29 +919,28 @@ def test_bf16_grouped_gemm( for i in range(len(x_group)): x_group[i][zero_start_index_M[i] :, :] = 0 + bf16_op = ( + torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic + if use_padding_zeros + else torch.ops.fbgemm.bf16bf16bf16_grouped + ) + bf16_args = ( + (x_group, w_group, zero_start_index_M) + if use_padding_zeros + else (x_group, w_group) + ) + # BF16 grouped gemm kernel if use_cudagraph: # warmup - torch.ops.fbgemm.bf16bf16bf16_grouped( - x_group, - w_group, - zero_start_index_M if use_padding_zeros else None, - ) + bf16_op(*bf16_args) # With cudagraph g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): - y_bf16_group = torch.ops.fbgemm.bf16bf16bf16_grouped( - x_group, - w_group, - zero_start_index_M if use_padding_zeros else None, - ) + y_bf16_group = bf16_op(*bf16_args) g.replay() else: - y_bf16_group = torch.ops.fbgemm.bf16bf16bf16_grouped( - x_group, - w_group, - zero_start_index_M if use_padding_zeros else None, - ) + y_bf16_group = bf16_op(*bf16_args) # BF16 loopover gemm reference y_group_ref = torch.bmm(xs, ws.transpose(1, 2)) @@ -1100,6 +1099,7 @@ def test_quantize_compile(self) -> None: @unittest.skipIf( torch.version.hip, "Skip on AMD: cuda quantize op is yet suported." ) + @settings(deadline=None) @given( K=st.sampled_from([0, 128]), ) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 8f8d5779e..b048bd953 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -145,6 +145,17 @@ class GlobalWeightDecayDefinition: lower_bound: float = 0.0 +@dataclass(frozen=True) +class UserEnabledConfigDefinition: + """ + This class is used to configure whether certain modes are to be enabled + """ + + # This is used in Adam to perform rowwise bias correction using `row_counter` + # More details can be found in D64848802. + use_rowwise_bias_correction: bool = False + + @dataclass(frozen=True) class EnsembleModeDefinition: step_ema: float = 10000 @@ -564,6 +575,11 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): using `malloc` + `cudaHostRegister`. Otherwise use `cudaMallocManaged` + extra_optimizer_config Optional[UserEnabledConfigDefinition] = None): + An extra config to enable certain modes for optimizer. These modes + are not enabled by default. + - `use_rowwise_bias_correction` is used in Adam to enable rowwise + bias correction computation """ embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]] @@ -630,6 +646,7 @@ def __init__( # noqa C901 multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, global_weight_decay: Optional[GlobalWeightDecayDefinition] = None, uvm_host_mapped: bool = False, + extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -1006,6 +1023,20 @@ def __init__( # noqa C901 # and CowClipDefinition are not used counter_halflife = -1 + # TO DO: Enable this on the new interface + # learning_rate_tensor = torch.tensor( + # learning_rate, device=torch.device("cpu"), dtype=torch.float + # ) + if extra_optimizer_config is None: + extra_optimizer_config = UserEnabledConfigDefinition() + self.use_rowwise_bias_correction: bool = ( + extra_optimizer_config.use_rowwise_bias_correction + ) + if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM: + raise AssertionError( + "`use_rowwise_bias_correction` is only supported for OptimType.ADAM", + ) + self.optimizer_args = invokers.lookup_args.OptimizerArgs( stochastic_rounding=stochastic_rounding, gradient_clipping=gradient_clipping, @@ -1032,6 +1063,7 @@ def __init__( # noqa C901 weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient, lower_bound=cowclip_regularization.lower_bound, regularization_mode=weight_decay_mode.value, + use_rowwise_bias_correction=self.use_rowwise_bias_correction, ) if optimizer != OptimType.NONE: @@ -1168,6 +1200,19 @@ def __init__( # noqa C901 torch.ones(1, dtype=torch.float32, device=self.current_device), persistent=False, ) + elif optimizer == OptimType.ADAM and self.use_rowwise_bias_correction: + self._apply_split( + construct_split_state( + embedding_specs, + rowwise=True, + cacheable=False, + ), + prefix="row_counter", + # pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param + # but got `Type[torch.float32]`. + dtype=torch.float32, + uvm_host_mapped=self.uvm_host_mapped, + ) else: self._register_nonpersistent_buffers("prev_iter") self._register_nonpersistent_buffers("row_counter") @@ -1192,7 +1237,6 @@ def __init__( # noqa C901 "iter", torch.zeros(1, dtype=torch.int64, device=self.current_device), ) - else: self.register_buffer( "iter", @@ -1895,6 +1939,24 @@ def forward( # noqa: C901 iter_int = int(self.iter_cpu.add_(1).item()) # used for local computation self.iter.add_(1) # used for checkpointing + row_counter = invokers.lookup_args.Momentum( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. + dev=self.row_counter_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. + host=self.row_counter_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. + uvm=self.row_counter_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. + offsets=self.row_counter_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. + placements=self.row_counter_placements, + ) + if self.optimizer == OptimType.ADAM: return self._report_io_size_count( "fwd_output", @@ -1904,6 +1966,10 @@ def forward( # noqa: C901 momentum1, momentum2, iter_int, + self.use_rowwise_bias_correction, + row_counter=( + row_counter if self.use_rowwise_bias_correction else None + ), ), ) if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM: @@ -1957,23 +2023,6 @@ def forward( # noqa: C901 # `Union[Module, Tensor]`. placements=self.prev_iter_placements, ) - row_counter = invokers.lookup_args.Momentum( - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `Union[Module, Tensor]`. - dev=self.row_counter_dev, - # pyre-fixme[6]: For 2nd argument expected `Tensor` but got - # `Union[Module, Tensor]`. - host=self.row_counter_host, - # pyre-fixme[6]: For 3rd argument expected `Tensor` but got - # `Union[Module, Tensor]`. - uvm=self.row_counter_uvm, - # pyre-fixme[6]: For 4th argument expected `Tensor` but got - # `Union[Module, Tensor]`. - offsets=self.row_counter_offsets, - # pyre-fixme[6]: For 5th argument expected `Tensor` but got - # `Union[Module, Tensor]`. - placements=self.row_counter_placements, - ) if self.optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD: with torch.no_grad(): @@ -2543,6 +2592,15 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]: list_of_state_dict = [ {"momentum_buffer": states[0]} for states in split_optimizer_states ] + elif self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction: + list_of_state_dict = [ + { + "exp_avg": states[0], + "exp_avg_sq": states[1], + "row_counter": states[2], + } + for states in split_optimizer_states + ] elif ( self.optimizer == OptimType.ADAM or self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM @@ -2717,7 +2775,9 @@ def get_optimizer_states( rowwise=True, ) ) - if self._used_rowwise_adagrad_with_counter: + if self._used_rowwise_adagrad_with_counter or ( + self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction + ): states.append( get_optimizer_states( # pyre-fixme[6]: For 1st argument expected `Tensor` but got diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 36dae3e11..a4ef78d88 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -570,6 +570,7 @@ def __init__( weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient, lower_bound=cowclip_regularization.lower_bound, regularization_mode=weight_decay_mode.value, + use_rowwise_bias_correction=False, # Unused, this is used in TBE's Adam ) table_embedding_dtype = weights_precision.as_dtype() diff --git a/fbgemm_gpu/requirements.txt b/fbgemm_gpu/requirements.txt index cc0f1329c..f7f492194 100644 --- a/fbgemm_gpu/requirements.txt +++ b/fbgemm_gpu/requirements.txt @@ -23,3 +23,4 @@ scikit-build setuptools setuptools_git_versioning tabulate +patchelf diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h new file mode 100644 index 000000000..c25d04121 --- /dev/null +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include // @manual=//caffe2:ATen-core +#include + +namespace ssd { + +class EmbeddingRocksDB; +class EmbeddingRocksDBWrapper; +class SnapshotHandle; + +// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions +struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder { + explicit EmbeddingSnapshotHandleWrapper( + const SnapshotHandle* handle, + std::shared_ptr db); + + ~EmbeddingSnapshotHandleWrapper(); + + const SnapshotHandle* handle; + std::shared_ptr db; +}; + +class KVTensorWrapper : public torch::jit::CustomClassHolder { + public: + explicit KVTensorWrapper( + c10::intrusive_ptr db, + std::vector shape, + int64_t dtype, + int64_t row_offset, + std::optional> + snapshot_handle); + + at::Tensor narrow(int64_t dim, int64_t start, int64_t length); + + void set_range( + int64_t dim, + const int64_t start, + const int64_t length, + const at::Tensor& weights); + + c10::IntArrayRef size(); + + c10::ScalarType dtype(); + + std::string_view dtype_str(); + + c10::Device device(); + + std::string device_str(); + + std::string layout_str(); + + private: + std::shared_ptr db_; + c10::intrusive_ptr snapshot_handle_; + at::TensorOptions options_; + std::vector shape_; + int64_t row_offset_; +}; + +} // namespace ssd diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp new file mode 100644 index 000000000..4ae614f4d --- /dev/null +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "./kv_tensor_wrapper.h" +#include "common/base/Exception.h" + +using namespace at; +using namespace ssd; + +namespace ssd { +class EmbeddingRocksDB {}; + +// @lint-ignore CLANGTIDY facebook-hte-ShadowingClass +class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { + private: + friend class KVTensorWrapper; + std::shared_ptr impl_; +}; + +class SnapshotHandle {}; + +KVTensorWrapper::KVTensorWrapper( + c10::intrusive_ptr db, + std::vector shape, + [[maybe_unused]] int64_t dtype, + int64_t row_offset, + [[maybe_unused]] std::optional< + c10::intrusive_ptr> snapshot_handle) + // @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn + : db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) { + FBEXCEPTION("Not implemented"); +} + +at::Tensor KVTensorWrapper::narrow( + [[maybe_unused]] int64_t dim, + [[maybe_unused]] int64_t start, + [[maybe_unused]] int64_t length) { + FBEXCEPTION("Not implemented"); + return at::empty(c10::IntArrayRef({1, 1}), options_); +} + +void KVTensorWrapper::set_range( + [[maybe_unused]] int64_t dim, + [[maybe_unused]] const int64_t start, + [[maybe_unused]] const int64_t length, + // @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn + [[maybe_unused]] const at::Tensor& weights) { + FBEXCEPTION("Not implemented"); +} + +c10::IntArrayRef KVTensorWrapper::size() { + FBEXCEPTION("Not implemented"); + return shape_; +} + +c10::ScalarType KVTensorWrapper::dtype() { + FBEXCEPTION("Not implemented"); + return options_.dtype().toScalarType(); +} + +std::string_view KVTensorWrapper::dtype_str() { + FBEXCEPTION("Not implemented"); + return scalarTypeToTypeMeta(dtype()).name(); +} + +c10::Device KVTensorWrapper::device() { + FBEXCEPTION("Not implemented"); + return options_.device(); +} + +std::string KVTensorWrapper::device_str() { + FBEXCEPTION("Not implemented"); + return device().str(); +} + +std::string KVTensorWrapper::layout_str() { + FBEXCEPTION("Not implemented"); + std::ostringstream oss; + oss << options_.layout(); + return oss.str(); +} +} // namespace ssd diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index c45380a9e..45d4d2e34 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -386,6 +386,46 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { std::shared_ptr impl_; }; +SnapshotHandle::SnapshotHandle(EmbeddingRocksDB* db) : db_(db) { + auto num_shards = db->num_shards(); + CHECK_GT(num_shards, 0); + shard_snapshots_.reserve(num_shards); + for (auto shard = 0; shard < num_shards; ++shard) { + const auto* snapshot = db->dbs_[shard]->GetSnapshot(); + CHECK(snapshot != nullptr) + << "ERROR: create_snapshot fails to create a snapshot " + << "for db shard " << shard << ". Please make sure that " + << "inplace_update_support is set to false" << std::endl; + shard_snapshots_.push_back(snapshot); + } +} + +SnapshotHandle::~SnapshotHandle() { + for (auto shard = 0; shard < db_->dbs_.size(); ++shard) { + snapshot_ptr_t snapshot = shard_snapshots_[shard]; + CHECK(snapshot != nullptr) << "Unexpected nullptr for snapshot " << shard; + db_->dbs_[shard]->ReleaseSnapshot(snapshot); + } +} + +void SnapshotHandle::release() { + db_->release_snapshot(this); +} + +snapshot_ptr_t SnapshotHandle::get_snapshot_for_shard(size_t shard) const { + CHECK_LE(shard, shard_snapshots_.size()); + return shard_snapshots_[shard]; +} + +EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper( + const SnapshotHandle* handle, + std::shared_ptr db) + : handle(handle), db(std::move(db)) {} + +EmbeddingSnapshotHandleWrapper::~EmbeddingSnapshotHandleWrapper() { + db->release_snapshot(handle); +} + KVTensorWrapper::KVTensorWrapper( c10::intrusive_ptr db, std::vector shape, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 5bb7358cf..a84ac1698 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -25,6 +25,7 @@ #endif #include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h" #include "kv_db_table_batched_embeddings.h" +#include "kv_tensor_wrapper.h" #include "torch/csrc/autograd/record_function_ops.h" namespace ssd { @@ -124,55 +125,29 @@ class Initializer { std::unique_ptr producer_; }; +class EmbeddingRocksDB; +using snapshot_ptr_t = const rocksdb::Snapshot*; +// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions +class SnapshotHandle { + public: + explicit SnapshotHandle(EmbeddingRocksDB* db); + ~SnapshotHandle(); + void release(); + snapshot_ptr_t get_snapshot_for_shard(size_t shard) const; + + private: + friend class EmbeddingRocksDB; + + EmbeddingRocksDB* db_; + std::vector shard_snapshots_; +}; // class SnapshotHandle + /// @ingroup embedding-ssd /// /// @brief An implementation of EmbeddingKVDB for RocksDB /// class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { - using snapshot_ptr_t = const rocksdb::Snapshot*; - public: - class SnapshotHandle { - public: - explicit SnapshotHandle(EmbeddingRocksDB* db) : db_(db) { - auto num_shards = db->num_shards(); - CHECK_GT(num_shards, 0); - shard_snapshots_.reserve(num_shards); - for (auto shard = 0; shard < num_shards; ++shard) { - const auto* snapshot = db->dbs_[shard]->GetSnapshot(); - CHECK(snapshot != nullptr) - << "ERROR: create_snapshot fails to create a snapshot " - << "for db shard " << shard << ". Please make sure that " - << "inplace_update_support is set to false" << std::endl; - shard_snapshots_.push_back(snapshot); - } - } - - ~SnapshotHandle() { - for (auto shard = 0; shard < db_->dbs_.size(); ++shard) { - snapshot_ptr_t snapshot = shard_snapshots_[shard]; - CHECK(snapshot != nullptr) - << "Unexpected nullptr for snapshot " << shard; - db_->dbs_[shard]->ReleaseSnapshot(snapshot); - } - } - - void release() { - db_->release_snapshot(this); - } - - snapshot_ptr_t get_snapshot_for_shard(size_t shard) const { - CHECK_LE(shard, shard_snapshots_.size()); - return shard_snapshots_[shard]; - } - - private: - friend class EmbeddingRocksDB; - - EmbeddingRocksDB* db_; - std::vector shard_snapshots_; - }; - explicit EmbeddingRocksDB( std::string path, int64_t num_shards, @@ -934,6 +909,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { return folly::collect(futures); } + friend class SnapshotHandle; + std::vector> dbs_; std::vector> initializers_; std::unique_ptr executor_; @@ -960,58 +937,4 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { int64_t elem_size_; }; // class EmbeddingRocksDB -class EmbeddingRocksDBWrapper; - -struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder { - explicit EmbeddingSnapshotHandleWrapper( - const EmbeddingRocksDB::SnapshotHandle* handle, - std::shared_ptr db) - : handle(handle), db(std::move(db)) {} - - ~EmbeddingSnapshotHandleWrapper() { - db->release_snapshot(handle); - } - - const EmbeddingRocksDB::SnapshotHandle* handle; - std::shared_ptr db; -}; - -class KVTensorWrapper : public torch::jit::CustomClassHolder { - public: - explicit KVTensorWrapper( - c10::intrusive_ptr db, - std::vector shape, - int64_t dtype, - int64_t row_offset, - std::optional> - snapshot_handle); - - at::Tensor narrow(int64_t dim, int64_t start, int64_t length); - - void set_range( - int64_t dim, - const int64_t start, - const int64_t length, - const at::Tensor& weights); - - c10::IntArrayRef size(); - - c10::ScalarType dtype(); - - std::string_view dtype_str(); - - c10::Device device(); - - std::string device_str(); - - std::string layout_str(); - - private: - std::shared_ptr db_; - c10::intrusive_ptr snapshot_handle_; - at::TensorOptions options_; - std::vector shape_; - int64_t row_offset_; -}; - } // namespace ssd diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index 91b0d95fd..d5bdf4eb6 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -33,6 +33,7 @@ SplitTableBatchedEmbeddingBagsCodegen, StepMode, TailIdThreshold, + UserEnabledConfigDefinition, WeightDecayMode, ) @@ -105,6 +106,7 @@ def execute_backward_optimizers_( # noqa C901 weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, uvm_non_rowwise_momentum: bool = False, optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None, + use_rowwise_bias_correction: bool = False, ) -> None: # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! @@ -307,6 +309,9 @@ def execute_backward_optimizers_( # noqa C901 optimizer_kwargs["weight_decay"] = weight_decay optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes + extra_optimizer_config = UserEnabledConfigDefinition( + use_rowwise_bias_correction=use_rowwise_bias_correction + ) if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB): optimizer_kwargs["eps"] = eps optimizer_kwargs["beta1"] = beta1 @@ -335,6 +340,10 @@ def execute_backward_optimizers_( # noqa C901 step_ema_coef=momentum, step_mode=step_mode, ) + row_counter_ref = [torch.zeros(E, dtype=torch.float32) for E in Es] + if optimizer == OptimType.ADAM and use_rowwise_bias_correction: + for i, indices in enumerate(xs): + row_counter_ref[i][indices.cpu()] += 1 if optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD: (eps, step_ema, step_start) = ( @@ -356,6 +365,7 @@ def execute_backward_optimizers_( # noqa C901 optimizer=optimizer, pooling_mode=pooling_mode, uvm_non_rowwise_momentum=uvm_non_rowwise_momentum, + extra_optimizer_config=extra_optimizer_config, **optimizer_kwargs, ) @@ -539,8 +549,21 @@ def execute_backward_optimizers_( # noqa C901 if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM): rowwise = optimizer == OptimType.PARTIAL_ROWWISE_ADAM + row_counter: Optional[torch.Tensor] = None for t in range(T): - (m1, m2) = split_optimizer_states[t] + if rowwise or not use_rowwise_bias_correction: + (m1, m2) = split_optimizer_states[t] + else: # Full adam with rowwise bias correction + (m1, m2, row_counter) = split_optimizer_states[t] + # check row counter + row_counter = row_counter.cpu() + torch.testing.assert_close( + row_counter, + row_counter_ref[t], + atol=0, + rtol=0, + ) + row_counter = row_counter.reshape(row_counter.size(0), 1) # Some optimizers have non-float momentums dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() m2_ref = ( @@ -552,9 +575,10 @@ def execute_backward_optimizers_( # noqa C901 m1_ref = dense_cpu_grad * (1.0 - beta1) self.assert_close_optim_state(m1, m1_ref) iter_ = cc.iter.item() - v_hat_t = m2_ref / (1 - beta2**iter_) + power = row_counter if use_rowwise_bias_correction else iter_ + v_hat_t = m2_ref / (1 - beta2**power) v_hat_t = v_hat_t if not rowwise else v_hat_t.view(v_hat_t.numel(), 1) - m_hat_t = m1_ref / (1 - beta1**iter_) + m_hat_t = m1_ref / (1 - beta1**power) weights_new = split_weights[t] weights_ref = ( torch.addcdiv( @@ -574,10 +598,16 @@ def execute_backward_optimizers_( # noqa C901 if get_optimizer_states is not None: optimizer_states_dict = get_optimizer_states[t] - assert set(optimizer_states_dict.keys()) == { - "exp_avg", - "exp_avg_sq", - } + state_keys = ( + { + "exp_avg", + "exp_avg_sq", + "row_counter", + } + if use_rowwise_bias_correction + else {"exp_avg", "exp_avg_sq"} + ) + assert set(optimizer_states_dict.keys()) == state_keys if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: for t in range(T): @@ -938,6 +968,63 @@ def test_backward_optimizers_adam( # noqa C901 uvm_non_rowwise_momentum=uvm_non_rowwise_momentum, ) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + weighted=st.booleans(), + mixed=st.booleans(), + mixed_B=st.booleans(), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, + ] + ), + uvm_non_rowwise_momentum=st.booleans(), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + @unittest.skipIf(*gpu_unavailable) + def test_backward_optimizers_adam_rowwise_bias_correction( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + mixed: bool, + mixed_B: bool, + long_segments: bool, + pooling_mode: PoolingMode, + uvm_non_rowwise_momentum: bool, + ) -> None: + self.execute_backward_optimizers_( + T, + D, + B, + log_E, + L, + weighted, + mixed, + mixed_B, + OptimType.ADAM, + long_segments, + pooling_mode, + False, # use_cpu + uvm_non_rowwise_momentum=uvm_non_rowwise_momentum, + use_rowwise_bias_correction=True, + ) + @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), diff --git a/fbgemm_gpu/test/tbe/training/forward_test.py b/fbgemm_gpu/test/tbe/training/forward_test.py index bb519864d..1832db0ef 100644 --- a/fbgemm_gpu/test/tbe/training/forward_test.py +++ b/fbgemm_gpu/test/tbe/training/forward_test.py @@ -541,7 +541,7 @@ def test_forward_gpu_no_cache_fp16( @unittest.skipIf(*gpu_unavailable) @given( - use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_experimental_tbe=st.booleans(), ) @settings( verbosity=VERBOSITY, @@ -670,7 +670,7 @@ def test_forward_gpu_uvm_cache_int8( @unittest.skipIf(*gpu_unavailable) @given( cache_algorithm=st.sampled_from(CacheAlgorithm), - use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_experimental_tbe=st.booleans(), ) @settings( verbosity=VERBOSITY, @@ -740,7 +740,7 @@ def test_forward_gpu_uvm_cache_fp16( @unittest.skipIf(*gpu_unavailable) @given( cache_algorithm=st.sampled_from(CacheAlgorithm), - use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_experimental_tbe=st.booleans(), ) @settings( verbosity=VERBOSITY,