Skip to content

Commit

Permalink
Improvement on eve::search
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisYaroshevskiy authored Sep 24, 2024
1 parent 279fad9 commit 5895aa6
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 51 deletions.
230 changes: 180 additions & 50 deletions include/eve/module/algo/algo/search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,102 @@
namespace eve::algo
{

namespace detail
{

/*
* This is a version of eve::algo::for_each_selected.
* The problem with just using eve::algo::for_each_selected for search,
* is that, when used with zip, the tail handling gets expensive.
*
* We want to trade off the tail handling for maybe more false positives.
*/
struct for_each_possibly_matching_for_search_
{
template<typename NeedleWide, typename Equal, typename Verify> struct delegate
{
NeedleWide needle_front;
NeedleWide needle_back;
Equal equal_fn;
Verify& verify;
bool was_stopped = false;

template<typename I> EVE_FORCEINLINE auto make_verify_adapter(I haystack_it)
{
struct res_t
{
Verify& verify;
unaligned_t<I> base;

EVE_FORCEINLINE bool operator()(std::ptrdiff_t i) { return verify(base + i); }
};

return res_t {verify, unalign(haystack_it)};
}

EVE_FORCEINLINE bool tail(auto zip_it, eve::relative_conditional_expr auto ignore)
{
auto front_it = get<0>(zip_it);

// not loading from `zip_it` here, becasue it's much more expensive for tails.
auto haystack_front = eve::load[ignore](front_it);
eve::logical precheck = equal_fn(haystack_front, needle_front);

was_stopped = eve::iterate_selected[ignore](precheck, make_verify_adapter(front_it));
return was_stopped;
}

EVE_FORCEINLINE bool main_part(auto zip_it)
{
auto [haystack_front, haystack_back] = eve::load(zip_it);

eve::logical precheck =
equal_fn(haystack_front, needle_front) && equal_fn(haystack_back, needle_back);
was_stopped = eve::iterate_selected(precheck, make_verify_adapter(get<0>(zip_it)));

return was_stopped;
}

template<eve::relative_conditional_expr C>
EVE_FORCEINLINE bool step(auto zip_it, C ignore, auto /*idx*/)
{
if constexpr( C::is_complete && C::is_inverted ) { return main_part(zip_it); }
else { return tail(zip_it, ignore); }
}

EVE_FORCEINLINE bool unrolled_step(auto arr)
{
return unroll_by_calling_single_step {}(arr, *this);
}
};

template<typename HaystackI,
typename HaystackS,
typename NeedleWide,
typename Equal,
typename Verify>
EVE_FORCEINLINE bool operator()(auto traits,
HaystackI haystack_f,
HaystackS haystack_l,
NeedleWide needle_front,
NeedleWide needle_back,
std::ptrdiff_t needle_len,
Equal equal_fn,
Verify& verify) const
{
auto haystack_front_back_range =
views::zip(as_range(haystack_f, haystack_l), unalign(haystack_f) + (needle_len - 1));

auto iteration = algo::for_each_iteration(
traits, haystack_front_back_range.begin(), haystack_front_back_range.end());
delegate<NeedleWide, Equal, Verify> d {needle_front, needle_back, equal_fn, verify};
iteration(d);
return d.was_stopped;
}
} inline constexpr for_each_possibly_matching_for_search;

}

template<typename TraitsSupport> struct search_ : TraitsSupport
{
template<typename I1, // haystack_iter
Expand All @@ -25,8 +121,7 @@ template<typename TraitsSupport> struct search_ : TraitsSupport
>
struct needle_checker
{
std::ptrdiff_t needle_len;
Equal equal;
Equal equal_fn;

wide_value_type_t<I2> first_wide;
eve::keep_first first_wide_ignore;
Expand All @@ -37,7 +132,7 @@ template<typename TraitsSupport> struct search_ : TraitsSupport

template<typename S2>
needle_checker(I2 f, S2 l, Equal _equal)
: equal(_equal)
: equal_fn(_equal)
, first_wide_ignore(iterator_cardinal_v<I2>)
{
std::ptrdiff_t needle_len = l - f;
Expand Down Expand Up @@ -70,15 +165,15 @@ template<typename TraitsSupport> struct search_ : TraitsSupport
// a register.
bool main_check(unaligned_t<I1> haystack_i) const
{
auto test_first = equal(eve::load(haystack_i), first_wide);
auto test_first = equal_fn(eve::load(haystack_i), first_wide);

if( !eve::all[first_wide_ignore](test_first) ) return false;

haystack_i += long_tail_offset;
auto needle_i = long_tail_start;
for( std::ptrdiff_t count = long_tail_n; count; --count )
{
auto test = equal(eve::load(haystack_i), eve::load(needle_i));
auto test = equal_fn(eve::load(haystack_i), eve::load(needle_i));
if( !eve::all(test) ) return false;

haystack_i += iterator_cardinal_v<I1>;
Expand All @@ -89,10 +184,9 @@ template<typename TraitsSupport> struct search_ : TraitsSupport
return true;
}

// tail handling for small needles, just looping through a register
bool small_check(wide_value_type_t<I1> haystack) const
{
auto test = equal(haystack, first_wide);
auto test = equal_fn(haystack, first_wide);
return eve::all[first_wide_ignore](test);
}
};
Expand All @@ -103,69 +197,100 @@ template<typename TraitsSupport> struct search_ : TraitsSupport
unaligned_t<I1> haystack_main_part_l,
I2 needle_f,
std::ptrdiff_t needle_len,
Equal equal,
Equal equal_fn,
Checker check) const
{
eve::wide_value_type_t<I2> needle_front(eve::read(needle_f));
eve::wide_value_type_t<I2> needle_back(eve::read(eve::unalign(needle_f) + (needle_len - 1)));

unaligned_t<I1> hastack_back_f = eve::unalign(haystack_f) + (needle_len - 1);
struct
{
std::optional<unaligned_t<I1>> res;
Checker check;

std::optional<unaligned_t<I1>> res;
for_each_selected[drop_key(divisible_by_cardinal, traits)](
views::zip(as_range(haystack_f, haystack_main_part_l), hastack_back_f),
[&](auto haystack_front_back)
{
auto [haystack_front, haystack_back] = haystack_front_back;
return equal(haystack_front, needle_front) && equal(haystack_back, needle_back);
},
[&](auto haystack_front_back_it)
EVE_FORCEINLINE bool operator()(unaligned_t<I1> haystack_it)
{
if( check.main_check(haystack_it) )
{
auto [haystack_it, _] = haystack_front_back_it;

if( check.main_check(haystack_it) )
{
res = haystack_it;
return true;
}
return false;
});
return res;
res = haystack_it;
return true;
}
return false;
}
} verify {{}, check};

detail::for_each_possibly_matching_for_search(drop_key(divisible_by_cardinal, traits),
haystack_f,
haystack_main_part_l,
needle_front,
needle_back,
needle_len,
equal_fn,
verify);

return verify.res;
}

template<typename UnalignedI1>
template<typename UnalignedI1, typename I2, typename Checker>
EVE_FORCEINLINE std::optional<UnalignedI1> small_tail(UnalignedI1 small_tail_start,
auto haystack_l,
I2 needle_f,
auto equal_fn,
std::ptrdiff_t needle_len,
auto checker) const
Checker checker) const
{
// no small tail
if( needle_len > eve::iterator_cardinal_v<UnalignedI1> ) return {};

std::ptrdiff_t iterations = (haystack_l - small_tail_start) - needle_len + 1;

auto haystack = eve::load[eve::keep_first(haystack_l - small_tail_start)](small_tail_start);
eve::wide_value_type_t<UnalignedI1> haystack =
eve::load[eve::keep_first(haystack_l - small_tail_start)](small_tail_start);
eve::wide_value_type_t<I2> needle_front(eve::read(needle_f));

for( std::ptrdiff_t i = 0; i != iterations; ++i )
struct verify_t
{
if( checker.small_check(haystack) ) return small_tail_start + i;

// TODO: use shuffle_v2 here.
//
// slide_left shifts in 0s.
// If we were to use shuffle_v2 here, we'd could say `we_`
// instead of 0s - which would be better.
//
// Unfortunately shuffle_v2 can't slide left yet
haystack = eve::slide_left(haystack, eve::index<1>);
}
std::optional<UnalignedI1> res;
UnalignedI1 small_tail_start;

// store small haystack in the stack buffer.
stack_buffer<wide<value_type_t<UnalignedI1>, fixed<2 * iterator_cardinal_v<UnalignedI1>>>>
buf;

Checker checker;

EVE_FORCEINLINE
verify_t(UnalignedI1 _small_tail_start,
wide_value_type_t<UnalignedI1> haystack,
Checker _checker)
: small_tail_start(_small_tail_start)
, checker(_checker)
{
eve::store(haystack, buf.ptr());
}

EVE_FORCEINLINE bool operator()(std::ptrdiff_t i)
{
// We can't slide a register by a runtime value.
// So we store the register on the stack buffer and load instead.
auto to_load = eve::unalign(buf.ptr()) + i;
if( checker.small_check(load(to_load, as<wide_value_type_t<UnalignedI1>> {})) )
{
res = small_tail_start + i;
return true;
}
return false;
}
} verify {small_tail_start, haystack, checker};

std::ptrdiff_t possible_starts = (haystack_l - small_tail_start) - needle_len + 1;

iterate_selected[eve::keep_first(possible_starts)](equal_fn(haystack, needle_front), verify);

return {};
return verify.res;
}

