Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions basic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package pollon

// Contains tells whether a contains x.
func Contains(a []string, x string) bool {
for _, n := range a {
if x == n {
return true
}
}
return false
}
2 changes: 1 addition & 1 deletion examples/simple/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func Check(c chan pollon.ConfData) {
return
}
log.Printf("address: %s", addr)
c <- pollon.ConfData{DestAddr: addr}
c <- pollon.ConfData{DestAddr: []*net.TCPAddr{addr}}
}

func main() {
Expand Down
166 changes: 141 additions & 25 deletions pollon.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,125 @@ package pollon
import (
"fmt"
"io"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
)

type LBType string

const (
Random LBType = "random"
LeastQueue LBType = "leastqueue"
)

type ConfData struct {
DestAddr *net.TCPAddr
DestAddr []*net.TCPAddr
}

type Proxy struct {
C chan ConfData
listener *net.TCPListener
confMutex sync.Mutex
type Backend struct {
destAddr *net.TCPAddr
closeConns chan struct{}
stop chan struct{}
endCh chan error
connMutex sync.Mutex
needClean bool
connNum int32
}

type Proxy struct {
C chan ConfData
listener *net.TCPListener
confMutex sync.Mutex
connMutex sync.Mutex
stop chan struct{}
endCh chan error
keepAlive bool
keepAliveIdle time.Duration
keepAliveCount int
keepAliveInterval time.Duration
backends []*Backend
lbType LBType
}

func NewProxy(listener *net.TCPListener) (*Proxy, error) {
return &Proxy{
C: make(chan ConfData),
listener: listener,
closeConns: make(chan struct{}),
stop: make(chan struct{}),
endCh: make(chan error),
connMutex: sync.Mutex{},
C: make(chan ConfData),
listener: listener,
stop: make(chan struct{}),
endCh: make(chan error),
lbType: Random,
}, nil
}

func newBackend(destAddr *net.TCPAddr) *Backend {
return &Backend{
destAddr: destAddr,
closeConns: make(chan struct{}),
}
}

func (p *Proxy) GetBackend() *Backend {
if p.lbType == LeastQueue {
var backResult *Backend = nil
var connNum int32
for _, b := range p.backends {
if backResult == nil {
backResult = b
connNum = atomic.LoadInt32(&b.connNum)
continue
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a biggie.

Maybe rename it as leastConnNum or something to signify that's the least connection count.

}
currConnNum := atomic.LoadInt32(&b.connNum)
if connNum > currConnNum {
backResult = b
connNum = currConnNum
}
}
return backResult
}
if len(p.backends) > 0 {
return p.backends[rand.Intn(len(p.backends))]
}
return nil
}

func (b *Backend) incConn() {
atomic.AddInt32(&b.connNum, 1)
}
func (b *Backend) decConn() {
atomic.AddInt32(&b.connNum, -1)
}

// proxy client connection
func (p *Proxy) proxyConn(conn *net.TCPConn) {
p.connMutex.Lock()
closeConns := p.closeConns
destAddr := p.destAddr
p.connMutex.Unlock()
log.Printf("INFO start source connection: %v", conn)
defer func() {
log.Printf("closing source connection: %v", conn)
conn.Close()
}()
defer conn.Close()

p.connMutex.Lock()
back := p.GetBackend()
p.connMutex.Unlock()
if back == nil {
log.Printf("ERR no backends, closing source connection: %v", conn)
return
}
p.confMutex.Lock()
closeConns := back.closeConns
destAddr := back.destAddr
p.confMutex.Unlock()
if destAddr == nil {
log.Printf("ERR bad destAddr, closing source connection: %v", conn)
return
}
back.incConn()
defer back.decConn()

var d net.Dialer
d.Cancel = closeConns
destConnInterface, err := d.Dial("tcp", destAddr.String())
if err != nil {
log.Printf("ERR destAddr dial, closing source connection: %v", conn)
conn.Close()
return
}
Expand Down Expand Up @@ -119,19 +183,38 @@ func (p *Proxy) proxyConn(conn *net.TCPConn) {
}
}

//reconfig backends
func (p *Proxy) confCheck() {
for {
select {
case <-p.stop:
p.confMutex.Lock()
// Is last iteration before func() exit, use defer
defer p.confMutex.Unlock()
for _, back := range p.backends {
back.needClean = true
}
p.BackendCleaning()
return
case confData := <-p.C:
if confData.DestAddr.String() != p.destAddr.String() {
p.connMutex.Lock()
close(p.closeConns)
p.closeConns = make(chan struct{})
p.destAddr = confData.DestAddr
p.connMutex.Unlock()
var dAddrStr []string
p.confMutex.Lock()
// Add new backends
for _, dAddr := range confData.DestAddr {
// if New backend exists
if !Contains(p.GetBackendsString(), dAddr.String()) {
p.backends = append(p.backends, newBackend(dAddr))
}
dAddrStr = append(dAddrStr, dAddr.String())
}
// Delete stale backends & force close connections
for _, back := range p.backends {
if !Contains(dAddrStr, back.destAddr.String()) {
back.needClean = true
}
}
p.BackendCleaning()
p.confMutex.Unlock()
}
}
}
Expand Down Expand Up @@ -168,6 +251,35 @@ func (p *Proxy) Start() error {
return nil
}

func (p *Proxy) GetBackendsString() []string {
var result []string
for _, b := range p.backends {
result = append(result, b.destAddr.String())
}
return result
}

func (p *Proxy) BackendCleaning() {
last := len(p.backends) - 1
if last < 0 {
return
}
for i := last; i >= 0; i-- {
if p.backends[i].needClean {
close(p.backends[i].closeConns)
if i != last {
p.backends[i], p.backends[last] = p.backends[last], p.backends[i]
}
last--
}
}
if last < 0 {
p.backends = nil
} else {
p.backends = p.backends[:last+1]
}
}

func (p *Proxy) SetKeepAlive(keepalive bool) {
p.keepAlive = keepalive
}
Expand All @@ -183,3 +295,7 @@ func (p *Proxy) SetKeepAliveCount(n int) {
func (p *Proxy) SetKeepAliveInterval(d time.Duration) {
p.keepAliveInterval = d
}

func (p *Proxy) SetLBType(lbt LBType) {
p.lbType = lbt
}