diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index efef42164b..5c9e320d5c 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -164,159 +164,6 @@ using namespace fbgemm_gpu; {%- endif %} {%- endmacro %} -{#-/* - Splitted version of load_and_accumulate macro. This code chunk describes - the weights load in forward kernel. Set up the WeightRow and load quantization - parameters. Shortcut store for nobag mode. - - The main difference is in whether the slices are loaded from the embedding - table or cache. - - NOTE: The decision was made to define this code chunk as a Jinja macro - instead of inline C++ function, since the compiler might not be able to - inline the code. - - In-code variables that are defined outside: - emb_t, cache_t, cache_t - idx_j - inner_j - D_emb - lxu_cache_weights - {{ locs_or_addrs_idx }}_j - idx_weight_j - VEC_WIDTH - D - kThreadGroupSize - output_j -*/#} -{%- macro load_weights(from_cache) %} - {%- if from_cache %} - const cache_t* cache_weights; - {%- if ssd %} - cache_weights = reinterpret_cast( - *reinterpret_cast(&{{ locs_or_addrs_idx }}_j)); - {%- else %} - cache_weights = reinterpret_cast( - &lxu_cache_weights[{{ locs_or_addrs_idx }}_j][0]); - {%- endif %} - {%- endif %} - {#-/* Set the weights row */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} - const auto weights_row = WeightRowAccessor - {%- endif %} - < - emb_t, - cache_t, - cache_t, - {%- if from_cache %} - true - {%- else %} - false - {%- endif %} - >( - {%- if from_cache %} - // Pass nullptr to avoid calling &weights[idx_j * D_emb], which loads - // memory into the registers as a side effect - nullptr, - // Load from the cache - cache_weights, - {%- else %} - // Load from the embedding table - &weights[idx_j * D_emb], - // Pass nullptr bc we are loading from the embedding table - nullptr, - {%- endif %} - D); - - {#-/* Set the quantization params */#} - {%- if from_cache %} - // Assume cache is FP16/FP32, which doesn't require quantization params - const auto qparams = make_float2(0.0f, 0.0f); - {%- else %} - // Load the quantization params from the embedding table row if emb_t == uint8_t - const auto qparams = weights_row.load_qparams(); - {%- endif %} - - {%- if not nobag %} - // Iterate over the row in the weights table, in 4-element strides - #pragma unroll kMaxVecsPerThread - for (int32_t i = 0; i < kMaxVecsPerThread; ++i) - { - // Load the slice of the weights - int32_t d = (i * kThreadGroupSize + threadIdx.x) * VEC_WIDTH; - d = (d < D) ? d : 0; - const auto weights_slice = weights_row.load(d, qparams); - vals[inner_j * kMaxVecsPerThread + i] = weights_slice; - } - - {%- else %} - for (int32_t i = 0; i < D; i += kThreadGroupSize * VEC_WIDTH) { - const int32_t d = i + threadIdx.x * VEC_WIDTH; - if (d < D) { - // Since there is no pooling, simply copy the weights to output - const auto weights_slice = weights_row.load(d, qparams); - {%- if is_index_select %} - // output is 1D (because the stride can be irregular) - weights_slice.store(&output[output_offset + output_j * output_stride + d]); - {%- else %} - // output is 2D - weights_slice.store(&output[output_j][d]); - {%- endif %} - } - } - {%- endif %} -{%- endmacro %} - -{#-/* - Splitted version of load_and_accumulate macro. This code chunk - describes the weights accumulate step in the forward kernel. - Accumulate the slices of values from the row. Does nothing for - nobag mode assuming all the work is done in load() macro. - - The main difference is in whether the slices are loaded from the embedding - table or cache. - - NOTE: The decision was made to define this code chunk as a Jinja macro - instead of inline C++ function, since the compiler might not be able to - inline the code. - - In-code variables that are defined outside: - emb_t, cache_t, cache_t - idx_j - inner_j - D_emb - lxu_cache_weights - cache_idx_j - idx_weight_j - VEC_WIDTH - D - kThreadGroupSize - output_j -*/#} -{%- macro accumulate_and_store(from_cache) %} - {%- if not nobag %} - // Iterate over the row in the weights table, in 4-element strides - #pragma unroll kMaxVecsPerThread - for (int32_t i = 0; - i < kMaxVecsPerThread && (i * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D; - ++i) { - {%- if is_gwd_kernel %} - // Scale weights with global weight decay - vals[inner_j * kMaxVecsPerThread + i].mul_(global_weight_decay_j); - {%- endif %} - {%- if weighted %} - // Accumulate the weights * positional weight - accumulators[i].fma_(vals[inner_j * kMaxVecsPerThread + i], idx_weight_j); - {%- else %} - // Accumulate the weights - accumulators[i].add_(vals[inner_j * kMaxVecsPerThread + i]); - {%- endif %} - } - {%- endif %} -{%- endmacro %} - {#-/* This code chunk contains the implementation body of the kernel, and is defined as a Jinja macro to be copy-pasted directly into the kernel as @@ -356,162 +203,8 @@ using namespace fbgemm_gpu; at::acc_type idx_weight = l < L ? indice_weights[indices_start + l] : 0; {%- endif %} - {%- if is_rocm %} - {%- if not nobag %} - rocm::Vec2T vals[kManualUnrollLength * kMaxVecsPerThread]; - {%- endif %} - // Iterate over kThreadGroupSize indices - for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength) - { - {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} - // Load index from thread j in the group - [[maybe_unused]] int64_t idx_j_[kManualUnrollLength]; - for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) - { - idx_j_[inner_j] = SHFL_SYNC(idx, outer_j + inner_j); - } - {%- endif %} - {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} - // Load cache's index from thread j in the group - [[maybe_unused]] int32_t {{ locs_or_addrs_idx }}_j_[kManualUnrollLength]; - for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) - { - {{ locs_or_addrs_idx }}_j_[inner_j] = use_lxu_cache ? SHFL_SYNC({{ locs_or_addrs_idx }}, outer_j + inner_j) : 0; - } - {%- endif %} - - {%- if weighted %} - // Load positional weight index from thread j in the group - at::acc_type idx_weight_j_[kManualUnrollLength]; - for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) - { - idx_weight_j_[inner_j] = SHFL_SYNC(idx_weight, outer_j + inner_j); - } - {%- endif %} - - - for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) - { - auto j = outer_j + inner_j; - {%- if is_index_select %} - int64_t output_j = L_start + l_start + j; - {%- elif nobag %} - int64_t output_j = indices_start + l_start + j; - {%- endif %} - - {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} - [[maybe_unused]] int64_t idx_j = idx_j_[inner_j]; - {%- endif %} - {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} - [[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j - = use_lxu_cache ? {{ locs_or_addrs_idx }}_j_[inner_j] : 0; - - {%- endif %} - {%- if weighted %} - at::acc_type idx_weight_j = idx_weight_j_[inner_j]; - {%- endif %} - - - - {#/**************************************************************/#} - {#-/* - This is the main switch that determines how we are to load and - accumulate weights, and is determined by Jinja-time, compile-time, - and run-time variables. - */#} - - {%- if dense %} - {#-/* If it's dense, cache is not supported, so load from the embedding table */#} - {{- load_weights(false) }} - - {%- elif lxu_miss_rate == "cache_conflict_miss_rate::all" %} - {#-/* Else if we know we have a 100% miss rate, then always fetch from the embedding table */#} - {{- load_weights(false) }} - - {%- elif lxu_miss_rate == "cache_conflict_miss_rate::zero" %} - {#-/* Else if we know we have a 0% miss rate, then always fetch from the cache */#} - {{ load_weights(true) }} - {%- else %} - {#-/* Else we defer to run-time selection */#} - if (placement == PlacementType::MANAGED_CACHING - && {{ locs_or_addrs_idx }}_j != kCacheLocationMissing - ) { - {#-/* If the row is available in the cache, fetch from the cache */#} - {{ load_weights(true) }} - } else { - {#-/* Else fetch from the embedding table */#} - {{ load_weights(false) }} - } - - {%- endif %} - {#/**************************************************************/#} - } - {%- if not nobag %} - for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) - { - auto j = outer_j + inner_j; - - {%- if is_index_select %} - int64_t output_j = L_start + l_start + j; - {%- elif nobag %} - int64_t output_j = indices_start + l_start + j; - {%- endif %} - - {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} - [[maybe_unused]] int64_t idx_j = idx_j_[inner_j]; - {%- endif %} - {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} - [[maybe_unused]] int32_t {{ locs_or_addrs_idx }}_j = {{ locs_or_addrs_idx }}_j_[inner_j]; - {%- endif %} - {%- if weighted %} - at::acc_type idx_weight_j = idx_weight_j_[inner_j]; - {%- endif %} - {%- if is_gwd_kernel %} - const auto global_weight_decay_j = SHFL_SYNC(global_weight_decay, j); - {%- endif %} - - {#/**************************************************************/#} - {#-/* - This is the main switch that determines how we are to load and - accumulate weights, and is determined by Jinja-time, compile-time, - and run-time variables. - */#} - - {%- if dense %} - {#-/* If it's dense, cache is not supported, so load from the embedding table */#} - {{- accumulate_and_store(false) }} - - {%- elif lxu_miss_rate == "cache_conflict_miss_rate::all" %} - {#-/* Else if we know we have a 100% miss rate, then always fetch from the embedding table */#} - {{- accumulate_and_store(false) }} - - {%- elif lxu_miss_rate == "cache_conflict_miss_rate::zero" %} - {#-/* Else if we know we have a 0% miss rate, then always fetch from the cache */#} - {{ accumulate_and_store(true) }} - {%- else %} - {#-/* Else we defer to run-time selection */#} - if (placement == PlacementType::MANAGED_CACHING - && {{ locs_or_addrs_idx }}_j != kCacheLocationMissing) { - {#-/* If the row is available in the cache, fetch from the cache */#} - {{ accumulate_and_store(true) }} - } else { - {#-/* Else fetch from the embedding table */#} - {{ accumulate_and_store(false) }} - } - - {%- endif %} - {#/**************************************************************/#} - } - {%- endif %} - } - {%- endif %} - - {%- if is_rocm %} - for(auto j = L - L % kManualUnrollLength; j < kThreadGroupSize && l_start + j < L; ++j) { - {%- else %} // Iterate over kThreadGroupSize indices for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { - {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group [[maybe_unused]] int64_t idx_j = SHFL_SYNC(idx, j); @@ -677,10 +370,6 @@ batch_index_select_dim0_codegen_forward_kernel( {%- else %} constexpr int VEC_WIDTH = 4; {%- endif %} - {%- if is_rocm %} - // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 4; - {%- endif %} // Determine the linearized warp ID, and exit early if needed int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;