Skip to content

Commit

Permalink
Rename addSetsForCacheReads to cacheOperandsToRegisters (#3380)
Browse files Browse the repository at this point in the history
As suggested by @zasdfgbnm in
https://github.com/NVIDIA/Fuser/pull/3278/files#r1833466428, this just
renames one method in the Ampere matmul scheduler for clarity. No
functional change is expected.
  • Loading branch information
jacobhinkle authored Nov 8, 2024
1 parent 2aacfd7 commit a730949
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,10 @@ void AmpereMultipleMatmulScheduler::cacheInputsAndOutputs() {

void AmpereMultipleMatmulScheduler::defineOperandCaches() {
cacheOperandsToSmem(as_, acw_smems_, params_->supported_vec_size.a);
addSetsForCacheReads(acw_smems_, acrs_);
cacheOperandsToRegisters(acw_smems_, acrs_);

cacheOperandsToSmem(bs_, bcw_smems_, params_->supported_vec_size.b);
addSetsForCacheReads(bcw_smems_, bcrs_);
cacheOperandsToRegisters(bcw_smems_, bcrs_);

// Now that we are finished possibly redefining the inputs to the MmaOps,
// we can set the macro for those ops
Expand Down Expand Up @@ -551,7 +551,7 @@ void AmpereMultipleMatmulScheduler::cacheOperandsToSmem(
}
}

void AmpereMultipleMatmulScheduler::addSetsForCacheReads(
void AmpereMultipleMatmulScheduler::cacheOperandsToRegisters(
const std::vector<TensorView*>& tv_smems,
std::vector<TensorView*>& tv_rs) {
tv_rs.resize(tv_smems.size(), nullptr);
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/ampere_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler {
// existing LoadStoreOp present. Please note that for the second LoadStore
// we don't propagate the allocation domain, since the scheduler sets the
// allocation domain in the registers.
void addSetsForCacheReads(
void cacheOperandsToRegisters(
const std::vector<TensorView*>& tv_smems,
std::vector<TensorView*>& tv_rs);

Expand Down

0 comments on commit a730949

Please sign in to comment.