Skip to content

Commit

Permalink
add ods logging for l2 cache perf (pytorch#3031)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#2335

Pull Request resolved: pytorch#3031

X-link: facebookresearch/FBGEMM#129

collect performance related metrics from KV store and export them to ODS

Differential Revision: D61417980
  • Loading branch information
Guanqiao Wang authored and facebook-github-bot committed Aug 29, 2024
1 parent c818b87 commit f3b2a2e
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 25 deletions.
13 changes: 13 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/runtime_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def should_report(self, iteration_step: int) -> bool:
"""
...

@abc.abstractmethod
def register_stats(self, stats_name: str, amplifier: int = 1) -> None:
"""
Register stats_name in the whitelist of the reporter
"""
...

@abc.abstractmethod
def report_duration(
self,
Expand Down Expand Up @@ -68,6 +75,9 @@ def __init__(self, report_interval: int) -> None:
assert report_interval > 0, "Report interval must be positive"
self.report_interval = report_interval

def register_stats(self, stats_name: str, amplifier: int = 1) -> None:
return

def should_report(self, iteration_step: int) -> bool:
return iteration_step % self.report_interval == 0

Expand Down Expand Up @@ -96,6 +106,9 @@ def report_data_amount(
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes"
)

def __repr__(self) -> str:
return "StdLogStatsReporter{ " f"report_interval={self.report_interval} " "}"


@dataclass(frozen=True)
class TBEStatsReporterConfig:
Expand Down
80 changes: 78 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def __init__(
else:
logging.warning("dist is not initialized, treating as single gpu cases")
tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
self.tbe_unique_id = tbe_unique_id
logging.info(f"tbe_unique_id: {tbe_unique_id}")
if not ps_hosts:
logging.info(
Expand Down Expand Up @@ -575,7 +576,7 @@ def __init__(
)
logging.info(
f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, "
f"stats_reporter:{self.stats_reporter if self.stats_reporter else 'none'}, "
f"stats_reporter:{self.stats_reporter if self.stats_reporter else 'none'}"
)

# prefetch launch a series of kernels, we use AsyncSeriesTimer to track the kernel time
Expand All @@ -589,6 +590,12 @@ def __init__(
self.prefetch_parallel_stream_cnt,
0,
)
self.l2_num_cache_misses_stats_name: str = (
f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_cache_misses"
)
self.l2_num_cache_lookups_stats_name: str = (
f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_lookups"
)
if self.stats_reporter:
self.ssd_prefetch_read_timer = AsyncSeriesTimer(
functools.partial(
Expand All @@ -606,6 +613,10 @@ def __init__(
time_unit="us",
)
)
# pyre-ignore
self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name)
# pyre-ignore
self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name)

@torch.jit.ignore
def _report_duration(
Expand Down Expand Up @@ -1596,6 +1607,7 @@ def _report_ssd_stats(self) -> None:
self._report_ssd_l1_cache_stats()
self._report_ssd_io_stats()
self._report_ssd_mem_usage()
self._report_l2_cache_perf_stats()

@torch.jit.ignore
def _report_ssd_l1_cache_stats(self) -> None:
Expand Down Expand Up @@ -1646,7 +1658,7 @@ def _report_ssd_io_stats(self) -> None:
EmbeddingRocksDB will hold stats for total read/write duration in fwd/bwd
this function fetch the stats from EmbeddingRocksDB and report it with stats_reporter
"""
ssd_io_duration = self.ssd_db.get_io_duration(
ssd_io_duration = self.ssd_db.get_rocksdb_io_duration(
self.step, self.stats_reporter.report_interval # pyre-ignore
)

Expand Down Expand Up @@ -1719,6 +1731,70 @@ def _report_ssd_mem_usage(
data_bytes=block_cache_pinned_usage,
)

@torch.jit.ignore
def _report_l2_cache_perf_stats(self) -> None:
"""
EmbeddingKVDB will hold stats for L2+SSD performance in fwd/bwd
this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
"""
if self.stats_reporter is None:
return

stats_reporter: TBEStatsReporter = self.stats_reporter
if not stats_reporter.should_report(self.step):
return

l2_cache_perf_stats = self.ssd_db.get_l2cache_perf(
self.step, stats_reporter.report_interval # pyre-ignore
)

if len(l2_cache_perf_stats) != 6:
logging.error("l2 perf stats should have 6 elements")
return

num_cache_misses = l2_cache_perf_stats[0]
num_lookups = l2_cache_perf_stats[1]
get_total_duration = l2_cache_perf_stats[2]
get_cache_lookup_total_duration = l2_cache_perf_stats[3]
get_weights_fillup_total_duration = l2_cache_perf_stats[4]
get_cache_update_total_duration = l2_cache_perf_stats[5]

stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.l2_num_cache_misses_stats_name, # ods only show integer
data_bytes=num_cache_misses,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.l2_num_cache_lookups_stats_name, # ods only show integer
data_bytes=num_lookups,
)

stats_reporter.report_duration(
iteration_step=self.step,
event_name="l2_cache.perf.get.total_duration_us",
duration_ms=get_total_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="l2_cache.perf.get.cache_lookup_duration_us",
duration_ms=get_cache_lookup_total_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="l2_cache.perf.get.weights_fillup_duration_us",
duration_ms=get_weights_fillup_total_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="l2_cache.perf.get.cache_update_duration_us",
duration_ms=get_cache_update_total_duration,
time_unit="us",
)

# pyre-ignore
def _recording_to_timer(
self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,47 +56,87 @@ void EmbeddingKVDB::set_cuda(
rec->record.end();
}

std::vector<double> EmbeddingKVDB::get_l2cache_perf(
const int64_t step,
const int64_t interval) {
std::vector<double> ret;
ret.reserve(6); // num metrics
if (step > 0 && step % interval == 0) {
int reset_val = 0;
auto num_cache_misses = num_cache_misses_.exchange(reset_val);
auto num_lookups = num_lookups_.exchange(reset_val);
auto get_total_duration = get_total_duration_.exchange(reset_val);
auto get_cache_lookup_total_duration =
get_cache_lookup_total_duration_.exchange(reset_val);
auto get_weights_fillup_total_duration =
get_weights_fillup_total_duration_.exchange(reset_val);
auto get_cache_update_total_duration =
get_cache_update_total_duration_.exchange(reset_val);
ret.push_back(double(num_cache_misses) / interval);
ret.push_back(double(num_lookups) / interval);
ret.push_back(double(get_total_duration) / interval);
ret.push_back(double(get_cache_lookup_total_duration) / interval);
ret.push_back(double(get_weights_fillup_total_duration) / interval);
ret.push_back(double(get_cache_update_total_duration) / interval);
}
return ret;
}

void EmbeddingKVDB::set(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count,
const bool is_bwd) {
folly::stop_watch<std::chrono::microseconds> timer;
if (auto num_evictions = count.item<long>(); num_evictions <= 0) {
XLOG_EVERY_MS(INFO, 60000)
<< "[" << unique_id_ << "]skip set_cuda since number evictions is "
<< num_evictions;
return;
}
auto start_ts = facebook::WallClockUtil::NowInUsecFast();
set_cache(indices, weights, count);
XLOG_EVERY_N(INFO, 1000) << "set_cuda: finished set embeddings in "
<< timer.elapsed().count() << " us.";
}

void EmbeddingKVDB::get(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count) {
if (auto num_lookups = count.item<long>(); num_lookups <= 0) {
XLOG_EVERY_MS(INFO, 60000)
<< "[" << unique_id_ << "]skip get_cuda since number lookups is "
<< num_lookups;
return;
}
ASSERT_EQ(max_D_, weights.size(1));
folly::stop_watch<std::chrono::microseconds> timer;
auto start_ts = facebook::WallClockUtil::NowInUsecFast();
auto cache_context = get_cache(indices, count);
if (cache_context != nullptr) {
if (cache_context->num_misses > 0) {
XLOG(INFO) << "[" << unique_id_
<< "]cache miss: " << cache_context->num_misses << " out of "
<< count.item().toLong() << " lookups";
std::vector<folly::coro::Task<void>> tasks;

auto weight_fillup_start_ts = facebook::WallClockUtil::NowInUsecFast();
tasks.emplace_back(get_kv_db_async(indices, weights, count));
tasks.emplace_back(
cache_memcpy(weights, cache_context->cached_addr_list));
folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks)));
get_weights_fillup_total_duration_ +=
facebook::WallClockUtil::NowInUsecFast() - weight_fillup_start_ts;

auto cache_update_start_ts = facebook::WallClockUtil::NowInUsecFast();
set_cache(indices, weights, count);
get_cache_update_total_duration_ +=
facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts;
} else {
XLOG_EVERY_N(INFO, 1000) << "[" << unique_id_ << "]cache hit 100%";
auto weight_fillup_start_ts = facebook::WallClockUtil::NowInUsecFast();
folly::coro::blockingWait(
cache_memcpy(weights, cache_context->cached_addr_list));
get_weights_fillup_total_duration_ +=
facebook::WallClockUtil::NowInUsecFast() - weight_fillup_start_ts;
}
} else { // no l2 cache
folly::coro::blockingWait(get_kv_db_async(indices, weights, count));
}
XLOG_EVERY_N(INFO, 1000) << "[" << unique_id_
<< "]get_cuda: finished get embeddings in "
<< timer.elapsed().count() << " us.";
get_total_duration_ += facebook::WallClockUtil::NowInUsecFast() - start_ts;
}

