Skip to content

Commit

Permalink
Made Real_t default floating point type for CPU implementation and re…
Browse files Browse the repository at this point in the history
…moved template parameter from TrainCpu(...).
  • Loading branch information
Simon Pfreundschuh authored and lmoneta committed Oct 31, 2016
1 parent f66e511 commit cc5288e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tmva/tmva/inc/TMVA/DNN/Architectures/Cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace DNN
* for this architecture as well as the remaining functions in the low-level
* interface in the form of static members.
*/
template<typename AReal>
template<typename AReal = Real_t>
class TCpu
{
public:
Expand Down
1 change: 0 additions & 1 deletion tmva/tmva/inc/TMVA/MethodDNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ class MethodDNN : public MethodBase
TString tokenDelim);
void Train();
void TrainGpu();
template <typename AFloat>
void TrainCpu();

virtual Double_t GetMvaValue( Double_t* err=0, Double_t* errUpper=0 );
Expand Down
21 changes: 10 additions & 11 deletions tmva/tmva/src/MethodDNN.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ void TMVA::MethodDNN::Train()
Log() << kFATAL << "OpenCL backend not yes supported." << Endl;
return;
} else if (fArchitectureString == "CPU") {
TrainCpu<Double_t>();
TrainCpu();
if (!fExitFromTraining) fIPyMaxIter = fIPyCurrentIter;
ExitFromTraining();
return;
Expand Down Expand Up @@ -895,7 +895,6 @@ void TMVA::MethodDNN::TrainGpu()
}

//______________________________________________________________________________
template<typename AFloat>
void TMVA::MethodDNN::TrainCpu()
{

Expand All @@ -919,7 +918,7 @@ void TMVA::MethodDNN::TrainCpu()
<< fTrainingSettings.size() << ":" << Endl;
trainingPhase++;

TNet<TCpu<AFloat>> net(settings.batchSize, fNet);
TNet<TCpu<>> net(settings.batchSize, fNet);
net.SetWeightDecay(settings.weightDecay);
net.SetRegularization(settings.regularization);
// Need to convert dropoutprobabilities to conventions used
Expand All @@ -933,7 +932,7 @@ void TMVA::MethodDNN::TrainCpu()
net.InitializeGradients();
auto testNet = net.CreateClone(settings.batchSize);

using DataLoader_t = TDataLoader<TMVAInput_t, TCpu<AFloat>>;
using DataLoader_t = TDataLoader<TMVAInput_t, TCpu<>>;

size_t nThreads = 1;
DataLoader_t trainingData(GetEventCollection(Types::kTraining),
Expand All @@ -946,22 +945,22 @@ void TMVA::MethodDNN::TrainCpu()
testNet.GetBatchSize(),
net.GetInputWidth(),
net.GetOutputWidth(), nThreads);
DNN::TGradientDescent<TCpu<AFloat>> minimizer(settings.learningRate,
DNN::TGradientDescent<TCpu<>> minimizer(settings.learningRate,
settings.convergenceSteps,
settings.testInterval);

std::vector<TNet<TCpu<AFloat>>> nets{};
std::vector<TBatch<TCpu<AFloat>>> batches{};
std::vector<TNet<TCpu<>>> nets{};
std::vector<TBatch<TCpu<>>> batches{};
nets.reserve(nThreads);
for (size_t i = 0; i < nThreads; i++) {
nets.push_back(net);
for (size_t j = 0; j < net.GetDepth(); j++)
{
auto &masterLayer = net.GetLayer(j);
auto &layer = nets.back().GetLayer(j);
TCpu<AFloat>::Copy(layer.GetWeights(),
TCpu<>::Copy(layer.GetWeights(),
masterLayer.GetWeights());
TCpu<AFloat>::Copy(layer.GetBiases(),
TCpu<>::Copy(layer.GetBiases(),
masterLayer.GetBiases());
}
}
Expand Down Expand Up @@ -1004,7 +1003,7 @@ void TMVA::MethodDNN::TrainCpu()
if ((stepCount % minimizer.GetTestInterval()) == 0) {

// Compute test error.
AFloat testError = 0.0;
Double_t testError = 0.0;
for (auto batch : testData) {
auto inputMatrix = batch.GetInput();
auto outputMatrix = batch.GetOutput();
Expand All @@ -1015,7 +1014,7 @@ void TMVA::MethodDNN::TrainCpu()
end = std::chrono::system_clock::now();

// Compute training error.
AFloat trainingError = 0.0;
Double_t trainingError = 0.0;
for (auto batch : trainingData) {
auto inputMatrix = batch.GetInput();
auto outputMatrix = batch.GetOutput();
Expand Down

0 comments on commit cc5288e

Please sign in to comment.