Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review CUB util.ptx for CCCL 2.x #3342

Merged
merged 18 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cub/cub/agent/agent_adjacent_difference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ struct AgentDifference
BlockLoad(temp_storage.load).Load(load_it + tile_base, input);
}

CTA_SYNC();
__syncthreads();

if (ReadLeft)
{
Expand Down Expand Up @@ -186,7 +186,7 @@ struct AgentDifference
}
}

CTA_SYNC();
__syncthreads();

if (IS_LAST_TILE)
{
Expand Down
14 changes: 7 additions & 7 deletions cub/cub/agent/agent_batch_memcpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ private:
BlockBLevTileCountScanT(temp_storage.staged.blev.block_scan_storage)
.ExclusiveSum(block_offset, block_offset, blev_tile_prefix_op);
}
CTA_SYNC();
__syncthreads();

// Read in the BLEV buffer partition (i.e., the buffers that require block-level collaboration)
blev_buffer_offset = threadIdx.x * BLEV_BUFFERS_PER_THREAD;
Expand Down Expand Up @@ -996,7 +996,7 @@ private:

// Ensure all threads finished collaborative BlockExchange so temporary storage can be reused
// with next iteration
CTA_SYNC();
__syncthreads();
}
}

Expand Down Expand Up @@ -1026,7 +1026,7 @@ public:
}

// Ensure we can repurpose the BlockLoad's temporary storage
CTA_SYNC();
__syncthreads();

// Count how many buffers fall into each size-class
VectorizedSizeClassCounterT size_class_histogram = GetBufferSizeClassHistogram(buffer_sizes);
Expand All @@ -1037,7 +1037,7 @@ public:
.ExclusiveSum(size_class_histogram, size_class_histogram, size_class_agg);

// Ensure we can repurpose the scan's temporary storage for scattering the buffer ids
CTA_SYNC();
__syncthreads();

// Factor in the per-size-class counts / offsets
// That is, WLEV buffer offset has to be offset by the TLEV buffer count and BLEV buffer offset
Expand Down Expand Up @@ -1077,15 +1077,15 @@ public:

// Ensure the prefix callback has finished using its temporary storage and that it can be reused
// in the next stage
CTA_SYNC();
__syncthreads();

// Scatter the buffers into one of the three partitions (TLEV, WLEV, BLEV) depending on their
// size
PartitionBuffersBySize(buffer_sizes, size_class_histogram, temp_storage.staged.buffers_by_size_class);

// Ensure all buffers have been partitioned by their size class AND
// ensure that blev_buffer_offset has been written to shared memory
CTA_SYNC();
__syncthreads();

// TODO: think about prefetching tile_buffer_{srcs,dsts} into shmem
InputBufferIt tile_buffer_srcs = input_buffer_it + buffer_offset;
Expand All @@ -1104,7 +1104,7 @@ public:
tile_id);

// Ensure we can repurpose the temporary storage required by EnqueueBLEVBuffers
CTA_SYNC();
__syncthreads();

// Copy warp-level buffers
BatchMemcpyWLEVBuffers(
Expand Down
8 changes: 4 additions & 4 deletions cub/cub/agent/agent_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ struct AgentHistogram
}

// Barrier to make sure all threads are done updating counters
CTA_SYNC();
__syncthreads();
}

// Initialize privatized bin counters. Specialized for privatized shared-memory counters
Expand Down Expand Up @@ -350,7 +350,7 @@ struct AgentHistogram
_CCCL_DEVICE _CCCL_FORCEINLINE void StoreOutput(CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS])
{
// Barrier to make sure all threads are done updating counters
CTA_SYNC();
__syncthreads();

// Apply privatized bin counts to output bin counts
#pragma unroll
Expand Down Expand Up @@ -690,15 +690,15 @@ struct AgentHistogram
ConsumeTile<IS_ALIGNED, true>(tile_offset, TILE_SAMPLES);
}

CTA_SYNC();
__syncthreads();

// Get next tile
if (threadIdx.x == 0)
{
temp_storage.tile_idx = tile_queue.Drain(1) + num_even_share_tiles;
}

CTA_SYNC();
__syncthreads();

