Skip to content

Commit 00971b8

Browse files
authored
Merge pull request #1770 from bytegrrrl/sqrt
Add pure FixedPoint sqrt implementation and tests
2 parents f227563 + 292d6da commit 00971b8

File tree

3 files changed

+96
-3
lines changed

3 files changed

+96
-3
lines changed

copying.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ _the openage authors_ are:
166166
| Eelco Empting | Eeelco | me à eelco dawt de |
167167
| Jordan Sutton | jsutCodes | jsutcodes à gmail dawt com |
168168
| Daniel Wieczorek | Danio | danielwieczorek96 à gmail dawt com |
169+
| | bytegrrrl | bytegrrrl à proton dawt me |
169170

170171
If you're a first-time committer, add yourself to the above list. This is not
171172
just for legal reasons, but also to keep an overview of all those nicknames.

libopenage/util/fixed_point.h

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include <algorithm>
6+
#include <bit>
67
#include <climits>
78
#include <cmath>
89
#include <iomanip>
@@ -446,8 +447,52 @@ class FixedPoint {
446447
return is;
447448
}
448449

449-
constexpr double sqrt() {
450-
return std::sqrt(this->to_double());
450+
/**
451+
* Pure FixedPoint sqrt implementation using Heron's Algorithm.
452+
*
453+
* Note that this function is undefined for negative values.
454+
*
455+
* There's a small loss in precision depending on the value of fractional_bits and the position of
456+
* the most significant bit: if the integer portion is very large, we won't have as much (absolute)
457+
* precision. Ideally you would want the intermediate_type to be twice the size of raw_type to avoid
458+
* any losses.
459+
*/
460+
constexpr FixedPoint sqrt() {
461+
// Zero can cause issues later, so deal with now.
462+
if (this->raw_value == 0) {
463+
return zero();
464+
}
465+
466+
// Check for negative values
467+
if constexpr (std::is_signed<raw_type>()) {
468+
ENSURE(this->raw_value > 0, "FixedPoint::sqrt() is undefined for negative values.");
469+
}
470+
471+
// A greater shift = more precision, but can overflow the intermediate type if too large.
472+
size_t max_shift = std::countl_zero(static_cast<unsigned_intermediate_type>(this->raw_value)) - 1;
473+
size_t shift = max_shift > fractional_bits ? fractional_bits : max_shift;
474+
475+
// shift + fractional bits must be an even number
476+
if ((shift + fractional_bits) % 2) {
477+
shift -= 1;
478+
}
479+
480+
// We can't use the safe shift since the shift value is unknown at compile time.
481+
intermediate_type n = static_cast<intermediate_type>(this->raw_value) << shift;
482+
intermediate_type guess = static_cast<intermediate_type>(1) << fractional_bits;
483+
484+
for (size_t i = 0; i < fractional_bits; i++) {
485+
intermediate_type prev = guess;
486+
guess = (guess + n / guess) / 2;
487+
if (guess == prev) {
488+
break;
489+
}
490+
}
491+
492+
// The sqrt operation halves the number of bits, so we'll we'll have to calculate a shift back
493+
size_t unshift = fractional_bits - (shift + fractional_bits) / 2;
494+
495+
return from_raw_value(guess << unshift);
451496
}
452497

453498
constexpr double atan2(const FixedPoint &n) {
@@ -574,7 +619,7 @@ namespace std {
574619

575620
template <typename I, unsigned F, typename Inter>
576621
constexpr double sqrt(openage::util::FixedPoint<I, F, Inter> n) {
577-
return n.sqrt();
622+
return static_cast<double>(n.sqrt());
578623
}
579624

580625
template <typename I, unsigned F, typename Inter>

libopenage/util/fixed_point_test.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,53 @@ void fixed_point() {
156156
TESTEQUALS_FLOAT((c/b).to_double(), -4.75/3.5, 0.1);
157157
}
158158

159+
// Pure FixedPoint sqrt tests
160+
{
161+
using T = FixedPoint<int64_t, 32, int64_t>;
162+
TESTEQUALS_FLOAT(T(41231.131).sqrt(), 203.0545025356, 1e-7);
163+
TESTEQUALS_FLOAT(T(547965.116).sqrt(), 740.2466588915, 1e-7);
164+
165+
TESTEQUALS_FLOAT(T(2).sqrt(), T::sqrt_2(), 1e-9);
166+
TESTEQUALS_FLOAT(2 / std::sqrt(T::pi()), T::inv2_sqrt_pi(), 1e-9);
167+
168+
// Powers of two (anything over 2^15 will overflow (2^16)^2 = 2^32 >).
169+
for (size_t i = 0; i < 15; i++) {
170+
int64_t value = 1 << i;
171+
TESTEQUALS_FLOAT(T(value * value).sqrt(), value, 1e-7);
172+
}
173+
174+
for (size_t i = 0; i < 100; i++) {
175+
double value = 14.25 * i;
176+
TESTEQUALS_FLOAT(T(value * value).sqrt(), value, 1e-7);
177+
}
178+
179+
// This one can go up to 2^63, but that would take years.
180+
for (uint32_t i = 0; i < 65536; i++) {
181+
T value = T::from_raw_value(i * i);
182+
TESTEQUALS_FLOAT(value.sqrt(), std::sqrt(value.to_double()), 1e-7);
183+
}
184+
185+
// We lose some precision when raw_type == intermediate_type
186+
for (uint64_t i = 1; i < std::numeric_limits<uint64_t>::max(); i = (i * 2) ^ i) {
187+
T value = T::from_raw_value(i * i);
188+
if (value < 0) {
189+
value = -value;
190+
}
191+
TESTEQUALS_FLOAT(value.sqrt(), std::sqrt(value.to_double()), 1e-4);
192+
}
193+
194+
using FP16_16 = FixedPoint<uint32_t, 16, uint64_t>;
195+
for (uint32_t i = 1; i < 65536; i++) {
196+
FP16_16 value = FP16_16::from_raw_value(i);
197+
TESTEQUALS_FLOAT(value.sqrt(), std::sqrt(value.to_double()), 1e-4);
198+
}
199+
200+
201+
// Test with negative number
202+
TESTTHROWS((FixedPoint<int64_t, 32>::from_float(-3.25).sqrt()));
203+
TESTNOEXCEPT((FixedPoint<int64_t, 32>::from_float(3.25).sqrt()));
204+
TESTNOEXCEPT((FixedPoint<uint64_t, 32>::from_float(-3.25).sqrt()));
205+
}
159206
}
160207

161208
}}} // openage::util::tests

0 commit comments

Comments
 (0)