From a8545d7b5bc6b18170590765d3bd00a59758111e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=91=D1=80=D0=B0=D0=BD=D0=B8=D0=BC=D0=B8=D1=80=20=D0=9A?= =?UTF-8?q?=D0=B0=D1=80=D0=B0=D1=9F=D0=B8=D1=9B?= Date: Sun, 21 Aug 2022 16:21:51 -0700 Subject: [PATCH] Added lower/upperBound, templatizes comparison functions. --- include/bx/inline/sort.inl | 86 ++++++++++++++ include/bx/sort.h | 237 ++++++++++++++++++++++++++++++++----- src/sort.cpp | 48 ++++++++ tests/sort_test.cpp | 106 ++++++++++++++--- 4 files changed, 429 insertions(+), 48 deletions(-) diff --git a/include/bx/inline/sort.inl b/include/bx/inline/sort.inl index 5dc75dd..1a17644 100644 --- a/include/bx/inline/sort.inl +++ b/include/bx/inline/sort.inl @@ -9,6 +9,92 @@ namespace bx { + template + inline int32_t compareAscending(const void* _lhs, const void* _rhs) + { + const Ty lhs = *static_cast(_lhs); + const Ty rhs = *static_cast(_rhs); + return (lhs > rhs) - (lhs < rhs); + } + + template + inline int32_t compareDescending(const void* _lhs, const void* _rhs) + { + return compareAscending(_rhs, _lhs); + } + + template<> + inline int32_t compareAscending(const void* _lhs, const void* _rhs) + { + return strCmp(*(const char**)_lhs, *(const char**)_rhs); + } + + template<> + inline int32_t compareAscending(const void* _lhs, const void* _rhs) + { + return strCmp(*(const StringView*)_lhs, *(const StringView*)_rhs); + } + + template + void quickSort(Ty* _data, uint32_t _num, const ComparisonFn _fn) + { + quickSort( (void*)_data, _num, sizeof(Ty), _fn); + } + + template + void quickSort(void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn) + { + quickSort(_data, _num, _stride, _fn); + } + + template + uint32_t lowerBound(const Ty& _key, const Ty* _data, uint32_t _num, const ComparisonFn _fn) + { + return lowerBound( (const void*)&_key, _data, _num, sizeof(Ty), _fn); + } + + template + uint32_t lowerBound(const Ty& _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn) + { + return lowerBound( (const void*)&_key, _data, _num, _stride, _fn); + } + + template + uint32_t upperBound(const Ty& _key, const Ty* _data, uint32_t _num, const ComparisonFn _fn) + { + return upperBound( (const void*)&_key, _data, _num, sizeof(Ty), _fn); + } + + template + uint32_t upperBound(const Ty& _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn) + { + return upperBound( (const void*)&_key, _data, _num, _stride, _fn); + } + + template + bool isSorted(const Ty* _data, uint32_t _num, const ComparisonFn _fn) + { + return isSorted(_data, _num, sizeof(Ty), _fn); + } + + template + bool isSorted(const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn) + { + return isSorted(_data, _num, _stride, _fn); + } + + template + int32_t binarySearch(const Ty& _key, const Ty* _data, uint32_t _num, const ComparisonFn _fn) + { + return binarySearch( (const void*)&_key, _data, _num, sizeof(Ty), _fn); + } + + template + int32_t binarySearch(const Ty& _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn) + { + return binarySearch( (const void*)&_key, _data, _num, _stride, _fn); + } + namespace radix_sort_detail { constexpr uint32_t kBits = 11; diff --git a/include/bx/sort.h b/include/bx/sort.h index 408a0b7..a6f1fd0 100644 --- a/include/bx/sort.h +++ b/include/bx/sort.h @@ -7,6 +7,8 @@ #define BX_SORT_H_HEADER_GUARD #include "bx.h" +#include "math.h" +#include "string.h" namespace bx { @@ -19,6 +21,26 @@ namespace bx /// typedef int32_t (*ComparisonFn)(const void* _lhs, const void* _rhs); + /// The function compares the `_lhs` and `_rhs` values. + /// + /// @returns Returns value: + /// - less than zero if `_lhs` is less than `_rhs` + /// - zero if `_lhs` is equivalent to `_rhs` + /// - greater than zero if `_lhs` is greater than `_rhs` + /// + template + int32_t compareAscending(const void* _lhs, const void* _rhs); + + /// The function compares the `_lhs` and `_rhs` values. + /// + /// @returns Returns value: + /// - less than zero if `_lhs` is greated than `_rhs` + /// - zero if `_lhs` is equivalent to `_rhs` + /// - greater than zero if `_lhs` is less than `_rhs` + /// + template + int32_t compareDescending(const void* _lhs, const void* _rhs); + /// Performs sort (Quick Sort algorithm). /// /// @param _data Pointer to sorted array data. @@ -33,39 +55,24 @@ namespace bx , const ComparisonFn _fn ); + /// Performs sort (Quick Sort algorithm). /// - void radixSort( - uint32_t* _keys - , uint32_t* _tempKeys - , uint32_t _size - ); + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _fn Comparison function. + /// + template + void quickSort(Ty* _data, uint32_t _num, const ComparisonFn _fn = compareAscending); + /// Performs sort (Quick Sort algorithm). /// - template - void radixSort( - uint32_t* _keys - , uint32_t* _tempKeys - , Ty* _values - , Ty* _tempValues - , uint32_t _size - ); - + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _stride Element stride in bytes. + /// @param _fn Comparison function. /// - void radixSort( - uint64_t* _keys - , uint64_t* _tempKeys - , uint32_t _size - ); - - /// - template - void radixSort( - uint64_t* _keys - , uint64_t* _tempKeys - , Ty* _values - , Ty* _tempValues - , uint32_t _size - ); + template + void quickSort(void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn = compareAscending); /// Performs check if array is sorted. /// @@ -83,6 +90,115 @@ namespace bx , const ComparisonFn _fn ); + /// Performs check if array is sorted. + /// + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _fn Comparison function. + /// + /// @returns Returns `true` if array is sorted, otherwise returns `false`. + /// + template + bool isSorted(const Ty* _data, uint32_t _num, const ComparisonFn _fn = compareAscending); + + /// Performs check if array is sorted. + /// + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _stride Element stride in bytes. + /// @param _fn Comparison function. + /// + /// @returns Returns `true` if array is sorted, otherwise returns `false`. + /// + template + bool isSorted(const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn = compareAscending); + + /// Returns an index to the first element greater or equal than the `_key` value. + /// + /// @param _key Pointer to the key to search for. + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _stride Element stride in bytes. + /// @param _fn Comparison function. + /// + /// @remarks Array must be sorted! + /// + /// @returns Returns an index to the first element greater or equal than the `_key` value. + /// + uint32_t lowerBound(const void* _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn); + + /// Returns an index to the first element greater or equal than the `_key` value. + /// + /// @param _key Pointer to the key to search for. + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _fn Comparison function. + /// + /// @remarks Array must be sorted! + /// + /// @returns Returns an index to the first element greater or equal than the `_key` value. + /// + template + uint32_t lowerBound(const Ty& _key, const Ty* _data, uint32_t _num, const ComparisonFn _fn = compareAscending); + + /// Returns an index to the first element greater or equal than the `_key` value. + /// + /// @param _key Pointer to the key to search for. + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _stride Element stride in bytes. + /// @param _fn Comparison function. + /// + /// @remarks Array must be sorted! + /// + /// @returns Returns an index to the first element greater or equal than the `_key` value. + /// + template + uint32_t lowerBound(const Ty& _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn = compareAscending); + + /// Returns an index to the first element greater than the `_key` value. + /// + /// @param _key Pointer to the key to search for. + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _stride Element stride in bytes. + /// @param _fn Comparison function. + /// + /// @remarks Array must be sorted! + /// + /// @returns Returns an index to the first element greater than the `_key` value. + /// + uint32_t upperBound(const void* _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn); + + /// Returns an index to the first element greater than the `_key` value. + /// + /// @param _key Pointer to the key to search for. + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _fn Comparison function. + /// + /// @remarks Array must be sorted! + /// + /// @returns Returns an index to the first element greater than the `_key` value. + /// + template + uint32_t upperBound(const Ty& _key, const Ty* _data, uint32_t _num, const ComparisonFn _fn = compareAscending); + + /// Returns an index to the first element greater than the `_key` value. + /// + /// @param _key Pointer to the key to search for. + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _stride Element stride in bytes. + /// @param _fn Comparison function. + /// + /// @remarks Array must be sorted! + /// + /// @returns Returns an index to the first element greater than the `_key` value. + /// + template + uint32_t upperBound(const Ty& _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn = compareAscending); + /// Performs binary search of a sorted array. /// /// @param _key Pointer to the key to search for. @@ -103,6 +219,69 @@ namespace bx , const ComparisonFn _fn ); + /// Performs binary search of a sorted array. + /// + /// @param _key Pointer to the key to search for. + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _fn Comparison function. + /// + /// @remarks Array must be sorted! + /// + /// @returns Returns index of element or -1 if the key is not found in sorted array. + /// + template + int32_t binarySearch(const Ty& _key, const Ty* _data, uint32_t _num, const ComparisonFn _fn = compareAscending); + + /// Performs binary search of a sorted array. + /// + /// @param _key Pointer to the key to search for. + /// @param _data Pointer to sorted array data. + /// @param _num Number of elements. + /// @param _stride Element stride in bytes. + /// @param _fn Comparison function. + /// + /// @remarks Array must be sorted! + /// + /// @returns Returns index of element or -1 if the key is not found in sorted array. + /// + template + int32_t binarySearch(const Ty& _key, const void* _data, uint32_t _num, uint32_t _stride = sizeof(Ty), const ComparisonFn _fn = compareAscending); + + /// + void radixSort( + uint32_t* _keys + , uint32_t* _tempKeys + , uint32_t _size + ); + + /// + template + void radixSort( + uint32_t* _keys + , uint32_t* _tempKeys + , Ty* _values + , Ty* _tempValues + , uint32_t _size + ); + + /// + void radixSort( + uint64_t* _keys + , uint64_t* _tempKeys + , uint32_t _size + ); + + /// + template + void radixSort( + uint64_t* _keys + , uint64_t* _tempKeys + , Ty* _values + , Ty* _tempValues + , uint32_t _size + ); + } // namespace bx #include "inline/sort.inl" diff --git a/src/sort.cpp b/src/sort.cpp index 25f7ea9..caba004 100644 --- a/src/sort.cpp +++ b/src/sort.cpp @@ -68,6 +68,54 @@ namespace bx return true; } + uint32_t lowerBound(const void* _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn) + { + uint32_t offset = 0; + const uint8_t* data = (uint8_t*)_data; + + for (uint32_t ll = _num; offset < ll;) + { + const uint32_t idx = (offset + ll) / 2; + + int32_t result = _fn(_key, &data[idx * _stride]); + + if (result <= 0) + { + ll = idx; + } + else + { + offset = idx + 1; + } + } + + return offset; + } + + uint32_t upperBound(const void* _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn) + { + uint32_t offset = 0; + const uint8_t* data = (uint8_t*)_data; + + for (uint32_t ll = _num; offset < ll;) + { + const uint32_t idx = (offset + ll) / 2; + + int32_t result = _fn(_key, &data[idx * _stride]); + + if (result < 0) + { + ll = idx; + } + else + { + offset = idx + 1; + } + } + + return offset; + } + int32_t binarySearch(const void* _key, const void* _data, uint32_t _num, uint32_t _stride, const ComparisonFn _fn) { uint32_t offset = 0; diff --git a/tests/sort_test.cpp b/tests/sort_test.cpp index feb4115..5681afc 100644 --- a/tests/sort_test.cpp +++ b/tests/sort_test.cpp @@ -18,23 +18,16 @@ TEST_CASE("quickSort", "") "jagoda", }; - auto strCmpFn = [](const void* _lhs, const void* _rhs) - { - const char* lhs = *(const char**)_lhs; - const char* rhs = *(const char**)_rhs; - return bx::strCmp(lhs, rhs); - }; + REQUIRE(!bx::isSorted(str, BX_COUNTOF(str) ) ); - REQUIRE(!bx::isSorted(str, BX_COUNTOF(str), sizeof(str[0]), strCmpFn) ); - - bx::quickSort(str, BX_COUNTOF(str), sizeof(str[0]), strCmpFn); + bx::quickSort(str, BX_COUNTOF(str) ); REQUIRE(0 == bx::strCmp(str[0], "jabuka") ); REQUIRE(0 == bx::strCmp(str[1], "jagoda") ); REQUIRE(0 == bx::strCmp(str[2], "kruska") ); REQUIRE(0 == bx::strCmp(str[3], "malina") ); - REQUIRE(bx::isSorted(str, BX_COUNTOF(str), sizeof(str[0]), strCmpFn) ); + REQUIRE(bx::isSorted(str, BX_COUNTOF(str) ) ); auto bsearchStrCmpFn = [](const void* _lhs, const void* _rhs) { @@ -50,12 +43,17 @@ TEST_CASE("quickSort", "") REQUIRE( 3 == bx::binarySearch("malina", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); REQUIRE(-1 == bx::binarySearch("kupina", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); - auto byteCmpFn = [](const void* _lhs, const void* _rhs) - { - int8_t lhs = *(const int8_t*)_lhs; - int8_t rhs = *(const int8_t*)_rhs; - return lhs - rhs; - }; + REQUIRE( 0 == bx::lowerBound("jabuka", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); + REQUIRE( 1 == bx::upperBound("jabuka", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); + + REQUIRE( 1 == bx::lowerBound("jagoda", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); + REQUIRE( 2 == bx::upperBound("jagoda", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); + + REQUIRE( 2 == bx::lowerBound("kruska", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); + REQUIRE( 3 == bx::upperBound("kruska", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); + + REQUIRE( 3 == bx::lowerBound("malina", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); + REQUIRE( 4 == bx::upperBound("malina", str, BX_COUNTOF(str), sizeof(str[0]), bsearchStrCmpFn) ); int8_t byte[128]; bx::RngMwc rng; @@ -64,14 +62,84 @@ TEST_CASE("quickSort", "") byte[ii] = rng.gen()&0xff; } - REQUIRE(!bx::isSorted(byte, BX_COUNTOF(byte), sizeof(byte[0]), byteCmpFn) ); + REQUIRE(!bx::isSorted(byte, BX_COUNTOF(byte) ) ); - bx::quickSort(byte, BX_COUNTOF(byte), sizeof(byte[0]), byteCmpFn); + bx::quickSort(byte, BX_COUNTOF(byte) ); for (uint32_t ii = 1; ii < BX_COUNTOF(byte); ++ii) { REQUIRE(byte[ii-1] <= byte[ii]); } - REQUIRE(bx::isSorted(byte, BX_COUNTOF(byte), sizeof(byte[0]), byteCmpFn) ); + REQUIRE(bx::isSorted(byte, BX_COUNTOF(byte) ) ); +} + +TEST_CASE("lower/upperBound int32_t", "") +{ + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 | 14 + const int32_t test[] = { 100, 101, 101, 101, 103, 104, 105, 105, 105, 106, 106, 107, 108, 109 }; + REQUIRE(bx::isSorted(test, BX_COUNTOF(test) ) ); + + const uint32_t resultLowerBound[] = { 0, 1, 4, 4, 5, 6, 9, 11, 12, 13 }; + const uint32_t resultUpperBound[] = { 1, 4, 4, 5, 6, 9, 11, 12, 13, 14 }; + + static_assert(10 == BX_COUNTOF(resultLowerBound) ); + static_assert(10 == BX_COUNTOF(resultUpperBound) ); + + for (int32_t key = test[0], keyMax = test[BX_COUNTOF(test)-1], ii = 0; key <= keyMax; ++key, ++ii) + { + REQUIRE(resultLowerBound[ii] == bx::lowerBound(key, test, BX_COUNTOF(test) ) ); + REQUIRE(resultUpperBound[ii] == bx::upperBound(key, test, BX_COUNTOF(test) ) ); + } +} + +template +int32_t compareAscendingTest(const Ty& _lhs, const Ty& _rhs) +{ + return bx::compareAscending(&_lhs, &_rhs); +} + +template +int32_t compareDescendingTest(const Ty& _lhs, const Ty& _rhs) +{ + return bx::compareDescending(&_lhs, &_rhs); +} + +template +void compareTest(const Ty& _min, const Ty& _max) +{ + REQUIRE(_min < _max); + + REQUIRE(-1 == compareAscendingTest(std::numeric_limits::min(), std::numeric_limits::max() ) ); + REQUIRE(-1 == compareAscendingTest(Ty(0), std::numeric_limits::max() ) ); + REQUIRE( 0 == compareAscendingTest(std::numeric_limits::min(), std::numeric_limits::min() ) ); + REQUIRE( 0 == compareAscendingTest(std::numeric_limits::max(), std::numeric_limits::max() ) ); + REQUIRE( 1 == compareAscendingTest(std::numeric_limits::max(), Ty(0) ) ); + REQUIRE( 1 == compareAscendingTest(std::numeric_limits::max(), std::numeric_limits::min() ) ); + + REQUIRE(-1 == compareAscendingTest(_min, _max) ); + REQUIRE( 0 == compareAscendingTest(_min, _min) ); + REQUIRE( 0 == compareAscendingTest(_max, _max) ); + REQUIRE( 1 == compareAscendingTest(_max, _min) ); + + REQUIRE( 1 == compareDescendingTest(_min, _max) ); + REQUIRE( 0 == compareDescendingTest(_min, _min) ); + REQUIRE( 0 == compareDescendingTest(_max, _max) ); + REQUIRE(-1 == compareDescendingTest(_max, _min) ); +} + +TEST_CASE("ComparisonFn", "") +{ + compareTest< int8_t>( -13, 89); + compareTest(-1389, 1389); + compareTest(-1389, 1389); + compareTest(-1389, 1389); + + compareTest< uint8_t>( 13, 89); + compareTest( 13, 1389); + compareTest( 13, 1389); + compareTest( 13, 1389); + + compareTest< float>(-13.89f, 1389.0f); + compareTest(-13.89f, 1389.0f); }