forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBCECriterion.cu
134 lines (118 loc) · 3.44 KB
/
BCECriterion.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <THCUNN/THCUNN.h>
#include <THCUNN/common.h>
#include <TH/THHalf.h>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCApply.cuh>
#include <thrust/functional.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/transform.h>
#include <thrust/transform_reduce.h>
#include <thrust/system/cuda/execution_policy.h>
template <typename T>
inline __host__ __device__ T eps();
template <>
inline __host__ __device__ float eps() { return 1e-12f; }
template <>
inline __host__ __device__ double eps() { return 1e-12; }
template <typename T>
inline __host__ __device__ T safe_log(T a) {
if (a == 0.)
{
return THCNumerics<T>::log(eps<T>());
}
return THCNumerics<T>::log(a);
}
template <typename Dtype, typename Acctype>
struct bce_functor
{
template <class Tuple>
__host__ __device__
Acctype operator()(Tuple x)
{
Dtype input = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
assert(input >= 0. && input <= 1.);
return - (t * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(input))
+ (Acctype(1) - t) * safe_log<Acctype>(Acctype(1) - input));
}
};
template <typename Dtype, typename Acctype>
struct bce_updateOutput_no_reduce_functor
{
__forceinline__ __host__ __device__
void operator()(
const Dtype *input,
const Dtype *target,
Dtype *output)
{
assert(*input >= 0. && *input <= 1.);
*output = ScalarConvert<Acctype, Dtype>::to(
-(*target * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(*input)) +
(Acctype(1) - *target) * safe_log<Acctype>(Acctype(1) - *input)));
}
};
template <typename Dtype, typename Acctype>
struct bce_functor_weights
{
template <class Tuple>
__host__ __device__
Acctype operator()(Tuple x)
{
Dtype input = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
Dtype w = thrust::get<2>(x);
assert(input >= 0. && input <= 1.);
return - w * (t * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(input)) +
(Acctype(1) - t) * safe_log<Acctype>(Acctype(1) - input));
}
};
template <typename Dtype, typename Acctype>
struct bce_updateGradInput_no_reduce_functor
{
__forceinline__ __host__ __device__
void operator()(
const Dtype *x,
const Dtype *t,
Dtype *gradInput)
{
*gradInput = ScalarConvert<Acctype,Dtype>::to(
- (*t - *x) / ((Acctype(1) - *x + eps<Acctype>()) * (*x + eps<Acctype>())));
}
};
template <typename Dtype, typename Acctype>
struct bce_updateGradInput_functor
{
const Dtype norm;
bce_updateGradInput_functor(Dtype norm_)
: norm(norm_)
{}
template <class Tuple>
__host__ __device__
Dtype operator()(Tuple x)
{
Dtype o = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
return ScalarConvert<Acctype,Dtype>::to(- (t - o) / ((Acctype(1) - o + eps<Acctype>()) * (o + eps<Acctype>())) * norm);
}
};
template <typename Dtype, typename Acctype>
struct bce_updateGradInput_functor_weights
{
const Dtype norm;
bce_updateGradInput_functor_weights(Dtype norm_)
: norm(norm_)
{}
template <class Tuple>
__host__ __device__
Dtype operator()(Tuple x)
{
Dtype o = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
Dtype w = thrust::get<2>(x);
return ScalarConvert<Acctype, Dtype>::to(- (t - o) / ((Acctype(1) - o + eps<Acctype>()) * (o + eps<Acctype>())) * norm * w);
}
};
#include <THCUNN/generic/BCECriterion.cu>
#include <THC/THCGenerateFloatTypes.h>