forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathParallelNative.cpp
131 lines (108 loc) · 3.06 KB
/
ParallelNative.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#if AT_PARALLEL_NATIVE
#include <ATen/Parallel.h>
#include <ATen/PTThreadPool.h>
#include <atomic>
#ifdef _OPENMP
#include <omp.h>
#endif
#ifdef TH_BLAS_MKL
#include <mkl.h>
#endif
namespace at {
namespace {
const int NOT_SET = -1;
const int CONSUMED = -2;
// Number of threads set by the user
// NOT_SET -> positive value -> CONSUMED
// or
// NOT_SET -> CONSUMED
// Meaning:
// - NOT_SET - pool not initialized, user value is not set
// - positive value - pool not initialized, user value set
// - CONSUMED - pool is initialized
std::atomic<int> num_intraop_threads{NOT_SET};
// used with _set_in_parallel_region to mark master thread
// as in parallel region while executing parallel primitives
thread_local bool in_parallel_region_ = false;
// thread number (task_id) set by parallel primitive
thread_local size_t thread_num_ = 0;
int _num_pool_threads(int nthreads) {
if (nthreads == NOT_SET) {
nthreads = TaskThreadPoolBase::defaultNumThreads();
} else {
TORCH_INTERNAL_ASSERT(nthreads > 0);
}
// minus one because of the master thread
return nthreads - 1;
}
} // namespace
namespace internal {
TaskThreadPoolBase& _get_intraop_pool() {
static std::shared_ptr<TaskThreadPoolBase> pool =
ThreadPoolRegistry()->Create(
"C10",
/* device_id */ 0,
/* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)),
/* create_new */ true); // create a separate thread pool for intra-op
return *pool;
}
void _set_in_parallel_region(bool in_region) {
in_parallel_region_ = in_region;
}
void _set_thread_num(size_t thread_num) {
thread_num_ = thread_num;
}
void _unset_thread_num() {
thread_num_ = 0;
}
} // namespace internal
//TODO: use OMP and MKL env. vars as default values
void init_num_threads() {
#ifdef _OPENMP
omp_set_num_threads(1);
#endif
#ifdef TH_BLAS_MKL
mkl_set_num_threads(1);
#endif
}
void set_num_threads(int nthreads) {
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
int no_value = NOT_SET;
TORCH_CHECK(num_intraop_threads.compare_exchange_strong(no_value, nthreads),
"Error: cannot set number of interop threads "
"after parallel work has started or after set_num_threads call");
}
int get_num_threads() {
// not initializing pool unnecessarily,
// because pool cannot be resized after initialization
int nthreads = num_intraop_threads.load();
if (nthreads > 0) {
return nthreads;
} else if (nthreads == NOT_SET) {
return TaskThreadPoolBase::defaultNumThreads();
} else {
TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
return internal::_get_intraop_pool().size() + 1;
}
}
int get_thread_num() {
return thread_num_;
}
bool in_parallel_region() {
return in_parallel_region_ || (
num_intraop_threads.load() == CONSUMED &&
internal::_get_intraop_pool().inThreadPool()
);
}
void intraop_launch(std::function<void()> func) {
if (!in_parallel_region()) {
internal::_get_intraop_pool.run([func](){
func();
});
} else {
// execute inline if we're in parallel region
func();
}
}
} // namespace at
#endif