Skip to content

Commit

Permalink
Review/Deprecate CUB util.ptx for CCCL 2.x (#3342)
Browse files Browse the repository at this point in the history
  • Loading branch information
fbusato authored Jan 15, 2025
1 parent 1d426b6 commit 43fb061
Show file tree
Hide file tree
Showing 58 changed files with 434 additions and 386 deletions.
4 changes: 2 additions & 2 deletions cub/cub/agent/agent_adjacent_difference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ struct AgentDifference
BlockLoad(temp_storage.load).Load(load_it + tile_base, input);
}

CTA_SYNC();
__syncthreads();

if (ReadLeft)
{
Expand Down Expand Up @@ -186,7 +186,7 @@ struct AgentDifference
}
}

CTA_SYNC();
__syncthreads();

if (IS_LAST_TILE)
{
Expand Down
14 changes: 7 additions & 7 deletions cub/cub/agent/agent_batch_memcpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ private:
BlockBLevTileCountScanT(temp_storage.staged.blev.block_scan_storage)
.ExclusiveSum(block_offset, block_offset, blev_tile_prefix_op);
}
CTA_SYNC();
__syncthreads();

// Read in the BLEV buffer partition (i.e., the buffers that require block-level collaboration)
blev_buffer_offset = threadIdx.x * BLEV_BUFFERS_PER_THREAD;
Expand Down Expand Up @@ -996,7 +996,7 @@ private:

// Ensure all threads finished collaborative BlockExchange so temporary storage can be reused
// with next iteration
CTA_SYNC();
__syncthreads();
}
}

Expand Down Expand Up @@ -1026,7 +1026,7 @@ public:
}

// Ensure we can repurpose the BlockLoad's temporary storage
CTA_SYNC();
__syncthreads();

// Count how many buffers fall into each size-class
VectorizedSizeClassCounterT size_class_histogram = GetBufferSizeClassHistogram(buffer_sizes);
Expand All @@ -1037,7 +1037,7 @@ public:
.ExclusiveSum(size_class_histogram, size_class_histogram, size_class_agg);

// Ensure we can repurpose the scan's temporary storage for scattering the buffer ids
CTA_SYNC();
__syncthreads();

// Factor in the per-size-class counts / offsets
// That is, WLEV buffer offset has to be offset by the TLEV buffer count and BLEV buffer offset
Expand Down Expand Up @@ -1077,15 +1077,15 @@ public:

// Ensure the prefix callback has finished using its temporary storage and that it can be reused
// in the next stage
CTA_SYNC();
__syncthreads();

// Scatter the buffers into one of the three partitions (TLEV, WLEV, BLEV) depending on their
// size
PartitionBuffersBySize(buffer_sizes, size_class_histogram, temp_storage.staged.buffers_by_size_class);

// Ensure all buffers have been partitioned by their size class AND
// ensure that blev_buffer_offset has been written to shared memory
CTA_SYNC();
__syncthreads();

// TODO: think about prefetching tile_buffer_{srcs,dsts} into shmem
InputBufferIt tile_buffer_srcs = input_buffer_it + buffer_offset;
Expand All @@ -1104,7 +1104,7 @@ public:
tile_id);

// Ensure we can repurpose the temporary storage required by EnqueueBLEVBuffers
CTA_SYNC();
__syncthreads();

// Copy warp-level buffers
BatchMemcpyWLEVBuffers(
Expand Down
8 changes: 4 additions & 4 deletions cub/cub/agent/agent_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ struct AgentHistogram
}

// Barrier to make sure all threads are done updating counters
CTA_SYNC();
__syncthreads();
}

// Initialize privatized bin counters. Specialized for privatized shared-memory counters
Expand Down Expand Up @@ -350,7 +350,7 @@ struct AgentHistogram
_CCCL_DEVICE _CCCL_FORCEINLINE void StoreOutput(CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS])
{
// Barrier to make sure all threads are done updating counters
CTA_SYNC();
__syncthreads();

// Apply privatized bin counts to output bin counts
#pragma unroll
Expand Down Expand Up @@ -690,15 +690,15 @@ struct AgentHistogram
ConsumeTile<IS_ALIGNED, true>(tile_offset, TILE_SAMPLES);
}

CTA_SYNC();
__syncthreads();

