Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 184 additions & 106 deletions field/eisenstein/eisenstein.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,28 @@ package eisenstein

import (
"math/big"
"sync"
)

// A ComplexNumber represents an arbitrary-precision Eisenstein integer.
type ComplexNumber struct {
A0, A1 *big.Int
A0, A1 big.Int
t0, t1, t2, t3, t4 big.Int // temporary variables
_ sync.Mutex // to ensure there is no accidental value copy
}

// ──────────────────────────────────────────────────────────────────────────────
// helpers – hex-lattice geometry & symmetric rounding
// ──────────────────────────────────────────────────────────────────────────────

// six axial directions of the hexagonal lattice
var neighbours = [][2]int64{
{1, 0}, {0, 1}, {-1, 1}, {-1, 0}, {0, -1}, {1, -1},
}

// roundNearest returns ⌊(z + d/2) / d⌋ for *any* sign of z, d>0
func roundNearest(z, d *big.Int) *big.Int {
half := new(big.Int).Rsh(d, 1) // d / 2
if z.Sign() >= 0 {
return new(big.Int).Div(new(big.Int).Add(z, half), d)
}
tmp := new(big.Int).Neg(z)
tmp.Add(tmp, half)
tmp.Div(tmp, d)
return tmp.Neg(tmp)
}

