Skip to content

Commit

Permalink
Merge pull request #356 from bluescarni/pr/relu
Browse files Browse the repository at this point in the history
Add ReLU
  • Loading branch information
bluescarni authored Nov 2, 2023
2 parents 2a0d493 + 88222ed commit ec51518
Show file tree
Hide file tree
Showing 14 changed files with 1,248 additions and 5 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ set(HEYOKA_SRC_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/log.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/pow.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/sigmoid.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/relu.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/sin.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/sqrt.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/tan.cpp"
Expand Down
4 changes: 4 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Changelog
New
~~~

- Implement ``ReLU`` and its derivative in the expression
system (`#356 <https://github.com/bluescarni/heyoka/pull/356>`__).
- Implement the eccentric longitude :math:`F` in the expression
system (`#352 <https://github.com/bluescarni/heyoka/pull/352>`__).
- Implement the delta eccentric anomaly :math:`\Delta E` in the expression
Expand All @@ -24,6 +26,8 @@ Changes
Fix
~~~

- Fix compiler warning when building without SLEEF support
(`#356 <https://github.com/bluescarni/heyoka/pull/356>`__).
- Improve the numerical stability of the VSOP2013 model
(`#353 <https://github.com/bluescarni/heyoka/pull/353>`__).
- Improve the numerical stability of the Kepler solver
Expand Down
1 change: 1 addition & 0 deletions include/heyoka/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <heyoka/math/log.hpp>
#include <heyoka/math/pow.hpp>
#include <heyoka/math/prod.hpp>
#include <heyoka/math/relu.hpp>
#include <heyoka/math/sigmoid.hpp>
#include <heyoka/math/sin.hpp>
#include <heyoka/math/sinh.hpp>
Expand Down
97 changes: 97 additions & 0 deletions include/heyoka/math/relu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2020, 2021, 2022, 2023 Francesco Biscani ([email protected]), Dario Izzo ([email protected])
//
// This file is part of the heyoka library.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef HEYOKA_MATH_RELU_HPP
#define HEYOKA_MATH_RELU_HPP

#include <cstdint>
#include <vector>

#include <heyoka/config.hpp>
#include <heyoka/detail/fwd_decl.hpp>
#include <heyoka/detail/llvm_fwd.hpp>
#include <heyoka/detail/visibility.hpp>
#include <heyoka/func.hpp>
#include <heyoka/s11n.hpp>

HEYOKA_BEGIN_NAMESPACE

namespace detail
{

class HEYOKA_DLL_PUBLIC relu_impl : public func_base
{
friend class boost::serialization::access;
template <typename Archive>
void serialize(Archive &ar, unsigned)
{
ar &boost::serialization::base_object<func_base>(*this);
}

public:
relu_impl();
explicit relu_impl(expression);

[[nodiscard]] expression normalise() const;

[[nodiscard]] std::vector<expression> gradient() const;

[[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector<llvm::Value *> &, llvm::Value *,
llvm::Value *, llvm::Value *, std::uint32_t, bool) const;

[[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const;

llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector<std::uint32_t> &,
const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t, bool) const;

llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, bool) const;
};

class HEYOKA_DLL_PUBLIC relup_impl : public func_base
{
friend class boost::serialization::access;
template <typename Archive>
void serialize(Archive &ar, unsigned)
{
ar &boost::serialization::base_object<func_base>(*this);
}

public:
relup_impl();
explicit relup_impl(expression);

[[nodiscard]] expression normalise() const;

[[nodiscard]] std::vector<expression> gradient() const;

[[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector<llvm::Value *> &, llvm::Value *,
llvm::Value *, llvm::Value *, std::uint32_t, bool) const;

[[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const;

llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector<std::uint32_t> &,
const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t, bool) const;

llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, bool) const;
};

} // namespace detail

HEYOKA_DLL_PUBLIC expression relu(expression);

HEYOKA_DLL_PUBLIC expression relup(expression);

HEYOKA_END_NAMESPACE

HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::relu_impl)

HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::relup_impl)

#endif
2 changes: 1 addition & 1 deletion include/heyoka/math/sin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifndef HEYOKA_MATH_SIN_HPP
#define HEYOKA_MATH_SIN_HPP

#include <cstddef>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down
6 changes: 6 additions & 0 deletions src/detail/vector_math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ namespace

using vf_map_t = std::unordered_map<std::string, std::vector<vf_info>>;

// NOTE: make_vfinfo() is necessary if *any* vectorisation backend is active,
// but at the moment we have only SLEEF.
#if defined(HEYOKA_WITH_SLEEF)

auto make_vfinfo(const char *s_name, std::string v_name, std::uint32_t width, std::uint32_t nargs)
{
assert(nargs == 1u || nargs == 2u);
Expand All @@ -40,6 +44,8 @@ auto make_vfinfo(const char *s_name, std::string v_name, std::uint32_t width, st
return ret;
}

#endif

#if defined(HEYOKA_WITH_SLEEF)

// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
Expand Down
10 changes: 9 additions & 1 deletion src/math/cos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#include <boost/numeric/conversion/cast.hpp>

#include <fmt/format.h>
#include <fmt/core.h>

#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/DerivedTypes.h>
Expand Down Expand Up @@ -223,6 +223,8 @@ llvm::Value *taylor_diff_cos_impl(llvm_state &s, llvm::Type *fp_t, const cos_imp
return llvm_fdiv(s, ret_acc, div);
}

// LCOV_EXCL_START

// All the other cases.
template <typename U, std::enable_if_t<!is_num_param_v<U>, int> = 0>
llvm::Value *taylor_diff_cos_impl(llvm_state &, llvm::Type *, const cos_impl &, const std::vector<std::uint32_t> &,
Expand All @@ -233,6 +235,8 @@ llvm::Value *taylor_diff_cos_impl(llvm_state &, llvm::Type *, const cos_impl &,
"An invalid argument type was encountered while trying to build the Taylor derivative of a cosine");
}

// LCOV_EXCL_STOP

llvm::Value *taylor_diff_cos(llvm_state &s, llvm::Type *fp_t, const cos_impl &f, const std::vector<std::uint32_t> &deps,
const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
Expand Down Expand Up @@ -370,6 +374,8 @@ llvm::Function *taylor_c_diff_func_cos_impl(llvm_state &s, llvm::Type *fp_t, con
return f;
}

// LCOV_EXCL_START

// All the other cases.
template <typename U, std::enable_if_t<!is_num_param_v<U>, int> = 0>
llvm::Function *taylor_c_diff_func_cos_impl(llvm_state &, llvm::Type *, const cos_impl &, const U &, std::uint32_t,
Expand All @@ -379,6 +385,8 @@ llvm::Function *taylor_c_diff_func_cos_impl(llvm_state &, llvm::Type *, const co
"of a cosine in compact mode");
}

// LCOV_EXCL_STOP

llvm::Function *taylor_c_diff_func_cos(llvm_state &s, llvm::Type *fp_t, const cos_impl &fn, std::uint32_t n_uvars,
std::uint32_t batch_size)
{
Expand Down
Loading

0 comments on commit ec51518

Please sign in to comment.