From 0d51df17795897bf151ce75484fcd39ebcac34fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=91=D1=80=D0=B0=D0=BD=D0=B8=D0=BC=D0=B8=D1=80=20=D0=9A?= =?UTF-8?q?=D0=B0=D1=80=D0=B0=D1=9F=D0=B8=D1=9B?= Date: Wed, 19 Apr 2023 19:01:00 -0700 Subject: [PATCH] Fixed rsqrt, and sqrt. Added more tests. --- include/bx/inline/math.inl | 53 ++++++++++++++++-------------------- tests/math_test.cpp | 56 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 77 insertions(+), 32 deletions(-) diff --git a/include/bx/inline/math.inl b/include/bx/inline/math.inl index f73a548..19fb03d 100644 --- a/include/bx/inline/math.inl +++ b/include/bx/inline/math.inl @@ -212,17 +212,27 @@ namespace bx inline BX_CONST_FUNC float rsqrtRef(float _a) { + if (_a < kNearZero) + { + return kInfinity; + } + return pow(_a, -0.5f); } inline BX_CONST_FUNC float sqrtRef(float _a) { - if (_a < kNearZero) - { - return 0.0f; - } + return _a*pow(_a, -0.5f); + } - return 1.0f/rsqrtRef(_a); + inline BX_CONST_FUNC float rsqrtSimd(float _a) + { + const simd128_t aa = simd_splat(_a); + const simd128_t rsqrta = simd_rsqrt_nr(aa); + float result; + simd_stx(&result, rsqrta); + + return result; } inline BX_CONST_FUNC float sqrtSimd(float _a) @@ -235,30 +245,6 @@ namespace bx return result; } - inline BX_CONST_FUNC float sqrt(float _a) - { -#if BX_CONFIG_SUPPORTS_SIMD - return sqrtSimd(_a); -#else - return sqrtRef(_a); -#endif // BX_CONFIG_SUPPORTS_SIMD - } - - inline BX_CONST_FUNC float rsqrtSimd(float _a) - { - if (_a < kNearZero) - { - return 0.0f; - } - - const simd128_t aa = simd_splat(_a); - const simd128_t rsqrta = simd_rsqrt_nr(aa); - float result; - simd_stx(&result, rsqrta); - - return result; - } - inline BX_CONST_FUNC float rsqrt(float _a) { #if BX_CONFIG_SUPPORTS_SIMD @@ -268,6 +254,15 @@ namespace bx #endif // BX_CONFIG_SUPPORTS_SIMD } + inline BX_CONST_FUNC float sqrt(float _a) + { +#if BX_CONFIG_SUPPORTS_SIMD + return sqrtSimd(_a); +#else + return sqrtRef(_a); +#endif // BX_CONFIG_SUPPORTS_SIMD + } + inline BX_CONSTEXPR_FUNC float trunc(float _a) { return float(int(_a) ); diff --git a/tests/math_test.cpp b/tests/math_test.cpp index 7a1de3f..702dbd1 100644 --- a/tests/math_test.cpp +++ b/tests/math_test.cpp @@ -106,16 +106,42 @@ TEST_CASE("libm", "") REQUIRE(bx::isEqual(bx::exp(xx), ::expf(xx), 0.00001f) ); } - for (float xx = 0.0f; xx < 100.0f; xx += 0.1f) + // rsqrt + REQUIRE(bx::isInfinite(1.0f/::sqrtf(0.0f) ) ); + REQUIRE(bx::isInfinite(bx::rsqrt(0.0f) ) ); + + for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) { bx::write(writer, &err, "rsqrt(%f) == %f (expected: %f)\n", xx, bx::rsqrt(xx), 1.0f/::sqrtf(xx) ); REQUIRE(err.isOk() ); REQUIRE(bx::isEqual(bx::rsqrt(xx), 1.0f/::sqrtf(xx), 0.00001f) ); } + // rsqrtRef + REQUIRE(bx::isInfinite(bx::rsqrtRef(0.0f) ) ); + + for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) + { + bx::write(writer, &err, "rsqrtRef(%f) == %f (expected: %f)\n", xx, bx::rsqrtRef(xx), 1.0f/::sqrtf(xx) ); + REQUIRE(err.isOk() ); + REQUIRE(bx::isEqual(bx::rsqrtRef(xx), 1.0f/::sqrtf(xx), 0.00001f) ); + } + + // rsqrtSimd + REQUIRE(bx::isInfinite(bx::rsqrtSimd(0.0f) ) ); + + for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) + { + bx::write(writer, &err, "rsqrtSimd(%f) == %f (expected: %f)\n", xx, bx::rsqrtSimd(xx), 1.0f/::sqrtf(xx) ); + REQUIRE(err.isOk() ); + REQUIRE(bx::isEqual(bx::rsqrtSimd(xx), 1.0f/::sqrtf(xx), 0.00001f) ); + } + + // sqrt + REQUIRE(bx::isNan(::sqrtf(-1.0f) ) ); REQUIRE(bx::isNan(bx::sqrt(-1.0f) ) ); - REQUIRE(bx::isEqual(bx::sqrt(0.0f), 0.0f, 0.0f) ); - REQUIRE(bx::isEqual(bx::sqrt(1.0f), 1.0f, 0.0f) ); + REQUIRE(bx::isEqual(bx::sqrt(0.0f), ::sqrtf(0.0f), 0.0f) ); + REQUIRE(bx::isEqual(bx::sqrt(1.0f), ::sqrtf(1.0f), 0.0f) ); for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) { @@ -124,6 +150,30 @@ TEST_CASE("libm", "") REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f) ); } + // sqrtRef + REQUIRE(bx::isNan(bx::sqrtRef(-1.0f) ) ); + REQUIRE(bx::isEqual(bx::sqrtRef(0.0f), ::sqrtf(0.0f), 0.0f) ); + REQUIRE(bx::isEqual(bx::sqrtRef(1.0f), ::sqrtf(1.0f), 0.0f) ); + + for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) + { + bx::write(writer, &err, "sqrtRef(%f) == %f (expected: %f)\n", xx, bx::sqrtRef(xx), ::sqrtf(xx) ); + REQUIRE(err.isOk() ); + REQUIRE(bx::isEqual(bx::sqrtRef(xx), ::sqrtf(xx), 0.00001f) ); + } + + // sqrtSimd + REQUIRE(bx::isNan(bx::sqrtSimd(-1.0f) ) ); + REQUIRE(bx::isEqual(bx::sqrtSimd(0.0f), ::sqrtf(0.0f), 0.0f) ); + REQUIRE(bx::isEqual(bx::sqrtSimd(1.0f), ::sqrtf(1.0f), 0.0f) ); + + for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) + { + bx::write(writer, &err, "sqrtSimd(%f) == %f (expected: %f)\n", xx, bx::sqrtSimd(xx), ::sqrtf(xx) ); + REQUIRE(err.isOk() ); + REQUIRE(bx::isEqual(bx::sqrtSimd(xx), ::sqrtf(xx), 0.00001f) ); + } + for (float xx = 0.0f; xx < 100.0f; xx += 0.1f) { bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx) );