func (z *ComplexNumber) init() {
if z.A0 == nil {
z.A0 = new(big.Int)

}
if z.A1 == nil {
z.A1 = new(big.Int)

}
var neighbours = [6][2]*big.Int{
{big.NewInt(1), big.NewInt(0)},
{big.NewInt(0), big.NewInt(1)},
{big.NewInt(-1), big.NewInt(1)},
{big.NewInt(-1), big.NewInt(0)},
{big.NewInt(0), big.NewInt(-1)},
{big.NewInt(1), big.NewInt(-1)},
}

// String implements Stringer interface for fancy printing
Expand All @@ -48,152 +33,246 @@ func (z *ComplexNumber) String() string {

// Equal returns true if z equals x, false otherwise
func (z *ComplexNumber) Equal(x *ComplexNumber) bool {
return z.A0.Cmp(x.A0) == 0 && z.A1.Cmp(x.A1) == 0
return z.A0.Cmp(&x.A0) == 0 && z.A1.Cmp(&x.A1) == 0
}

// Set sets z to x, and returns z.
func (z *ComplexNumber) Set(x *ComplexNumber) *ComplexNumber {
z.init()
z.A0.Set(x.A0)
z.A1.Set(x.A1)
z.A0.Set(&x.A0)
z.A1.Set(&x.A1)
return z
}

// SetZero sets z to 0, and returns z.
func (z *ComplexNumber) SetZero() *ComplexNumber {
z.A0 = big.NewInt(0)
z.A1 = big.NewInt(0)
z.A0.SetUint64(0)
z.A1.SetUint64(0)
return z
}

// SetOne sets z to 1, and returns z.
func (z *ComplexNumber) SetOne() *ComplexNumber {
z.A0 = big.NewInt(1)
z.A1 = big.NewInt(0)
z.A0.SetUint64(1)
z.A1.SetUint64(0)
return z
}

// Neg sets z to the negative of x, and returns z.
func (z *ComplexNumber) Neg(x *ComplexNumber) *ComplexNumber {
z.init()
z.A0.Neg(x.A0)
z.A1.Neg(x.A1)
z.A0.Neg(&x.A0)
z.A1.Neg(&x.A1)
return z
}

// Conjugate sets z to the conjugate of x, and returns z.
// The conjugate of an Eisenstein integer x₀ + x₁ω is defined as:
// (x₀ - x₁) - x₁ω
func (z *ComplexNumber) Conjugate(x *ComplexNumber) *ComplexNumber {
z.init()
z.A0.Sub(x.A0, x.A1)
z.A1.Neg(x.A1)
z.A0.Sub(&x.A0, &x.A1)
z.A1.Neg(&x.A1)
return z
}

// Add sets z to the sum of x and y, and returns z.
func (z *ComplexNumber) Add(x, y *ComplexNumber) *ComplexNumber {
z.init()
z.A0.Add(x.A0, y.A0)
z.A1.Add(x.A1, y.A1)
z.A0.Add(&x.A0, &y.A0)
z.A1.Add(&x.A1, &y.A1)
return z
}

// Sub sets z to the difference of x and y, and returns z.
func (z *ComplexNumber) Sub(x, y *ComplexNumber) *ComplexNumber {
z.init()
z.A0.Sub(x.A0, y.A0)
z.A1.Sub(x.A1, y.A1)
z.A0.Sub(&x.A0, &y.A0)
z.A1.Sub(&x.A1, &y.A1)
return z
}

// Mul sets z to the product of x and y, and returns z.
//
// Given that ω²+ω+1=0, the explicit formula is:
//
// (x0+x1ω)(y0+y1ω) = (x0y0-x1y1) + (x0y1+x1y0-x1y1)ω
// (x₀ + x₁ω)(y₀ + y₁ω) = (x₀y₀ - x₁y₁) + (x₀y₁ + x₁y₀ - x₁y₁)ω
//
// We use Karatsuba multiplication to compute the product efficiently.
func (z *ComplexNumber) Mul(x, y *ComplexNumber) *ComplexNumber {
z.init()
var t [3]big.Int
var z0, z1 big.Int
t[0].Mul(x.A0, y.A0)
t[1].Mul(x.A1, y.A1)
z0.Sub(&t[0], &t[1])
t[0].Mul(x.A0, y.A1)
t[2].Mul(x.A1, y.A0)
t[0].Add(&t[0], &t[2])
z1.Sub(&t[0], &t[1])
z.A0.Set(&z0)
z.A1.Set(&z1)
z.t0.Mul(&x.A0, &y.A0) // t0 = x₀y₀
z.t1.Mul(&x.A1, &y.A1) // t1 = x₁y₁
z.t2.Add(&x.A0, &x.A1) // t2 = x₀ + x₁
z.t3.Add(&y.A0, &y.A1) // t3 = y₀ + y₁
z.t2.Mul(&z.t2, &z.t3) // t2 = (x₀ + x₁)(y₀ + y₁)

z.A0.Sub(&z.t0, &z.t1) // A0 = x₀y₀ - x₁y₁
z.t3.Add(&z.t1, &z.t1)
z.t3.Add(&z.t3, &z.t0)

z.A1.Sub(&z.t2, &z.t3) // A1 = (x₀ + x₁)(y₀ + y₁) - x₀y₀ - x₁y₁

return z
}

// MulByConjugate sets z to the product of x and the conjugate of y
//
// x * ȳ = (x₀ + x₁ω)((y₀ - y₁) - y₁ω) = (x₀(y₀-y₁) + x₁y₁) + (-x₀y₁ + x₁(y₀-y₁) + x₁y₁)ω
// = (x₀y₀ + x₁y₁ - x₀y₁) + (x₁y₀ - x₀y₁)ω
func (z *ComplexNumber) MulByConjugate(x, y *ComplexNumber) *ComplexNumber {
z.t0.Mul(&x.A1, &y.A0) // t0 = x₁y₀
z.t1.Mul(&x.A0, &y.A1) // t1 = x₀y₁
z.t2.Add(&x.A0, &x.A1) // t2 = x₀ + x₁
z.t3.Add(&y.A0, &y.A1) // t3 = y₀ + y₁
z.t2.Mul(&z.t2, &z.t3) // t2 = (x₀ + x₁)(y₀ + y₁) = x₀y₀ + x₁y₁ + x₀y₁ + x₁y₀

z.t3.Add(&z.t1, &z.t1)
z.t3.Add(&z.t3, &z.t0)
z.A0.Sub(&z.t2, &z.t3) // A0 = x₀y₀ + x₁y₁ - x₀y₁ = t₂ - t₀ - 2t₁

z.A1.Sub(&z.t0, &z.t1) // A1 = x₁y₀ - x₀y₁ = t₀ - t₁

return z
}

// Norm returns the norm of z.
//
// The explicit formula is:
//
// N(x0+x1ω) = x0² + x1² - x0*x1
func (z *ComplexNumber) Norm() *big.Int {
norm := new(big.Int)
temp := new(big.Int)
norm.Add(
norm.Mul(z.A0, z.A0),
temp.Mul(z.A1, z.A1),
)
norm.Sub(
norm,
temp.Mul(z.A0, z.A1),
)
// N(x0+x1ω) = x₀² + x₁² - x₀x₁
//
// We rearrange into it (x₀-x₁)² + x₀x₁
func (z *ComplexNumber) Norm(norm *big.Int) *big.Int {
z.t1.Sub(&z.A0, &z.A1).Mul(&z.t1, &z.t1)
z.t2.Mul(&z.A0, &z.A1)
norm.Add(&z.t1, &z.t2)
return norm
}

func (z *ComplexNumber) roundNearest(num *ComplexNumber, d *big.Int) {
z.t1.Abs(d)
dBitLen := z.t1.BitLen()

// Helper function for rounding one component
roundComp := func(result, comp *big.Int) {
isNegativeResult := (comp.Sign() < 0) != (d.Sign() < 0)
z.t0.Abs(comp)

// Bit length shortcut before full comparison
t0BitLen := z.t0.BitLen()
if t0BitLen < dBitLen || (t0BitLen == dBitLen && z.t0.Cmp(&z.t1) < 0) {
// |a| < |b|
z.t2.Lsh(&z.t0, 1) // t2 = 2 * |a|
if z.t2.BitLen() > dBitLen || (z.t2.BitLen() == dBitLen && z.t2.Cmp(&z.t1) >= 0) {
if isNegativeResult {
result.SetInt64(-1)
} else {
result.SetInt64(1)
}
} else {
result.SetInt64(0)
}
} else {
// division and rounding
z.t2.Set(&z.t0) // remainder = |a|
k := t0BitLen - dBitLen
z.t3.Lsh(&z.t1, uint(k))
if z.t3.Cmp(&z.t0) > 0 {
k--
}
result.SetInt64(0)
for i := k; i >= 0; i-- {
z.t3.Lsh(&z.t1, uint(i))
if z.t2.Cmp(&z.t3) >= 0 {
z.t2.Sub(&z.t2, &z.t3)
result.SetBit(result, i, 1)
}
}
z.t3.Lsh(&z.t2, 1)
if z.t3.Cmp(&z.t1) >= 0 {
z.t4.SetUint64(1)
result.Add(result, &z.t4)
}
if isNegativeResult {
result.Neg(result)
}
}
}

// Round both components
roundComp(&z.A0, &num.A0)
roundComp(&z.A1, &num.A1)
}

// Quo sets z to the Euclidean quotient of x / y
// and guarantees ‖r‖ < ‖y‖ (true Euclidean division in ℤ[ω]).
func (z *ComplexNumber) Quo(x, y *ComplexNumber) *ComplexNumber {

// y.t0 = Norm(y)
y.Norm(&y.t0)
if y.t0.Sign() == 0 {
panic("division by zero")
}

// z = x * ȳ
z.MulByConjugate(x, y)

// rounding of both coordinates
z.roundNearest(z, &y.t0)

return z
}

// QuoRem sets z to the Euclidean quotient of x / y, r to the remainder,
// and guarantees ‖r‖ < ‖y‖ (true Euclidean division in ℤ[ω]).
func (z *ComplexNumber) QuoRem(x, y, r *ComplexNumber) (*ComplexNumber, *ComplexNumber) {

norm := y.Norm() // > 0 (Eisenstein norm is always non-neg)
if norm.Sign() == 0 {
y.Norm(&y.t0) // > 0 (Eisenstein norm is always non-neg)
if y.t0.Sign() == 0 {
panic("division by zero")
}

// num = x * ȳ (ȳ computed in a fresh variable → y unchanged)
var yConj, num ComplexNumber
yConj.Conjugate(y)
num.Mul(x, &yConj)
// z = x * ȳ
z.MulByConjugate(x, y)

// first guess by *symmetric* rounding of both coordinates
q0 := roundNearest(num.A0, norm)
q1 := roundNearest(num.A1, norm)
z.A0, z.A1 = q0, q1
z.roundNearest(z, &y.t0)

// r = x – q*y
r.Mul(y, z)
r.Sub(x, r)

// If Euclidean inequality already holds we're done.
if r.Norm(&x.t1).Cmp(&y.t0) < 0 {
return z, r
}

// Otherwise walk ≤2 unit steps in the hex lattice until N(r) < N(y).
if r.Norm().Cmp(norm) >= 0 {
bestQ0, bestQ1 := new(big.Int).Set(z.A0), new(big.Int).Set(z.A1)
bestR := new(ComplexNumber).Set(r)
bestN2 := bestR.Norm()

for _, dir := range neighbours {
candQ0 := new(big.Int).Add(z.A0, big.NewInt(dir[0]))
candQ1 := new(big.Int).Add(z.A1, big.NewInt(dir[1]))
var candQ ComplexNumber
candQ.A0, candQ.A1 = candQ0, candQ1

var candR ComplexNumber
candR.Mul(y, &candQ)
candR.Sub(x, &candR)

if candR.Norm().Cmp(bestN2) < 0 {
bestQ0, bestQ1 = candQ0, candQ1
bestR.Set(&candR)
bestN2 = bestR.Norm()
}
bestNorm := &z.t0
bestQ0, bestQ1 := &z.t1, &z.t2
a0, a1 := &z.t3, &z.t4
a0.Set(&z.A0)
a1.Set(&z.A1)
bestQ0.Set(a0)
bestQ1.Set(a1)

bestNorm.Set(&x.t1) // bestNorm = N(r)

// six axial directions of the hexagonal lattice
// {1, 0}, {0, 1}, {-1, 1}, {-1, 0}, {0, -1}, {1, -1}
var candR ComplexNumber
for _, dir := range neighbours {
z.A0.Add(a0, dir[0])
z.A1.Add(a1, dir[1])

candR.Mul(y, z)
candR.Sub(x, &candR)

if candR.Norm(&x.t1).Cmp(bestNorm) < 0 {
bestQ0.Set(&z.A0)
bestQ1.Set(&z.A1)
r.Set(&candR)
bestNorm.Set(&x.t1)
}
z.A0, z.A1 = bestQ0, bestQ1
r.Set(bestR) // update remainder and retry; Euclidean property ⇒ ≤ 2 loops
}
z.A0.Set(bestQ0)
z.A1.Set(bestQ1)

return z, r
}

Expand All @@ -202,7 +281,6 @@ func (z *ComplexNumber) QuoRem(x, y, r *ComplexNumber) (*ComplexNumber, *Complex
func HalfGCD(a, b *ComplexNumber) [3]*ComplexNumber {

var aRun, bRun, u, v, u_, v_, quotient, remainder, t, t1, t2 ComplexNumber
var sqrt big.Int

aRun.Set(a)
bRun.Set(b)
Expand All @@ -211,9 +289,9 @@ func HalfGCD(a, b *ComplexNumber) [3]*ComplexNumber {
u_.SetZero()
v_.SetOne()

// Eisenstein integers form an Euclidean domain for the norm
sqrt.Sqrt(a.Norm())
for bRun.Norm().Cmp(&sqrt) >= 0 {
// Eisenstein integers form an Euclidean domain for the norm = a.t0
a.t0.Sqrt(a.Norm(&a.t1))
for bRun.Norm(&a.t1).Cmp(&a.t0) >= 0 {
quotient.QuoRem(&aRun, &bRun, &remainder)
t.Mul(&u_, &quotient)
t1.Sub(&u, &t)
Expand Down
Loading