Skip to content

Commit

Permalink
Merge pull request #390 from bluescarni/pr/tweaks
Browse files Browse the repository at this point in the history
Tweaks and improvements to llvm_state
  • Loading branch information
bluescarni authored Jan 23, 2024
2 parents 5df3043 + 95e1f2f commit d55165f
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 121 deletions.
122 changes: 70 additions & 52 deletions include/heyoka/llvm_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <heyoka/config.hpp>

#include <concepts>
#include <cstdint>
#include <initializer_list>
#include <memory>
Expand All @@ -20,6 +21,8 @@
#include <type_traits>
#include <utility>

#include <boost/numeric/conversion/cast.hpp>

#if defined(HEYOKA_HAVE_REAL128)

#include <mp++/real128.hpp>
Expand Down Expand Up @@ -147,63 +150,84 @@ class HEYOKA_DLL_PUBLIC llvm_state

// Implementation details for the variadic constructor.
template <typename... KwArgs>
static auto kw_args_ctor_impl(KwArgs &&...kw_args)
static auto kw_args_ctor_impl(const KwArgs &...kw_args)
{
igor::parser p{kw_args...};

if constexpr (p.has_unnamed_arguments()) {
static_assert(detail::always_false_v<KwArgs...>,
"The variadic arguments in the construction of an llvm_state contain "
"unnamed arguments.");
} else {
// Module name (defaults to empty string).
auto mod_name = [&p]() -> std::string {
if constexpr (p.has(kw::mname)) {
static_assert(!p.has_unnamed_arguments(), "The variadic arguments in the construction of an llvm_state contain "
"unnamed arguments.");

// Module name (defaults to empty string).
auto mod_name = [&p]() -> std::string {
if constexpr (p.has(kw::mname)) {
if constexpr (std::convertible_to<decltype(p(kw::mname)), std::string>) {
return p(kw::mname);
} else {
return "";
static_assert(detail::always_false_v<KwArgs...>, "Invalid type for the 'mname' keyword argument.");
}
}();

// Optimisation level (defaults to 3).
auto opt_level = [&p]() -> unsigned {
if constexpr (p.has(kw::opt_level)) {
return p(kw::opt_level);
} else {
return {};
}
}();

// Optimisation level (defaults to 3).
auto opt_level = [&p]() -> unsigned {
if constexpr (p.has(kw::opt_level)) {
if constexpr (std::integral<std::remove_cvref_t<decltype(p(kw::opt_level))>>) {
return boost::numeric_cast<unsigned>(p(kw::opt_level));
} else {
return 3;
static_assert(detail::always_false_v<KwArgs...>,
"Invalid type for the 'opt_level' keyword argument.");
}
}();
opt_level = clamp_opt_level(opt_level);

// Fast math flag (defaults to false).
auto fmath = [&p]() -> bool {
if constexpr (p.has(kw::fast_math)) {
} else {
return 3;
}
}();
opt_level = clamp_opt_level(opt_level);

// Fast math flag (defaults to false).
auto fmath = [&p]() -> bool {
if constexpr (p.has(kw::fast_math)) {
if constexpr (std::convertible_to<decltype(p(kw::fast_math)), bool>) {
return p(kw::fast_math);
} else {
return false;
static_assert(detail::always_false_v<KwArgs...>,
"Invalid type for the 'fast_math' keyword argument.");
}
}();

// Force usage of AVX512 registers (defaults to false).
auto force_avx512 = [&p]() -> bool {
if constexpr (p.has(kw::force_avx512)) {
} else {
return false;
}
}();

// Force usage of AVX512 registers (defaults to false).
auto force_avx512 = [&p]() -> bool {
if constexpr (p.has(kw::force_avx512)) {
if constexpr (std::convertible_to<decltype(p(kw::force_avx512)), bool>) {
return p(kw::force_avx512);
} else {
return false;
static_assert(detail::always_false_v<KwArgs...>,
"Invalid type for the 'force_avx512' keyword argument.");
}
}();

// Enable SLP vectorization (defaults to false).
auto slp_vectorize = [&p]() -> bool {
if constexpr (p.has(kw::slp_vectorize)) {
} else {
return false;
}
}();

// Enable SLP vectorization (defaults to false).
auto slp_vectorize = [&p]() -> bool {
if constexpr (p.has(kw::slp_vectorize)) {
if constexpr (std::convertible_to<decltype(p(kw::slp_vectorize)), bool>) {
return p(kw::slp_vectorize);
} else {
return false;
static_assert(detail::always_false_v<KwArgs...>,
"Invalid type for the 'slp_vectorize' keyword argument.");
}
}();
} else {
return false;
}
}();

return std::tuple{std::move(mod_name), opt_level, fmath, force_avx512, slp_vectorize};
}
return std::tuple{std::move(mod_name), opt_level, fmath, force_avx512, slp_vectorize};
}
explicit llvm_state(std::tuple<std::string, unsigned, bool, bool, bool> &&);

Expand All @@ -216,21 +240,17 @@ class HEYOKA_DLL_PUBLIC llvm_state
HEYOKA_DLL_LOCAL void compile_impl();
HEYOKA_DLL_LOCAL void add_obj_trigger();

// Meta-programming for the kwargs ctor. Enabled if:
public:
llvm_state();
// NOTE: the constructor is enabled if:
// - there is at least 1 argument (i.e., cannot act as a def ctor),
// - if there is only 1 argument, it cannot be of type llvm_state
// (so that it does not interfere with copy/move ctors).
template <typename... KwArgs>
using kwargs_ctor_enabler = std::enable_if_t<
(sizeof...(KwArgs) > 0u)
&& (sizeof...(KwArgs) > 1u
|| std::conjunction_v<std::negation<std::is_same<detail::uncvref_t<KwArgs>, llvm_state>>...>),
int>;

public:
llvm_state();
template <typename... KwArgs, kwargs_ctor_enabler<KwArgs...> = 0>
explicit llvm_state(KwArgs &&...kw_args) : llvm_state(kw_args_ctor_impl(std::forward<KwArgs>(kw_args)...))
requires(sizeof...(KwArgs) > 0u)
&& ((sizeof...(KwArgs) > 1u) || (!std::same_as<std::remove_cvref_t<KwArgs>, llvm_state> && ...))
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init, hicpp-member-init)
explicit llvm_state(const KwArgs &...kw_args) : llvm_state(kw_args_ctor_impl(kw_args...))
{
}
llvm_state(const llvm_state &);
Expand All @@ -250,9 +270,7 @@ class HEYOKA_DLL_PUBLIC llvm_state
[[nodiscard]] bool fast_math() const;
[[nodiscard]] bool force_avx512() const;
[[nodiscard]] unsigned get_opt_level() const;
void set_opt_level(unsigned);
[[nodiscard]] bool get_slp_vectorize() const;
void set_slp_vectorize(bool);

[[nodiscard]] std::string get_ir() const;
[[nodiscard]] std::string get_bc() const;
Expand Down
12 changes: 6 additions & 6 deletions src/expression_cfunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1374,12 +1374,7 @@ void add_cfunc_c_mode(llvm_state &s, llvm::Type *fp_type, llvm::Value *out_ptr,
get_logger()->trace("cfunc IR creation compact mode runtime: {}", sw);
}

// NOTE: add_cfunc() will add two functions, one called 'name'
// and the other called 'name' + '.strided'. The first function
// indexes into the input/output/par buffers contiguously (that it,
// it assumes the input/output/par scalar/vector values are stored one
// after the other without "holes" between them).
// The second function has an extra trailing argument, the stride
// NOTE: in strided mode, the compiled function has an extra trailing argument, the stride
// value, which indicates the distance between consecutive
// input/output/par values in the buffers. The stride is measured in the number
// of *scalar* values between input/output/par values.
Expand All @@ -1388,6 +1383,11 @@ void add_cfunc_c_mode(llvm_state &s, llvm::Type *fp_type, llvm::Value *out_ptr,
// in the input array. For a batch size of 2 and a stride value of 3,
// the input vector values (of size 2) will be read from indices
// [0, 1], [3, 4], [6, 7], [9, 10], ... in the input array.
//
// In non-strided mode, the compiled function indexes into the
// input/output/par buffers contiguously (that is,
// it assumes the input/output/par scalar/vector values are stored one
// after the other without "holes" between them).
template <typename T, typename F>
auto add_cfunc_impl(llvm_state &s, const std::string &name, const F &fn, std::uint32_t batch_size, bool high_accuracy,
bool compact_mode, bool parallel_mode, [[maybe_unused]] long long prec, bool strided)
Expand Down
4 changes: 1 addition & 3 deletions src/func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,7 @@ llvm::Function *llvm_c_eval_func_helper(const std::string &name,
// Fetch the vector floating-point type.
auto *val_t = make_vector_type(fp_t, batch_size);

const auto na_pair = llvm_c_eval_func_name_args(context, fp_t, name, batch_size, fb.args());
const auto &fname = na_pair.first;
const auto &fargs = na_pair.second;
const auto [fname, fargs] = llvm_c_eval_func_name_args(context, fp_t, name, batch_size, fb.args());

// Try to see if we already created the function.
auto *f = md.getFunction(fname);
Expand Down
43 changes: 26 additions & 17 deletions src/llvm_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include <llvm/IR/Value.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/CodeGen.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SmallVectorMemoryBuffer.h>
#include <llvm/Support/TargetSelect.h>
Expand Down Expand Up @@ -347,8 +348,12 @@ struct llvm_state::jit {
std::unique_ptr<llvm::orc::ThreadSafeContext> m_ctx;
std::optional<std::string> m_object_file;

jit()
explicit jit(unsigned opt_level)
{
// NOTE: we assume here the opt level has already been clamped
// from the outside.
assert(opt_level <= 3u);

// Ensure the native target is inited.
detail::init_native_target();

Expand All @@ -359,8 +364,21 @@ struct llvm_state::jit {
throw std::invalid_argument("Error creating a JITTargetMachineBuilder for the host system");
}
// LCOV_EXCL_STOP
// Set the codegen optimisation level to aggressive.
jtmb->setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive);
// Set the codegen optimisation level.
switch (opt_level) {
case 0u:
jtmb->setCodeGenOptLevel(llvm::CodeGenOpt::None);
break;
case 1u:
jtmb->setCodeGenOptLevel(llvm::CodeGenOpt::Less);
break;
case 2u:
jtmb->setCodeGenOptLevel(llvm::CodeGenOpt::Default);
break;
default:
assert(opt_level == 3u);
jtmb->setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive);
}

// Create the jit builder.
llvm::orc::LLJITBuilder lljit_builder;
Expand Down Expand Up @@ -598,7 +616,7 @@ auto llvm_state_bc_to_module(const std::string &module_name, const std::string &

// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
llvm_state::llvm_state(std::tuple<std::string, unsigned, bool, bool, bool> &&tup)
: m_jitter(std::make_unique<jit>()), m_opt_level(std::get<1>(tup)), m_fast_math(std::get<2>(tup)),
: m_jitter(std::make_unique<jit>(std::get<1>(tup))), m_opt_level(std::get<1>(tup)), m_fast_math(std::get<2>(tup)),
m_force_avx512(std::get<3>(tup)), m_slp_vectorize(std::get<4>(tup)), m_module_name(std::move(std::get<0>(tup)))
{
// Create the module.
Expand All @@ -622,8 +640,9 @@ llvm_state::llvm_state(const llvm_state &other)
// NOTE: start off by:
// - creating a new jit,
// - copying over the options from other.
: m_jitter(std::make_unique<jit>()), m_opt_level(other.m_opt_level), m_fast_math(other.m_fast_math),
m_force_avx512(other.m_force_avx512), m_slp_vectorize(other.m_slp_vectorize), m_module_name(other.m_module_name)
: m_jitter(std::make_unique<jit>(other.m_opt_level)), m_opt_level(other.m_opt_level),
m_fast_math(other.m_fast_math), m_force_avx512(other.m_force_avx512), m_slp_vectorize(other.m_slp_vectorize),
m_module_name(other.m_module_name)
{
if (other.is_compiled()) {
// 'other' was compiled.
Expand Down Expand Up @@ -825,7 +844,7 @@ void llvm_state::load_impl(Archive &ar, unsigned version)
m_builder.reset();

// Reset the jit with a new one.
m_jitter = std::make_unique<jit>();
m_jitter = std::make_unique<jit>(opt_level);

if (cmp) {
// Assign the snapshots.
Expand Down Expand Up @@ -909,11 +928,6 @@ unsigned llvm_state::get_opt_level() const
return m_opt_level;
}

void llvm_state::set_opt_level(unsigned opt_level)
{
m_opt_level = clamp_opt_level(opt_level);
}

bool llvm_state::fast_math() const
{
return m_fast_math;
Expand All @@ -929,11 +943,6 @@ bool llvm_state::get_slp_vectorize() const
return m_slp_vectorize;
}

void llvm_state::set_slp_vectorize(bool flag)
{
m_slp_vectorize = flag;
}

unsigned llvm_state::clamp_opt_level(unsigned opt_level)
{
return std::min<unsigned>(opt_level, 3u);
Expand Down
Loading

0 comments on commit d55165f

Please sign in to comment.