template<relaxed_range R1, relaxed_range R2, typename Equal>
EVE_FORCEINLINE auto
operator()(R1&& haystack, R2&& needle, Equal equal) const -> unaligned_iterator_t<R1>
operator()(R1&& haystack, R2&& needle, Equal equal_fn) const -> unaligned_iterator_t<R1>
{
std::ptrdiff_t needle_len = (needle.end() - needle.begin());
std::ptrdiff_t haystack_len = (haystack.end() - haystack.begin());
Expand All @@ -183,7 +308,7 @@ template<typename TraitsSupport> struct search_ : TraitsSupport
using I2 = decltype(processed_needle.begin());

needle_checker<I1, I2, Equal> needle_checker(
processed_needle.begin(), processed_needle.end(), equal);
processed_needle.begin(), processed_needle.end(), equal_fn);

auto haystack_main_part_l = eve::unalign(haystack_f);

Expand All @@ -196,14 +321,19 @@ template<typename TraitsSupport> struct search_ : TraitsSupport
haystack_main_part_l,
processed_needle.begin(),
needle_len,
equal,
equal_fn,
needle_checker) )
{
return eve::unalign(haystack.begin()) + (*res - haystack_f);
}
}

if( auto res = small_tail(haystack_main_part_l, haystack_l, needle_len, needle_checker) )
if( auto res = small_tail(haystack_main_part_l,
haystack_l,
processed_needle.begin(),
equal_fn,
needle_len,
needle_checker) )
{
return eve::unalign(haystack.begin()) + (*res - haystack_f);
}
Expand Down
2 changes: 1 addition & 1 deletion test/unit/module/algo/algorithm/search_one_generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ TTS_CASE_TPL("eve.algo.find_if_with_search generic", algo_test::selected_types)
[](auto f, auto l, auto expected, auto actual)
{
TTS_EQUAL(actual, expected, REQUIRED)
<< "l - f: " << (l - f) << " expected: " << (expected - f);
<< "l - f: " << (l - f) << " expected: " << (expected - f) << " actual: " << (actual - f);
});
};

0 comments on commit 5895aa6

Please sign in to comment.