tile_idx = temp_storage.tile_idx;
}
Expand Down
11 changes: 6 additions & 5 deletions cub/cub/agent/agent_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct agent_t
gmem_to_reg<threads_per_block, IsFullTile>(
keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc);
CTA_SYNC();
__syncthreads();

// use binary search in shared memory to find merge path for each of thread.
// we can use int type here, because the number of items in shared memory is limited
Expand All @@ -158,7 +158,7 @@ struct agent_t
keys_loc,
indices,
compare_op);
CTA_SYNC();
__syncthreads();

// write keys
if (IsFullTile)
Expand All @@ -182,17 +182,18 @@ struct agent_t
item_type items_loc[items_per_thread];
gmem_to_reg<threads_per_block, IsFullTile>(
items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
CTA_SYNC(); // block_store_keys above uses shared memory, so make sure all threads are done before we write to it
__syncthreads(); // block_store_keys above uses shared memory, so make sure all threads are done before we write
// to it
reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc);
CTA_SYNC();
__syncthreads();

// gather items from shared mem
#pragma unroll
for (int i = 0; i < items_per_thread; ++i)
{
items_loc[i] = storage.items_shared[indices[i]];
}
CTA_SYNC();
__syncthreads();

// write from reg to gmem
if (IsFullTile)
Expand Down
20 changes: 10 additions & 10 deletions cub/cub/agent/agent_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ struct AgentBlockSort
BlockLoadItems(storage.load_items).Load(items_in + tile_base, items_local);
}

CTA_SYNC();
__syncthreads();
}

KeyT keys_local[ITEMS_PER_THREAD];
Expand All @@ -200,7 +200,7 @@ struct AgentBlockSort
BlockLoadKeys(storage.load_keys).Load(keys_in + tile_base, keys_local);
}

CTA_SYNC();
__syncthreads();
_CCCL_PDL_TRIGGER_NEXT_LAUNCH();

_CCCL_IF_CONSTEXPR (IS_LAST_TILE)
Expand All @@ -212,7 +212,7 @@ struct AgentBlockSort
BlockMergeSortT(storage.block_merge).Sort(keys_local, items_local, compare_op);
}

CTA_SYNC();
__syncthreads();

if (ping)
{
Expand All @@ -227,7 +227,7 @@ struct AgentBlockSort

_CCCL_IF_CONSTEXPR (!KEYS_ONLY)
{
CTA_SYNC();
__syncthreads();

_CCCL_IF_CONSTEXPR (IS_LAST_TILE)
{
Expand All @@ -252,7 +252,7 @@ struct AgentBlockSort

_CCCL_IF_CONSTEXPR (!KEYS_ONLY)
{
CTA_SYNC();
__syncthreads();

_CCCL_IF_CONSTEXPR (IS_LAST_TILE)
{
Expand Down Expand Up @@ -583,7 +583,7 @@ struct AgentMerge
}
}

CTA_SYNC();
__syncthreads();
_CCCL_PDL_TRIGGER_NEXT_LAUNCH();

// use binary search in shared memory
Expand Down Expand Up @@ -616,7 +616,7 @@ struct AgentMerge
indices,
compare_op);

CTA_SYNC();
__syncthreads();

// write keys
if (ping)
Expand Down Expand Up @@ -650,11 +650,11 @@ struct AgentMerge
_CCCL_IF_CONSTEXPR (!KEYS_ONLY)
#endif // _CCCL_CUDACC_AT_LEAST(11, 8)
{
CTA_SYNC();
__syncthreads();

detail::reg_to_shared<BLOCK_THREADS>(&storage.items_shared[0], items_local);

CTA_SYNC();
__syncthreads();

// gather items from shared mem
//
Expand All @@ -664,7 +664,7 @@ struct AgentMerge
items_local[item] = storage.items_shared[indices[item]];
}

CTA_SYNC();
__syncthreads();

// write from reg to gmem
//
Expand Down
34 changes: 17 additions & 17 deletions cub/cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ struct AgentRadixSortDownsweep
temp_storage.keys_and_offsets.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM];
}

CTA_SYNC();
__syncthreads();