std::shared_ptr<CacheContext> EmbeddingKVDB::get_cache(
Expand All @@ -105,7 +145,7 @@ std::shared_ptr<CacheContext> EmbeddingKVDB::get_cache(
if (l2_cache_ == nullptr) {
return nullptr;
}
folly::stop_watch<std::chrono::microseconds> timer;
auto start_ts = facebook::WallClockUtil::NowInUsecFast();
auto indices_addr = indices.data_ptr<int64_t>();
auto num_lookups = count.item<long>();
auto cache_context = std::make_shared<CacheContext>(num_lookups);
Expand Down Expand Up @@ -148,9 +188,21 @@ std::shared_ptr<CacheContext> EmbeddingKVDB::get_cache(
.scheduleOn(executor_tp_.get()));
}
folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks)));
XLOG(INFO) << "[" << unique_id_ << "]finished get cache in "
<< timer.elapsed().count() << " us. " << cache_context->num_misses
<< " cache misses out of " << num_lookups << " lookups";

// the following metrics added here as the current assumption is
// get_cache will only be called in get_cuda path, if assumption no longer
// true, we should wrap this up on the caller side
auto dur = facebook::WallClockUtil::NowInUsecFast() - start_ts;
get_cache_lookup_total_duration_ += dur;
auto cache_misses = cache_context->num_misses.load();
if (num_lookups > 0) {
num_cache_misses_ += cache_misses;
num_lookups_ += num_lookups;
} else {
XLOG_EVERY_MS(INFO, 60000)
<< "[" << unique_id_
<< "]num_lookups is 0, skip collecting the L2 cache miss stats";
}
return cache_context;
}

