I wrote a merge sort implementation in C++ today in C++20 way.
namespace frozenca::hard {
using namespace std;
namespace {
struct merge_sort_func {
template <bidirectional_iterator Iter, sentinel_for<Iter> Sentinel,
typename Comp = ranges::less, typename Proj = identity>
requires sortable<Iter, Comp, Proj>
constexpr Iter merge_impl(Iter first, Sentinel last,
iter_value_t<Iter> *temp_first,
iter_value_t<Iter> *temp_middle,
iter_value_t<Iter> *temp_last, Comp comp = {},
Proj proj = {}) const {
uninitialized_move(first, last, temp_first);
auto l_curr = temp_first;
auto r_curr = temp_middle;
auto A_curr = first;
while (l_curr != temp_middle && r_curr != temp_last) {
if (invoke(comp, invoke(proj, *l_curr), invoke(proj, *r_curr))) {
*A_curr = move(*l_curr);
++l_curr;
} else {
*A_curr = move(*r_curr);
++r_curr;
}
++A_curr;
}
while (l_curr != temp_middle) {
*A_curr = move(*l_curr);
++l_curr;
++A_curr;
}
while (r_curr != temp_last) {
*A_curr = move(*r_curr);
++r_curr;
++A_curr;
}
return A_curr;
}
template <bidirectional_iterator Iter, sentinel_for<Iter> Sentinel,
typename Comp = ranges::less, typename Proj = identity>
requires sortable<Iter, Comp, Proj>
constexpr Iter operator()(Iter first, Sentinel last, Comp comp = {},
Proj proj = {},
iter_value_t<Iter> *temp_buffer = nullptr) const {
const auto len = ranges::distance(first, last);
assert(len >= 0);
if (len < 2) {
return last;
}
using value_t = iter_value_t<Iter>;
bool to_delete = false;
if (!temp_buffer) {
temp_buffer = new value_t[len];
to_delete = true;
}
const auto mid = next(first, len / 2);
(*this)(first, mid, ref(comp), ref(proj), temp_buffer);
(*this)(mid, last, ref(comp), ref(proj), temp_buffer + (len / 2));
const auto ret =
merge_impl(first, last, temp_buffer, temp_buffer + (len / 2),
temp_buffer + len, move(comp), move(proj));
if (to_delete) {
delete[] temp_buffer;
}
return ret;
}
template <ranges::bidirectional_range Range, typename Comp = ranges::less,
typename Proj = identity>
requires sortable<ranges::iterator_t<Range>, Comp, Proj>
constexpr auto operator()(Range &&r, Comp comp = {}, Proj proj = {}) const {
using value_t = ranges::range_value_t<Range>;
value_t *temp_buffer = new value_t[ranges::size(r)];
const auto ret = (*this)(ranges::begin(r), ranges::end(r), move(comp),
move(proj), temp_buffer);
delete[] temp_buffer;
return ret;
}
};
} // anonymous namespace
inline constexpr merge_sort_func merge_sort{};
} // namespace frozenca::hard
I wrote some logging, unit test, performance benchmark code like this:
namespace frozenca {
using namespace std;
enum class log_level {
debug,
error,
all,
};
static const map<log_level, string> log_level_str = {{log_level::debug, "[D]"},
{log_level::error, "[E]"}};
#ifdef NDEBUG
static constexpr log_level curr_log_level = log_level::error;
#else
static constexpr log_level curr_log_level = log_level::debug;
#endif
namespace {
template <typename... Args>
constexpr void log_impl(const string_view message, log_level level,
const source_location location, ostream &os,
Args &&...args) {
string formatted_message = vformat(message, make_format_args(args...));
if (level >= curr_log_level) {
if (level == log_level::all) {
os << formatted_message << '\n';
} else {
filesystem::path path = filesystem::canonical(location.file_name());
os << log_level_str.at(level) << ":" << path << " (" << location.line()
<< ":" << location.column() << ") " << location.function_name()
<< " : " << formatted_message << '\n';
}
}
}
} // anonymous namespace
template <typename... Args>
constexpr void log(const string_view message,
log_level level = log_level::debug, Args &&...args) {
ostream& log_stream = level == log_level::all ? cout : clog;
log_impl(message, level, source_location::current(), log_stream, args...);
}
template <ranges::input_range R> void print(R &&r, ostream &os = cout) {
for (auto elem : r) {
os << elem << ' ';
}
os << '\n';
}
mt19937 gen(random_device{}());
namespace {
template <ranges::forward_range Range, typename Func1, typename Func2,
typename... Args>
requires regular_invocable<Func1, Range, Args...> &&
regular_invocable<Func2, Range, Args...>
void range_verify(Func1 &&f1, Func2 &&f2, int num_trials, int max_length,
Args &&...args) {
uniform_int_distribution<> len_dist(0, max_length);
for (int i = 0; i < num_trials; ++i) {
Range v;
int n = len_dist(gen);
generate_n(back_inserter(v), n, ref(gen));
f1(v, args...);
if (!f2(v, args...)) {
throw runtime_error("Verification failed");
}
}
log("Verification success!\n", log_level::all);
}
template <ranges::forward_range Range, typename Func, typename... Args>
requires regular_invocable<Func, Range, Args...>
void range_check_perf(Func &&f, int num_trials, const vector<int> &max_lengths,
Args &&...args) {
for (auto max_length : max_lengths) {
chrono::duration<double, micro> curr_length_duration(0);
uniform_int_distribution<> len_dist(0, max_length);
for (int i = 0; i < num_trials; ++i) {
Range v;
int n = len_dist(gen);
generate_n(back_inserter(v), n, ref(gen));
auto start = chrono::steady_clock::now();
f(v, args...);
auto end = chrono::steady_clock::now();
curr_length_duration += (end - start);
}
log("Time to process a range of {:6} elements : {:10.4f} us\n",
log_level::all, max_length,
(curr_length_duration.count() / num_trials));
}
}
} // anonymous namespace
template <ranges::forward_range Range = vector<int>, typename Func,
typename Comp = ranges::less, typename Proj = identity>
requires sortable<ranges::iterator_t<Range>, Comp, Proj> &&
regular_invocable<Func, Range, Comp, Proj>
void verify_sorting(Func &&f, int num_trials = 1'000, int max_length = 1'000,
Comp comp = {}, Proj proj = {}) {
range_verify<Range>(f, ranges::is_sorted, num_trials, max_length, comp, proj);
}
template <ranges::forward_range Range = vector<int>, typename Func,
typename Comp = ranges::less, typename Proj = identity>
requires sortable<ranges::iterator_t<Range>, Comp, Proj> &&
regular_invocable<Func, Range, Comp, Proj>
void perf_check_sorting(Func &&f, int num_trials = 1'000,
const vector<int> &max_lengths = {10, 30, 100, 300,
1'000, 3'000, 10'000},
Comp comp = {}, Proj proj = {}) {
range_check_perf<Range>(f, num_trials, max_lengths, comp, proj);
}
} // namespace frozenca
And I did a verification whether my code actually correctly sorts the range, and I compared my merge sort performance with std::ranges::sort, with this code:
int main() {
namespace fc = frozenca;
using namespace std;
{
vector<int> v{2, 3, 1, 6, 5, 4};
fc::hard::merge_sort(v);
fc::print(v);
fc::verify_sorting(ranges::sort);
fc::verify_sorting(fc::hard::merge_sort);
fc::perf_check_sorting(ranges::sort);
fc::perf_check_sorting(fc::hard::merge_sort);
}
}
The result: (MSVC 19.31 /Ox)
1 2 3 4 5 6
Verification success!
// This is std::ranges::sort. For each k, for processing k elements, 10000 time averaged
Time to process a range of 10 elements : 0.1045 us
Time to process a range of 30 elements : 0.2472 us
Time to process a range of 100 elements : 1.2097 us
Time to process a range of 300 elements : 4.8376 us
Time to process a range of 1000 elements : 20.0194 us
Time to process a range of 3000 elements : 76.9456 us
Time to process a range of 10000 elements : 282.3116 us
// this is my merge sort
Time to process a range of 10 elements : 0.2778 us
Time to process a range of 30 elements : 0.8240 us
Time to process a range of 100 elements : 2.5889 us
Time to process a range of 300 elements : 8.4630 us
Time to process a range of 1000 elements : 31.1639 us
Time to process a range of 3000 elements : 97.9112 us
Time to process a range of 10000 elements : 369.3333 us
This performance is not terrible, but I feel still not efficient.
How can I improve both my code quality and performance?
EDIT:
More comments: I referred https://github.com/microsoft/STL/blob/main/stl/inc/algorithm#L7070-L7110 to rewrite my merge_impl with something almost same with MSVC std::inplace_merge, but its performance became much much worse. (If I use that sorting 10000 length vector<int> will take around 500us in average)
My merge_impl is faster than MSVC std::inplace_merge implementation (but MSVC implementation allocates temporary buffer only if necessary, and amount of usage of temporary buffer is smaller, so it's a tradeoff)
I think MSVC std::ranges::sort is very difficult to beat with merge sort.