Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Back out "Manual loop unroll for rocm inference" #3506

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,70 +136,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow;
uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow);
bool load_idx_valid = row_load_idx < uint4_loads_per_row;

{%- if is_rocm %}
constexpr uint32_t kMaxRowUnroll = 4;
constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll;

#pragma unroll
for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) {
uint4 row_data_v[kRowUnroll];
const uint4* row_v[kRowUnroll];
int32_t idx_v[kRowUnroll];
int32_t cache_idx_v[kRowUnroll];
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1;
cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1;
}


#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
valid = valid && (idx_v[inner_i] != -1);
if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) {
row_v[inner_i] = reinterpret_cast<const uint4*>(&lxu_cache_weights[static_cast<int64_t>(cache_idx_v[inner_i])][0]);
} else
if (valid) {
row_v[inner_i] = reinterpret_cast<const uint4*>(&weights[static_cast<int64_t>(idx_v[inner_i]) * D_bytes]);
} else {
row_v[inner_i] = reinterpret_cast<const uint4*>(&weights[0]);
}
}
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
row_data_v[inner_i] = row_v[inner_i][row_load_idx];
}
uint4 zeros = {0, 0, 0, 0};
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1);
uint4 data = valid ? row_data_v[inner_i] : zeros;
buffers[warp_idx][i][input_row_idx][row_load_idx] = data;
{% if weighted %}
buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
{% endif %}
}
}
{%- endif %}

{%- if is_rocm %}
if constexpr (OutputRowsPerThread % kRowUnroll)
{
#pragma unroll
for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) {
{%- else %}
#pragma unroll OutputRowsPerThread
for (uint32_t i = 0; i < OutputRowsPerThread; ++i) {
{%- endif %}
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1;
Expand All @@ -219,9 +157,6 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
{% endif %}
}
{%- if is_rocm %}
} // constexpr if (OutputRowsPerThread % kRowUnroll)
{%- endif %}
}
// equivalent to fence + wait.
cp_async_wait<0>();
Expand Down Expand Up @@ -429,4 +364,4 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else ""

}

// clang-format on
// clang-format on
Loading