Expand All @@ -168,7 +220,8 @@ void EmbeddingKVDB::set_cache(
continue;
}
if (!l2_cache_->put(indices_addr[i], weights[i])) {
XLOG(ERR) << "Failed to insert into cache, this shouldn't happen";
XLOG(ERR) << "[" << unique_id_
<< "]Failed to insert into cache, this shouldn't happen";
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,16 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
const int64_t timestep,
const bool is_bwd = false);

/// export internally collected L2 performance metrics out
///
/// @param step the training step that caller side wants to report the stats
/// @param interval report interval in terms of training step
///
/// @return a list of doubles with predefined order for each metrics
std::vector<double> get_l2cache_perf(
const int64_t step,
const int64_t interval);

private:
/// Find non-negative embedding indices in <indices> and shard them into
/// #cachelib_pools pieces to be lookedup in parallel
Expand Down Expand Up @@ -225,6 +235,18 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
const int64_t num_shards_;
const int64_t max_D_;
std::unique_ptr<folly::CPUThreadPoolExecutor> executor_tp_;
// perf stats
// get perf
// cache miss rate(cmr) is avged on cmr per iteration
// instead of SUM(cache miss per interval) / SUM(lookups per interval)
std::atomic<int64_t> num_cache_misses_{0};
std::atomic<int64_t> num_lookups_{0};
std::atomic<int64_t> get_total_duration_{0};
std::atomic<int64_t> get_cache_lookup_total_duration_{0};
std::atomic<int64_t> get_weights_fillup_total_duration_{0};
std::atomic<int64_t> get_cache_update_total_duration_{0};

// set perf
}; // class EmbeddingKVDB

} // namespace kv_db
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,16 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
return impl_->get_mem_usage();
}

std::vector<double> get_io_duration(
std::vector<double> get_rocksdb_io_duration(
const int64_t step,
const int64_t interval) {
return impl_->get_io_duration(step, interval);
return impl_->get_rocksdb_io_duration(step, interval);
}

std::vector<double> get_l2cache_perf(
const int64_t step,
const int64_t interval) {
return impl_->get_l2cache_perf(step, interval);
}

void compact() {
Expand Down Expand Up @@ -365,7 +371,10 @@ static auto embedding_rocks_db_wrapper =
.def("compact", &EmbeddingRocksDBWrapper::compact)
.def("flush", &EmbeddingRocksDBWrapper::flush)
.def("get_mem_usage", &EmbeddingRocksDBWrapper::get_mem_usage)
.def("get_io_duration", &EmbeddingRocksDBWrapper::get_io_duration)
.def(
"get_rocksdb_io_duration",
&EmbeddingRocksDBWrapper::get_rocksdb_io_duration)
.def("get_l2cache_perf", &EmbeddingRocksDBWrapper::get_l2cache_perf)
.def("set", &EmbeddingRocksDBWrapper::set)
.def("get", &EmbeddingRocksDBWrapper::get);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
return mem_usages;
}

std::vector<double> get_io_duration(
std::vector<double> get_rocksdb_io_duration(
const int64_t step,
const int64_t interval) {
std::vector<double> ret;
Expand All @@ -565,9 +565,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
auto read_dur = read_total_duration_.load();
auto fwd_write_dur = fwd_write_total_duration_.load();
auto bwd_write_dur = bwd_write_total_duration_.load();
ret.push_back(double(read_dur / interval));
ret.push_back(double(fwd_write_dur / interval));
ret.push_back(double(bwd_write_dur / interval));
ret.push_back(double(read_dur) / interval);
ret.push_back(double(fwd_write_dur) / interval);
ret.push_back(double(bwd_write_dur) / interval);
read_total_duration_ = 0;
fwd_write_total_duration_ = 0;
bwd_write_total_duration_ = 0;
Expand Down
Loading

0 comments on commit f3b2a2e

Please sign in to comment.