Fixed rsqrt, and sqrt. Added more tests.

This commit is contained in:
Бранимир Караџић
2023-04-19 19:01:00 -07:00
parent 9a9a871d9a
commit 0d51df1779
2 changed files with 77 additions and 32 deletions

View File

@@ -212,17 +212,27 @@ namespace bx
inline BX_CONST_FUNC float rsqrtRef(float _a)
{
if (_a < kNearZero)
{
return kInfinity;
}
return pow(_a, -0.5f);
}
inline BX_CONST_FUNC float sqrtRef(float _a)
{
if (_a < kNearZero)
{
return 0.0f;
}
return _a*pow(_a, -0.5f);
}
return 1.0f/rsqrtRef(_a);
inline BX_CONST_FUNC float rsqrtSimd(float _a)
{
const simd128_t aa = simd_splat(_a);
const simd128_t rsqrta = simd_rsqrt_nr(aa);
float result;
simd_stx(&result, rsqrta);
return result;
}
inline BX_CONST_FUNC float sqrtSimd(float _a)
@@ -235,30 +245,6 @@ namespace bx
return result;
}
inline BX_CONST_FUNC float sqrt(float _a)
{
#if BX_CONFIG_SUPPORTS_SIMD
return sqrtSimd(_a);
#else
return sqrtRef(_a);
#endif // BX_CONFIG_SUPPORTS_SIMD
}
inline BX_CONST_FUNC float rsqrtSimd(float _a)
{
if (_a < kNearZero)
{
return 0.0f;
}
const simd128_t aa = simd_splat(_a);
const simd128_t rsqrta = simd_rsqrt_nr(aa);
float result;
simd_stx(&result, rsqrta);
return result;
}
inline BX_CONST_FUNC float rsqrt(float _a)
{
#if BX_CONFIG_SUPPORTS_SIMD
@@ -268,6 +254,15 @@ namespace bx
#endif // BX_CONFIG_SUPPORTS_SIMD
}
inline BX_CONST_FUNC float sqrt(float _a)
{
#if BX_CONFIG_SUPPORTS_SIMD
return sqrtSimd(_a);
#else
return sqrtRef(_a);
#endif // BX_CONFIG_SUPPORTS_SIMD
}
inline BX_CONSTEXPR_FUNC float trunc(float _a)
{
return float(int(_a) );

View File

@@ -106,16 +106,42 @@ TEST_CASE("libm", "")
REQUIRE(bx::isEqual(bx::exp(xx), ::expf(xx), 0.00001f) );
}
for (float xx = 0.0f; xx < 100.0f; xx += 0.1f)
// rsqrt
REQUIRE(bx::isInfinite(1.0f/::sqrtf(0.0f) ) );
REQUIRE(bx::isInfinite(bx::rsqrt(0.0f) ) );
for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
{
bx::write(writer, &err, "rsqrt(%f) == %f (expected: %f)\n", xx, bx::rsqrt(xx), 1.0f/::sqrtf(xx) );
REQUIRE(err.isOk() );
REQUIRE(bx::isEqual(bx::rsqrt(xx), 1.0f/::sqrtf(xx), 0.00001f) );
}
// rsqrtRef
REQUIRE(bx::isInfinite(bx::rsqrtRef(0.0f) ) );
for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
{
bx::write(writer, &err, "rsqrtRef(%f) == %f (expected: %f)\n", xx, bx::rsqrtRef(xx), 1.0f/::sqrtf(xx) );
REQUIRE(err.isOk() );
REQUIRE(bx::isEqual(bx::rsqrtRef(xx), 1.0f/::sqrtf(xx), 0.00001f) );
}
// rsqrtSimd
REQUIRE(bx::isInfinite(bx::rsqrtSimd(0.0f) ) );
for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
{
bx::write(writer, &err, "rsqrtSimd(%f) == %f (expected: %f)\n", xx, bx::rsqrtSimd(xx), 1.0f/::sqrtf(xx) );
REQUIRE(err.isOk() );
REQUIRE(bx::isEqual(bx::rsqrtSimd(xx), 1.0f/::sqrtf(xx), 0.00001f) );
}
// sqrt
REQUIRE(bx::isNan(::sqrtf(-1.0f) ) );
REQUIRE(bx::isNan(bx::sqrt(-1.0f) ) );
REQUIRE(bx::isEqual(bx::sqrt(0.0f), 0.0f, 0.0f) );
REQUIRE(bx::isEqual(bx::sqrt(1.0f), 1.0f, 0.0f) );
REQUIRE(bx::isEqual(bx::sqrt(0.0f), ::sqrtf(0.0f), 0.0f) );
REQUIRE(bx::isEqual(bx::sqrt(1.0f), ::sqrtf(1.0f), 0.0f) );
for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
{
@@ -124,6 +150,30 @@ TEST_CASE("libm", "")
REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f) );
}
// sqrtRef
REQUIRE(bx::isNan(bx::sqrtRef(-1.0f) ) );
REQUIRE(bx::isEqual(bx::sqrtRef(0.0f), ::sqrtf(0.0f), 0.0f) );
REQUIRE(bx::isEqual(bx::sqrtRef(1.0f), ::sqrtf(1.0f), 0.0f) );
for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
{
bx::write(writer, &err, "sqrtRef(%f) == %f (expected: %f)\n", xx, bx::sqrtRef(xx), ::sqrtf(xx) );
REQUIRE(err.isOk() );
REQUIRE(bx::isEqual(bx::sqrtRef(xx), ::sqrtf(xx), 0.00001f) );
}
// sqrtSimd
REQUIRE(bx::isNan(bx::sqrtSimd(-1.0f) ) );
REQUIRE(bx::isEqual(bx::sqrtSimd(0.0f), ::sqrtf(0.0f), 0.0f) );
REQUIRE(bx::isEqual(bx::sqrtSimd(1.0f), ::sqrtf(1.0f), 0.0f) );
for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
{
bx::write(writer, &err, "sqrtSimd(%f) == %f (expected: %f)\n", xx, bx::sqrtSimd(xx), ::sqrtf(xx) );
REQUIRE(err.isOk() );
REQUIRE(bx::isEqual(bx::sqrtSimd(xx), ::sqrtf(xx), 0.00001f) );
}
for (float xx = 0.0f; xx < 100.0f; xx += 0.1f)
{
bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx) );