forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrecord_function.cpp
98 lines (82 loc) · 2.08 KB
/
record_function.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/autograd/function.h>
namespace torch { namespace autograd { namespace profiler {
namespace {
std::vector<RecordFunctionCallback> start_callbacks;
std::vector<RecordFunctionCallback> end_callbacks;
size_t callback_needs_inputs = 0;
thread_local RecordFunction* thread_local_func_ = nullptr;
}
void pushCallback(
RecordFunctionCallback start,
RecordFunctionCallback end,
bool needs_inputs) {
start_callbacks.push_back(start);
end_callbacks.push_back(end);
if (callback_needs_inputs > 0 || needs_inputs) {
++callback_needs_inputs;
}
}
void popCallback() {
if (start_callbacks.empty()) {
throw std::runtime_error("Empty callbacks stack");
}
start_callbacks.pop_back();
end_callbacks.pop_back();
if (callback_needs_inputs > 0) {
--callback_needs_inputs;
}
}
bool hasCallbacks() {
return !start_callbacks.empty();
}
bool needsInputs() {
return callback_needs_inputs > 0;
}
void RecordFunction::before(const char* name, int64_t sequence_nr) {
if (!hasCallbacks()) {
return;
}
AT_ASSERT(!initialized_);
name_ = StringView(name);
sequence_nr_ = sequence_nr;
initialized_ = true;
processCallbacks();
}
void RecordFunction::before(std::string name, int64_t sequence_nr) {
if (!hasCallbacks()) {
return;
}
AT_ASSERT(!initialized_);
name_ = StringView(std::move(name));
sequence_nr_ = sequence_nr;
initialized_ = true;
processCallbacks();
}
void RecordFunction::before(Function* fn, int64_t sequence_nr) {
if (!hasCallbacks()) {
return;
}
AT_ASSERT(!initialized_);
fn_ = fn;
name_ = StringView(fn->name());
sequence_nr_ = (sequence_nr >= 0) ? sequence_nr : fn->sequence_nr();
initialized_ = true;
processCallbacks();
}
void RecordFunction::processCallbacks() {
parent_ = thread_local_func_;
thread_local_func_ = this;
for (const auto& cb : start_callbacks) {
cb(*this);
}
}
RecordFunction::~RecordFunction() {
if (initialized_) {
for (const auto& cb : end_callbacks) {
cb(*this);
}
thread_local_func_ = parent_;
}
}
}}}