diff --git a/include/bx/inline/math.inl b/include/bx/inline/math.inl index ca73eb4..331490e 100644 --- a/include/bx/inline/math.inl +++ b/include/bx/inline/math.inl @@ -220,27 +220,52 @@ namespace bx return pow(_a, -0.5f); } - inline BX_CONST_FUNC float sqrtRef(float _a) - { - return _a*pow(_a, -0.5f); - } - inline BX_CONST_FUNC float rsqrtSimd(float _a) { - const simd128_t aa = simd_splat(_a); + if (_a < kNearZero) + { + return kFloatInfinity; + } + + const simd128_t aa = simd_splat(_a); +#if BX_SIMD_NEON const simd128_t rsqrta = simd_rsqrt_nr(aa); +#else + const simd128_t rsqrta = simd_rsqrt_ni(aa); +#endif // BX_SIMD_NEON + float result; simd_stx(&result, rsqrta); return result; } + inline BX_CONST_FUNC float sqrtRef(float _a) + { + if (_a < 0.0F) + { + return bitsToFloat(kFloatExponentMask | kFloatMantissaMask); + } + + return _a * pow(_a, -0.5f); + } + inline BX_CONST_FUNC float sqrtSimd(float _a) { - const simd128_t aa = simd_splat(_a); - const simd128_t sqrta = simd_sqrt(aa); + if (_a < 0.0F) + { + return bitsToFloat(kFloatExponentMask | kFloatMantissaMask); + } + else if (_a < kNearZero) + { + return 0.0f; + } + + const simd128_t aa = simd_splat(_a); + const simd128_t sqrt = simd_sqrt(aa); + float result; - simd_stx(&result, sqrta); + simd_stx(&result, sqrt); return result; } diff --git a/tests/math_test.cpp b/tests/math_test.cpp index ed18551..b4f63fd 100644 --- a/tests/math_test.cpp +++ b/tests/math_test.cpp @@ -72,6 +72,99 @@ TEST_CASE("log2", "") REQUIRE(8 == bx::log2(256) ); } +BX_PRAGMA_DIAGNOSTIC_PUSH(); +BX_PRAGMA_DIAGNOSTIC_IGNORED_MSVC(4723) // potential divide by 0 + +TEST_CASE("libm sqrt", "") +{ + bx::WriterI* writer = bx::getNullOut(); + bx::Error err; + + // rsqrtRef + REQUIRE(bx::isInfinite(bx::rsqrtRef(0.0f))); + + for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) + { + bx::write(writer, &err, "rsqrtRef(%f) == %f (expected: %f)\n", xx, bx::rsqrtRef(xx), 1.0f / ::sqrtf(xx)); + REQUIRE(err.isOk()); + REQUIRE(bx::isEqual(bx::rsqrtRef(xx), 1.0f / ::sqrtf(xx), 0.00001f)); + } + + // rsqrtSimd + REQUIRE(bx::isInfinite(bx::rsqrtSimd(0.0f))); + + for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) + { + bx::write(writer, &err, "rsqrtSimd(%f) == %f (expected: %f)\n", xx, bx::rsqrtSimd(xx), 1.0f / ::sqrtf(xx)); + REQUIRE(err.isOk()); + REQUIRE(bx::isEqual(bx::rsqrtSimd(xx), 1.0f / ::sqrtf(xx), 0.00001f)); + } + + // rsqrt + REQUIRE(bx::isInfinite(1.0f / ::sqrtf(0.0f))); + REQUIRE(bx::isInfinite(bx::rsqrt(0.0f))); + + for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) + { + bx::write(writer, &err, "rsqrt(%f) == %f (expected: %f)\n", xx, bx::rsqrt(xx), 1.0f / ::sqrtf(xx)); + REQUIRE(err.isOk()); + REQUIRE(bx::isEqual(bx::rsqrt(xx), 1.0f / ::sqrtf(xx), 0.00001f)); + } + + // sqrtRef + REQUIRE(bx::isNan(bx::sqrtRef(-1.0f))); + REQUIRE(bx::isEqual(bx::sqrtRef(0.0f), ::sqrtf(0.0f), 0.0f)); + REQUIRE(bx::isEqual(bx::sqrtRef(1.0f), ::sqrtf(1.0f), 0.0f)); + + for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) + { + bx::write(writer, &err, "sqrtRef(%f) == %f (expected: %f)\n", xx, bx::sqrtRef(xx), ::sqrtf(xx)); + REQUIRE(err.isOk()); + REQUIRE(bx::isEqual(bx::sqrtRef(xx), ::sqrtf(xx), 0.00001f)); + } + + // sqrtSimd + REQUIRE(bx::isNan(bx::sqrtSimd(-1.0f))); + REQUIRE(bx::isEqual(bx::sqrtSimd(0.0f), ::sqrtf(0.0f), 0.0f)); + REQUIRE(bx::isEqual(bx::sqrtSimd(1.0f), ::sqrtf(1.0f), 0.0f)); + + for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) + { + bx::write(writer, &err, "sqrtSimd(%f) == %f (expected: %f)\n", xx, bx::sqrtSimd(xx), ::sqrtf(xx)); + REQUIRE(err.isOk()); + REQUIRE(bx::isEqual(bx::sqrtSimd(xx), ::sqrtf(xx), 0.00001f)); + } + + for (float xx = 0.0f; xx < 100.0f; xx += 0.1f) + { + bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx)); + REQUIRE(err.isOk()); + REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f)); + } + + // sqrt + REQUIRE(bx::isNan(::sqrtf(-1.0f))); + REQUIRE(bx::isNan(bx::sqrt(-1.0f))); + REQUIRE(bx::isEqual(bx::sqrt(0.0f), ::sqrtf(0.0f), 0.0f)); + REQUIRE(bx::isEqual(bx::sqrt(1.0f), ::sqrtf(1.0f), 0.0f)); + + for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) + { + bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx)); + REQUIRE(err.isOk()); + REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f)); + } + + for (float xx = 0.0f; xx < 100.0f; xx += 0.1f) + { + bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx)); + REQUIRE(err.isOk()); + REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f)); + } +} + +BX_PRAGMA_DIAGNOSTIC_POP(); + TEST_CASE("libm", "") { bx::WriterI* writer = bx::getNullOut(); @@ -110,81 +203,6 @@ TEST_CASE("libm", "") REQUIRE(bx::isEqual(bx::exp(xx), ::expf(xx), 0.00001f) ); } - // rsqrt - REQUIRE(bx::isInfinite(1.0f/::sqrtf(0.0f) ) ); - REQUIRE(bx::isInfinite(bx::rsqrt(0.0f) ) ); - - for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) - { - bx::write(writer, &err, "rsqrt(%f) == %f (expected: %f)\n", xx, bx::rsqrt(xx), 1.0f/::sqrtf(xx) ); - REQUIRE(err.isOk() ); - REQUIRE(bx::isEqual(bx::rsqrt(xx), 1.0f/::sqrtf(xx), 0.00001f) ); - } - - // rsqrtRef - REQUIRE(bx::isInfinite(bx::rsqrtRef(0.0f) ) ); - - for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) - { - bx::write(writer, &err, "rsqrtRef(%f) == %f (expected: %f)\n", xx, bx::rsqrtRef(xx), 1.0f/::sqrtf(xx) ); - REQUIRE(err.isOk() ); - REQUIRE(bx::isEqual(bx::rsqrtRef(xx), 1.0f/::sqrtf(xx), 0.00001f) ); - } - - // rsqrtSimd - REQUIRE(bx::isInfinite(bx::rsqrtSimd(0.0f) ) ); - - for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f) - { - bx::write(writer, &err, "rsqrtSimd(%f) == %f (expected: %f)\n", xx, bx::rsqrtSimd(xx), 1.0f/::sqrtf(xx) ); - REQUIRE(err.isOk() ); - REQUIRE(bx::isEqual(bx::rsqrtSimd(xx), 1.0f/::sqrtf(xx), 0.00001f) ); - } - - // sqrt - REQUIRE(bx::isNan(::sqrtf(-1.0f) ) ); - REQUIRE(bx::isNan(bx::sqrt(-1.0f) ) ); - REQUIRE(bx::isEqual(bx::sqrt(0.0f), ::sqrtf(0.0f), 0.0f) ); - REQUIRE(bx::isEqual(bx::sqrt(1.0f), ::sqrtf(1.0f), 0.0f) ); - - for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) - { - bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx) ); - REQUIRE(err.isOk() ); - REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f) ); - } - - // sqrtRef - REQUIRE(bx::isNan(bx::sqrtRef(-1.0f) ) ); - REQUIRE(bx::isEqual(bx::sqrtRef(0.0f), ::sqrtf(0.0f), 0.0f) ); - REQUIRE(bx::isEqual(bx::sqrtRef(1.0f), ::sqrtf(1.0f), 0.0f) ); - - for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) - { - bx::write(writer, &err, "sqrtRef(%f) == %f (expected: %f)\n", xx, bx::sqrtRef(xx), ::sqrtf(xx) ); - REQUIRE(err.isOk() ); - REQUIRE(bx::isEqual(bx::sqrtRef(xx), ::sqrtf(xx), 0.00001f) ); - } - - // sqrtSimd - REQUIRE(bx::isNan(bx::sqrtSimd(-1.0f) ) ); - REQUIRE(bx::isEqual(bx::sqrtSimd(0.0f), ::sqrtf(0.0f), 0.0f) ); - REQUIRE(bx::isEqual(bx::sqrtSimd(1.0f), ::sqrtf(1.0f), 0.0f) ); - - for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f) - { - bx::write(writer, &err, "sqrtSimd(%f) == %f (expected: %f)\n", xx, bx::sqrtSimd(xx), ::sqrtf(xx) ); - REQUIRE(err.isOk() ); - REQUIRE(bx::isEqual(bx::sqrtSimd(xx), ::sqrtf(xx), 0.00001f) ); - } - - for (float xx = 0.0f; xx < 100.0f; xx += 0.1f) - { - bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx) ); - REQUIRE(err.isOk() ); - REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f) ); - } - for (float xx = -100.0f; xx < 100.0f; xx += 0.1f) { bx::write(writer, &err, "pow(1.389f, %f) == %f (expected: %f)\n", xx, bx::pow(1.389f, xx), ::powf(1.389f, xx) );