// Get next tile
if (threadIdx.x == 0)
{
temp_storage.tile_idx = tile_queue.Drain(1) + num_even_share_tiles;
}

CTA_SYNC();
__syncthreads();

tile_idx = temp_storage.tile_idx;
}
Expand Down
11 changes: 6 additions & 5 deletions cub/cub/agent/agent_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct agent_t
gmem_to_reg<threads_per_block, IsFullTile>(
keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc);
CTA_SYNC();
__syncthreads();

// use binary search in shared memory to find merge path for each of thread.
// we can use int type here, because the number of items in shared memory is limited
Expand All @@ -158,7 +158,7 @@ struct agent_t
keys_loc,
indices,
compare_op);
CTA_SYNC();
__syncthreads();

// write keys
if (IsFullTile)
Expand All @@ -182,17 +182,18 @@ struct agent_t
item_type items_loc[items_per_thread];
gmem_to_reg<threads_per_block, IsFullTile>(
items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
CTA_SYNC(); // block_store_keys above uses shared memory, so make sure all threads are done before we write to it
__syncthreads(); // block_store_keys above uses shared memory, so make sure all threads are done before we write
// to it
reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc);
CTA_SYNC();
__syncthreads();

// gather items from shared mem
#pragma unroll
for (int i = 0; i < items_per_thread; ++i)
{
items_loc[i] = storage.items_shared[indices[i]];
}
CTA_SYNC();
__syncthreads();

// write from reg to gmem
if (IsFullTile)
Expand Down
20 changes: 10 additions & 10 deletions cub/cub/agent/agent_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ struct AgentBlockSort
BlockLoadItems(storage.load_items).Load(items_in + tile_base, items_local);
}

CTA_SYNC();
__syncthreads();
}

KeyT keys_local[ITEMS_PER_THREAD];
Expand All @@ -200,7 +200,7 @@ struct AgentBlockSort
BlockLoadKeys(storage.load_keys).Load(keys_in + tile_base, keys_local);
}

CTA_SYNC();
__syncthreads();
_CCCL_PDL_TRIGGER_NEXT_LAUNCH();

_CCCL_IF_CONSTEXPR (IS_LAST_TILE)
Expand All @@ -212,7 +212,7 @@ struct AgentBlockSort
BlockMergeSortT(storage.block_merge).Sort(keys_local, items_local, compare_op);
}

CTA_SYNC();
__syncthreads();

if (ping)
{
Expand All @@ -227,7 +227,7 @@ struct AgentBlockSort

_CCCL_IF_CONSTEXPR (!KEYS_ONLY)
{
CTA_SYNC();
__syncthreads();

_CCCL_IF_CONSTEXPR (IS_LAST_TILE)
{
Expand All @@ -252,7 +252,7 @@ struct AgentBlockSort

_CCCL_IF_CONSTEXPR (!KEYS_ONLY)
{
CTA_SYNC();
__syncthreads();

_CCCL_IF_CONSTEXPR (IS_LAST_TILE)
{
Expand Down Expand Up @@ -583,7 +583,7 @@ struct AgentMerge
}
}

CTA_SYNC();
__syncthreads();
_CCCL_PDL_TRIGGER_NEXT_LAUNCH();

// use binary search in shared memory
Expand Down Expand Up @@ -616,7 +616,7 @@ struct AgentMerge
indices,
compare_op);

CTA_SYNC();
__syncthreads();

// write keys
if (ping)
Expand Down Expand Up @@ -650,11 +650,11 @@ struct AgentMerge
_CCCL_IF_CONSTEXPR (!KEYS_ONLY)
#endif // _CCCL_CUDACC_AT_LEAST(11, 8)
{
CTA_SYNC();
__syncthreads();

detail::reg_to_shared<BLOCK_THREADS>(&storage.items_shared[0], items_local);

CTA_SYNC();
__syncthreads();

// gather items from shared mem
//
Expand All @@ -664,7 +664,7 @@ struct AgentMerge
items_local[item] = storage.items_shared[indices[item]];
}

CTA_SYNC();
__syncthreads();

// write from reg to gmem
//
Expand Down
34 changes: 17 additions & 17 deletions cub/cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ struct AgentRadixSortDownsweep
temp_storage.keys_and_offsets.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM];
}

CTA_SYNC();
__syncthreads();