#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
Expand Down Expand Up @@ -305,7 +305,7 @@ struct AgentRadixSortDownsweep
int (&ranks)[ITEMS_PER_THREAD],
OffsetT valid_items)
{
CTA_SYNC();
__syncthreads();

ValueExchangeT& exchange_values = temp_storage.exchange_values.Alias();

Expand All @@ -315,7 +315,7 @@ struct AgentRadixSortDownsweep
exchange_values[ranks[ITEM]] = values[ITEM];
}

CTA_SYNC();
__syncthreads();

#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
Expand All @@ -342,7 +342,7 @@ struct AgentRadixSortDownsweep
{
BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + block_offset, keys);

CTA_SYNC();
__syncthreads();
}

/**
Expand All @@ -362,7 +362,7 @@ struct AgentRadixSortDownsweep

BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + block_offset, keys, valid_items, oob_item);

CTA_SYNC();
__syncthreads();
}

/**
Expand Down Expand Up @@ -409,7 +409,7 @@ struct AgentRadixSortDownsweep
{
BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + block_offset, values);

CTA_SYNC();
__syncthreads();
}

/**
Expand All @@ -428,7 +428,7 @@ struct AgentRadixSortDownsweep

BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + block_offset, values, valid_items);

CTA_SYNC();
__syncthreads();
}

/**
Expand Down Expand Up @@ -474,7 +474,7 @@ struct AgentRadixSortDownsweep
{
ValueT values[ITEMS_PER_THREAD];

CTA_SYNC();
__syncthreads();

LoadValues(values, block_offset, valid_items, Int2Type<FULL_TILE>(), Int2Type<LOAD_WARP_STRIPED>());

Expand Down Expand Up @@ -520,7 +520,7 @@ struct AgentRadixSortDownsweep
int exclusive_digit_prefix[BINS_TRACKED_PER_THREAD];
BlockRadixRankT(temp_storage.radix_rank).RankKeys(keys, ranks, digit_extractor(), exclusive_digit_prefix);

CTA_SYNC();
__syncthreads();

// Share exclusive digit prefix
#pragma unroll
Expand All @@ -534,7 +534,7 @@ struct AgentRadixSortDownsweep
}
}

CTA_SYNC();
__syncthreads();

// Get inclusive digit prefix
int inclusive_digit_prefix[BINS_TRACKED_PER_THREAD];
Expand Down Expand Up @@ -562,7 +562,7 @@ struct AgentRadixSortDownsweep
}
}

CTA_SYNC();
__syncthreads();

// Update global scatter base offsets for each digit
#pragma unroll
Expand All @@ -577,7 +577,7 @@ struct AgentRadixSortDownsweep
}
}

CTA_SYNC();
__syncthreads();

// Scatter keys
ScatterKeys<FULL_TILE>(keys, relative_bin_offsets, ranks, valid_items);
Expand All @@ -602,7 +602,7 @@ struct AgentRadixSortDownsweep
T items[ITEMS_PER_THREAD];

LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_in + block_offset, items);
CTA_SYNC();
__syncthreads();
StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_out + block_offset, items);

block_offset += TILE_ITEMS;
Expand All @@ -616,7 +616,7 @@ struct AgentRadixSortDownsweep
T items[ITEMS_PER_THREAD];

LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_in + block_offset, items, valid_items);
CTA_SYNC();
__syncthreads();
StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_out + block_offset, items, valid_items);
}
}
Expand Down Expand Up @@ -670,7 +670,7 @@ struct AgentRadixSortDownsweep
}
}

short_circuit = CTA_SYNC_AND(short_circuit);
short_circuit = __syncthreads_and(short_circuit);
}

/**
Expand Down Expand Up @@ -719,7 +719,7 @@ struct AgentRadixSortDownsweep
}
}

short_circuit = CTA_SYNC_AND(short_circuit);
short_circuit = __syncthreads_and(short_circuit);
}

/**
Expand All @@ -744,7 +744,7 @@ struct AgentRadixSortDownsweep
ProcessTile<true>(block_offset);
block_offset += TILE_ITEMS;

CTA_SYNC();
__syncthreads();
}

// Clean up last partial tile with guarded-I/O
Expand Down
Loading
Loading