From f3b2a2ef62e470e2ec6b2900f5269725f39c4a9e Mon Sep 17 00:00:00 2001 From: Guanqiao Wang Date: Thu, 29 Aug 2024 12:38:59 -0700 Subject: [PATCH] add ods logging for l2 cache perf (#3031) Summary: X-link: https://github.com/pytorch/torchrec/pull/2335 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3031 X-link: https://github.com/facebookresearch/FBGEMM/pull/129 collect performance related metrics from KV store and export them to ODS Differential Revision: D61417980 --- fbgemm_gpu/fbgemm_gpu/runtime_monitor.py | 13 +++ fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 80 ++++++++++++++++- .../kv_db_table_batched_embeddings.cpp | 85 +++++++++++++++---- .../kv_db_table_batched_embeddings.h | 22 +++++ .../ssd_split_table_batched_embeddings.cpp | 15 +++- .../ssd_table_batched_embeddings.h | 8 +- fbgemm_gpu/test/tbe/cache/cache_common.py | 3 + 7 files changed, 201 insertions(+), 25 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py b/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py index de4ec26ce8..89e4fea4ef 100644 --- a/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py +++ b/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py @@ -33,6 +33,13 @@ def should_report(self, iteration_step: int) -> bool: """ ... + @abc.abstractmethod + def register_stats(self, stats_name: str, amplifier: int = 1) -> None: + """ + Register stats_name in the whitelist of the reporter + """ + ... + @abc.abstractmethod def report_duration( self, @@ -68,6 +75,9 @@ def __init__(self, report_interval: int) -> None: assert report_interval > 0, "Report interval must be positive" self.report_interval = report_interval + def register_stats(self, stats_name: str, amplifier: int = 1) -> None: + return + def should_report(self, iteration_step: int) -> bool: return iteration_step % self.report_interval == 0 @@ -96,6 +106,9 @@ def report_data_amount( f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes" ) + def __repr__(self) -> str: + return "StdLogStatsReporter{ " f"report_interval={self.report_interval} " "}" + @dataclass(frozen=True) class TBEStatsReporterConfig: diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 6546806f94..6ba3d9a6b6 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -355,6 +355,7 @@ def __init__( else: logging.warning("dist is not initialized, treating as single gpu cases") tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index + self.tbe_unique_id = tbe_unique_id logging.info(f"tbe_unique_id: {tbe_unique_id}") if not ps_hosts: logging.info( @@ -575,7 +576,7 @@ def __init__( ) logging.info( f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, " - f"stats_reporter:{self.stats_reporter if self.stats_reporter else 'none'}, " + f"stats_reporter:{self.stats_reporter if self.stats_reporter else 'none'}" ) # prefetch launch a series of kernels, we use AsyncSeriesTimer to track the kernel time @@ -589,6 +590,12 @@ def __init__( self.prefetch_parallel_stream_cnt, 0, ) + self.l2_num_cache_misses_stats_name: str = ( + f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_cache_misses" + ) + self.l2_num_cache_lookups_stats_name: str = ( + f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_lookups" + ) if self.stats_reporter: self.ssd_prefetch_read_timer = AsyncSeriesTimer( functools.partial( @@ -606,6 +613,10 @@ def __init__( time_unit="us", ) ) + # pyre-ignore + self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name) + # pyre-ignore + self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name) @torch.jit.ignore def _report_duration( @@ -1596,6 +1607,7 @@ def _report_ssd_stats(self) -> None: self._report_ssd_l1_cache_stats() self._report_ssd_io_stats() self._report_ssd_mem_usage() + self._report_l2_cache_perf_stats() @torch.jit.ignore def _report_ssd_l1_cache_stats(self) -> None: @@ -1646,7 +1658,7 @@ def _report_ssd_io_stats(self) -> None: EmbeddingRocksDB will hold stats for total read/write duration in fwd/bwd this function fetch the stats from EmbeddingRocksDB and report it with stats_reporter """ - ssd_io_duration = self.ssd_db.get_io_duration( + ssd_io_duration = self.ssd_db.get_rocksdb_io_duration( self.step, self.stats_reporter.report_interval # pyre-ignore ) @@ -1719,6 +1731,70 @@ def _report_ssd_mem_usage( data_bytes=block_cache_pinned_usage, ) + @torch.jit.ignore + def _report_l2_cache_perf_stats(self) -> None: + """ + EmbeddingKVDB will hold stats for L2+SSD performance in fwd/bwd + this function fetch the stats from EmbeddingKVDB and report it with stats_reporter + """ + if self.stats_reporter is None: + return + + stats_reporter: TBEStatsReporter = self.stats_reporter + if not stats_reporter.should_report(self.step): + return + + l2_cache_perf_stats = self.ssd_db.get_l2cache_perf( + self.step, stats_reporter.report_interval # pyre-ignore + ) + + if len(l2_cache_perf_stats) != 6: + logging.error("l2 perf stats should have 6 elements") + return + + num_cache_misses = l2_cache_perf_stats[0] + num_lookups = l2_cache_perf_stats[1] + get_total_duration = l2_cache_perf_stats[2] + get_cache_lookup_total_duration = l2_cache_perf_stats[3] + get_weights_fillup_total_duration = l2_cache_perf_stats[4] + get_cache_update_total_duration = l2_cache_perf_stats[5] + + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name=self.l2_num_cache_misses_stats_name, # ods only show integer + data_bytes=num_cache_misses, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name=self.l2_num_cache_lookups_stats_name, # ods only show integer + data_bytes=num_lookups, + ) + + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.total_duration_us", + duration_ms=get_total_duration, + time_unit="us", + ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.cache_lookup_duration_us", + duration_ms=get_cache_lookup_total_duration, + time_unit="us", + ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.weights_fillup_duration_us", + duration_ms=get_weights_fillup_total_duration, + time_unit="us", + ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.cache_update_duration_us", + duration_ms=get_cache_update_total_duration, + time_unit="us", + ) + # pyre-ignore def _recording_to_timer( self, timer: Optional[AsyncSeriesTimer], **kwargs: Any diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index d2869f10e2..75c74f605c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -56,47 +56,87 @@ void EmbeddingKVDB::set_cuda( rec->record.end(); } +std::vector EmbeddingKVDB::get_l2cache_perf( + const int64_t step, + const int64_t interval) { + std::vector ret; + ret.reserve(6); // num metrics + if (step > 0 && step % interval == 0) { + int reset_val = 0; + auto num_cache_misses = num_cache_misses_.exchange(reset_val); + auto num_lookups = num_lookups_.exchange(reset_val); + auto get_total_duration = get_total_duration_.exchange(reset_val); + auto get_cache_lookup_total_duration = + get_cache_lookup_total_duration_.exchange(reset_val); + auto get_weights_fillup_total_duration = + get_weights_fillup_total_duration_.exchange(reset_val); + auto get_cache_update_total_duration = + get_cache_update_total_duration_.exchange(reset_val); + ret.push_back(double(num_cache_misses) / interval); + ret.push_back(double(num_lookups) / interval); + ret.push_back(double(get_total_duration) / interval); + ret.push_back(double(get_cache_lookup_total_duration) / interval); + ret.push_back(double(get_weights_fillup_total_duration) / interval); + ret.push_back(double(get_cache_update_total_duration) / interval); + } + return ret; +} + void EmbeddingKVDB::set( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, const bool is_bwd) { - folly::stop_watch timer; + if (auto num_evictions = count.item(); num_evictions <= 0) { + XLOG_EVERY_MS(INFO, 60000) + << "[" << unique_id_ << "]skip set_cuda since number evictions is " + << num_evictions; + return; + } + auto start_ts = facebook::WallClockUtil::NowInUsecFast(); set_cache(indices, weights, count); - XLOG_EVERY_N(INFO, 1000) << "set_cuda: finished set embeddings in " - << timer.elapsed().count() << " us."; } void EmbeddingKVDB::get( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count) { + if (auto num_lookups = count.item(); num_lookups <= 0) { + XLOG_EVERY_MS(INFO, 60000) + << "[" << unique_id_ << "]skip get_cuda since number lookups is " + << num_lookups; + return; + } ASSERT_EQ(max_D_, weights.size(1)); - folly::stop_watch timer; + auto start_ts = facebook::WallClockUtil::NowInUsecFast(); auto cache_context = get_cache(indices, count); if (cache_context != nullptr) { if (cache_context->num_misses > 0) { - XLOG(INFO) << "[" << unique_id_ - << "]cache miss: " << cache_context->num_misses << " out of " - << count.item().toLong() << " lookups"; std::vector> tasks; + auto weight_fillup_start_ts = facebook::WallClockUtil::NowInUsecFast(); tasks.emplace_back(get_kv_db_async(indices, weights, count)); tasks.emplace_back( cache_memcpy(weights, cache_context->cached_addr_list)); folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + get_weights_fillup_total_duration_ += + facebook::WallClockUtil::NowInUsecFast() - weight_fillup_start_ts; + + auto cache_update_start_ts = facebook::WallClockUtil::NowInUsecFast(); set_cache(indices, weights, count); + get_cache_update_total_duration_ += + facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; } else { - XLOG_EVERY_N(INFO, 1000) << "[" << unique_id_ << "]cache hit 100%"; + auto weight_fillup_start_ts = facebook::WallClockUtil::NowInUsecFast(); folly::coro::blockingWait( cache_memcpy(weights, cache_context->cached_addr_list)); + get_weights_fillup_total_duration_ += + facebook::WallClockUtil::NowInUsecFast() - weight_fillup_start_ts; } } else { // no l2 cache folly::coro::blockingWait(get_kv_db_async(indices, weights, count)); } - XLOG_EVERY_N(INFO, 1000) << "[" << unique_id_ - << "]get_cuda: finished get embeddings in " - << timer.elapsed().count() << " us."; + get_total_duration_ += facebook::WallClockUtil::NowInUsecFast() - start_ts; } std::shared_ptr EmbeddingKVDB::get_cache( @@ -105,7 +145,7 @@ std::shared_ptr EmbeddingKVDB::get_cache( if (l2_cache_ == nullptr) { return nullptr; } - folly::stop_watch timer; + auto start_ts = facebook::WallClockUtil::NowInUsecFast(); auto indices_addr = indices.data_ptr(); auto num_lookups = count.item(); auto cache_context = std::make_shared(num_lookups); @@ -148,9 +188,21 @@ std::shared_ptr EmbeddingKVDB::get_cache( .scheduleOn(executor_tp_.get())); } folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); - XLOG(INFO) << "[" << unique_id_ << "]finished get cache in " - << timer.elapsed().count() << " us. " << cache_context->num_misses - << " cache misses out of " << num_lookups << " lookups"; + + // the following metrics added here as the current assumption is + // get_cache will only be called in get_cuda path, if assumption no longer + // true, we should wrap this up on the caller side + auto dur = facebook::WallClockUtil::NowInUsecFast() - start_ts; + get_cache_lookup_total_duration_ += dur; + auto cache_misses = cache_context->num_misses.load(); + if (num_lookups > 0) { + num_cache_misses_ += cache_misses; + num_lookups_ += num_lookups; + } else { + XLOG_EVERY_MS(INFO, 60000) + << "[" << unique_id_ + << "]num_lookups is 0, skip collecting the L2 cache miss stats"; + } return cache_context; } @@ -168,7 +220,8 @@ void EmbeddingKVDB::set_cache( continue; } if (!l2_cache_->put(indices_addr[i], weights[i])) { - XLOG(ERR) << "Failed to insert into cache, this shouldn't happen"; + XLOG(ERR) << "[" << unique_id_ + << "]Failed to insert into cache, this shouldn't happen"; } } } diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index d34dd7617e..00efe7bcd8 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -168,6 +168,16 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const int64_t timestep, const bool is_bwd = false); + /// export internally collected L2 performance metrics out + /// + /// @param step the training step that caller side wants to report the stats + /// @param interval report interval in terms of training step + /// + /// @return a list of doubles with predefined order for each metrics + std::vector get_l2cache_perf( + const int64_t step, + const int64_t interval); + private: /// Find non-negative embedding indices in and shard them into /// #cachelib_pools pieces to be lookedup in parallel @@ -225,6 +235,18 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const int64_t num_shards_; const int64_t max_D_; std::unique_ptr executor_tp_; + // perf stats + // get perf + // cache miss rate(cmr) is avged on cmr per iteration + // instead of SUM(cache miss per interval) / SUM(lookups per interval) + std::atomic num_cache_misses_{0}; + std::atomic num_lookups_{0}; + std::atomic get_total_duration_{0}; + std::atomic get_cache_lookup_total_duration_{0}; + std::atomic get_weights_fillup_total_duration_{0}; + std::atomic get_cache_update_total_duration_{0}; + + // set perf }; // class EmbeddingKVDB } // namespace kv_db diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index c6db3984c9..64a7a99c1b 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -286,10 +286,16 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { return impl_->get_mem_usage(); } - std::vector get_io_duration( + std::vector get_rocksdb_io_duration( const int64_t step, const int64_t interval) { - return impl_->get_io_duration(step, interval); + return impl_->get_rocksdb_io_duration(step, interval); + } + + std::vector get_l2cache_perf( + const int64_t step, + const int64_t interval) { + return impl_->get_l2cache_perf(step, interval); } void compact() { @@ -365,7 +371,10 @@ static auto embedding_rocks_db_wrapper = .def("compact", &EmbeddingRocksDBWrapper::compact) .def("flush", &EmbeddingRocksDBWrapper::flush) .def("get_mem_usage", &EmbeddingRocksDBWrapper::get_mem_usage) - .def("get_io_duration", &EmbeddingRocksDBWrapper::get_io_duration) + .def( + "get_rocksdb_io_duration", + &EmbeddingRocksDBWrapper::get_rocksdb_io_duration) + .def("get_l2cache_perf", &EmbeddingRocksDBWrapper::get_l2cache_perf) .def("set", &EmbeddingRocksDBWrapper::set) .def("get", &EmbeddingRocksDBWrapper::get); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 4efb13b782..30031fb0e2 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -556,7 +556,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { return mem_usages; } - std::vector get_io_duration( + std::vector get_rocksdb_io_duration( const int64_t step, const int64_t interval) { std::vector ret; @@ -565,9 +565,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { auto read_dur = read_total_duration_.load(); auto fwd_write_dur = fwd_write_total_duration_.load(); auto bwd_write_dur = bwd_write_total_duration_.load(); - ret.push_back(double(read_dur / interval)); - ret.push_back(double(fwd_write_dur / interval)); - ret.push_back(double(bwd_write_dur / interval)); + ret.push_back(double(read_dur) / interval); + ret.push_back(double(fwd_write_dur) / interval); + ret.push_back(double(bwd_write_dur) / interval); read_total_duration_ = 0; fwd_write_total_duration_ = 0; bwd_write_total_duration_ = 0; diff --git a/fbgemm_gpu/test/tbe/cache/cache_common.py b/fbgemm_gpu/test/tbe/cache/cache_common.py index a677dfd998..f744186693 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_common.py +++ b/fbgemm_gpu/test/tbe/cache/cache_common.py @@ -54,6 +54,9 @@ def __init__(self, reporting_interval: int = 1) -> None: def should_report(self, iteration_step: int) -> bool: return (iteration_step - 1) % self.reporting_interval == 0 + def register_stats(self, stats_name: str, amplifier: int = 1) -> None: + return + def report_duration( self, iteration_step: int,