#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
Expand Down Expand Up @@ -305,7 +305,7 @@ struct AgentRadixSortDownsweep
int (&ranks)[ITEMS_PER_THREAD],
OffsetT valid_items)
{
CTA_SYNC();
__syncthreads();

ValueExchangeT& exchange_values = temp_storage.exchange_values.Alias();

Expand All @@ -315,7 +315,7 @@ struct AgentRadixSortDownsweep
exchange_values[ranks[ITEM]] = values[ITEM];
}

CTA_SYNC();
__syncthreads();

#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
Expand All @@ -342,7 +342,7 @@ struct AgentRadixSortDownsweep
{
BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + block_offset, keys);

CTA_SYNC();
__syncthreads();
}

/**
Expand All @@ -362,7 +362,7 @@ struct AgentRadixSortDownsweep

BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + block_offset, keys, valid_items, oob_item);

CTA_SYNC();
__syncthreads();
}

/**
Expand Down Expand Up @@ -409,7 +409,7 @@ struct AgentRadixSortDownsweep
{
BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + block_offset, values);

CTA_SYNC();
__syncthreads();
}

/**
Expand All @@ -428,7 +428,7 @@ struct AgentRadixSortDownsweep

BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + block_offset, values, valid_items);

CTA_SYNC();
__syncthreads();
}

/**
Expand Down Expand Up @@ -474,7 +474,7 @@ struct AgentRadixSortDownsweep
{
ValueT values[ITEMS_PER_THREAD];

CTA_SYNC();
__syncthreads();

LoadValues(values, block_offset, valid_items, Int2Type<FULL_TILE>(), Int2Type<LOAD_WARP_STRIPED>());

Expand Down Expand Up @@ -520,7 +520,7 @@ struct AgentRadixSortDownsweep
int exclusive_digit_prefix[BINS_TRACKED_PER_THREAD];
BlockRadixRankT(temp_storage.radix_rank).RankKeys(keys, ranks, digit_extractor(), exclusive_digit_prefix);

CTA_SYNC();
__syncthreads();

// Share exclusive digit prefix
#pragma unroll
Expand All @@ -534,7 +534,7 @@ struct AgentRadixSortDownsweep
}
}

CTA_SYNC();
__syncthreads();

// Get inclusive digit prefix
int inclusive_digit_prefix[BINS_TRACKED_PER_THREAD];
Expand Down Expand Up @@ -562,7 +562,7 @@ struct AgentRadixSortDownsweep
}
}

CTA_SYNC();
__syncthreads();

// Update global scatter base offsets for each digit
#pragma unroll
Expand All @@ -577,7 +577,7 @@ struct AgentRadixSortDownsweep
}
}

CTA_SYNC();
__syncthreads();

// Scatter keys
ScatterKeys<FULL_TILE>(keys, relative_bin_offsets, ranks, valid_items);
Expand All @@ -602,7 +602,7 @@ struct AgentRadixSortDownsweep
T items[ITEMS_PER_THREAD];

LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_in + block_offset, items);
CTA_SYNC();
__syncthreads();
StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_out + block_offset, items);

block_offset += TILE_ITEMS;
Expand All @@ -616,7 +616,7 @@ struct AgentRadixSortDownsweep
T items[ITEMS_PER_THREAD];

LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_in + block_offset, items, valid_items);
CTA_SYNC();
__syncthreads();
StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_out + block_offset, items, valid_items);
}
}
Expand Down Expand Up @@ -670,7 +670,7 @@ struct AgentRadixSortDownsweep
}
}

short_circuit = CTA_SYNC_AND(short_circuit);
short_circuit = __syncthreads_and(short_circuit);
}

/**
Expand Down Expand Up @@ -719,7 +719,7 @@ struct AgentRadixSortDownsweep
}
}

short_circuit = CTA_SYNC_AND(short_circuit);
short_circuit = __syncthreads_and(short_circuit);
}

/**
Expand All @@ -744,7 +744,7 @@ struct AgentRadixSortDownsweep
ProcessTile<true>(block_offset);
block_offset += TILE_ITEMS;

CTA_SYNC();
__syncthreads();
}

// Clean up last partial tile with guarded-I/O
Expand Down
Loading

0 comments on commit 43fb061

Please sign in to comment.