Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 52 additions & 50 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,23 @@ type conn interface {

var _ conn = (*mux)(nil)

type muxwire struct {
wire atomic.Value
sc *singleconnect
mu sync.Mutex
}

type mux struct {
init wire
dead wire
clhks atomic.Value
dpool *pool
spool *pool
wireFn wireFn
dst string
wire []atomic.Value
sc []*singleconnect
mu []sync.Mutex
maxp int
maxm int
init wire
dead wire
clhks atomic.Value
dpool *pool
spool *pool
wireFn wireFn
dst string
muxwires []muxwire
maxp int
maxm int

usePool bool
optIn bool
Expand Down Expand Up @@ -91,18 +95,16 @@ func newMux(dst string, option *ClientOption, init, dead wire, wireFn wireFn, wi
multiplex = 1
}
m := &mux{dst: dst, init: init, dead: dead, wireFn: wireFn,
wire: make([]atomic.Value, multiplex),
mu: make([]sync.Mutex, multiplex),
sc: make([]*singleconnect, multiplex),
maxp: runtime.GOMAXPROCS(0),
maxm: option.BlockingPipeline,
muxwires: make([]muxwire, multiplex),
maxp: runtime.GOMAXPROCS(0),
maxm: option.BlockingPipeline,

usePool: option.DisableAutoPipelining,
optIn: isOptIn(option.ClientTrackingOptions),
}
m.clhks.Store(emptyclhks)
for i := 0; i < len(m.wire); i++ {
m.wire[i].Store(init)
for i := 0; i < len(m.muxwires); i++ {
m.muxwires[i].wire.Store(init)
}

m.dpool = newPool(option.BlockingPoolSize, dead, option.BlockingPoolCleanup, option.BlockingPoolMinSize, wireFn)
Expand Down Expand Up @@ -134,7 +136,7 @@ func (m *mux) setCloseHookOnWire(i uint16, w wire) {
if w != m.dead && w != m.init {
w.SetOnCloseHook(func(err error) {
if err != ErrClosing {
if m.wire[i].CompareAndSwap(w, m.init) {
if m.muxwires[i].wire.CompareAndSwap(w, m.init) {
m.clhks.Load().(func(error))(err)
}
}
Expand All @@ -144,47 +146,47 @@ func (m *mux) setCloseHookOnWire(i uint16, w wire) {

func (m *mux) Override(cc conn) {
if m2, ok := cc.(*mux); ok {
for i := 0; i < len(m.wire) && i < len(m2.wire); i++ {
w := m2.wire[i].Load().(wire)
for i := 0; i < len(m.muxwires) && i < len(m2.muxwires); i++ {
w := m2.muxwires[i].wire.Load().(wire)
m.setCloseHookOnWire(uint16(i), w) // bind the new m to the old w
m.wire[i].CompareAndSwap(m.init, w)
m.muxwires[i].wire.CompareAndSwap(m.init, w)
}
}
}

func (m *mux) _pipe(ctx context.Context, i uint16) (w wire, err error) {
if w = m.wire[i].Load().(wire); w != m.init {
if w = m.muxwires[i].wire.Load().(wire); w != m.init {
return w, nil
}

m.mu[i].Lock()
sc := m.sc[i]
if m.sc[i] == nil {
m.sc[i] = &singleconnect{}
m.sc[i].g.Add(1)
m.muxwires[i].mu.Lock()
sc := m.muxwires[i].sc
if m.muxwires[i].sc == nil {
m.muxwires[i].sc = &singleconnect{}
m.muxwires[i].sc.g.Add(1)
}
m.mu[i].Unlock()
m.muxwires[i].mu.Unlock()

if sc != nil {
sc.g.Wait()
return sc.w, sc.e
}

if w = m.wire[i].Load().(wire); w == m.init {
if w = m.muxwires[i].wire.Load().(wire); w == m.init {
if w = m.wireFn(ctx); w != m.dead {
m.setCloseHookOnWire(i, w)
m.wire[i].Store(w)
m.muxwires[i].wire.Store(w)
} else {
if err = w.Error(); err != ErrClosing {
m.clhks.Load().(func(error))(err)
}
}
}

m.mu[i].Lock()
sc = m.sc[i]
m.sc[i] = nil
m.mu[i].Unlock()
m.muxwires[i].mu.Lock()
sc = m.muxwires[i].sc
m.muxwires[i].sc = nil
m.muxwires[i].mu.Unlock()

sc.w = w
sc.e = err
Expand Down Expand Up @@ -285,46 +287,46 @@ func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (r
}

func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) {
slot := slotfn(len(m.wire), cmd.Slot(), cmd.NoReply())
slot := slotfn(len(m.muxwires), cmd.Slot(), cmd.NoReply())
wire := m.pipe(ctx, slot)
if resp = wire.Do(ctx, cmd); isBroken(resp.NonRedisError(), wire) {
m.wire[slot].CompareAndSwap(wire, m.init)
m.muxwires[slot].wire.CompareAndSwap(wire, m.init)
}
return resp
}

func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisresults) {
slot := slotfn(len(m.wire), cmd[0].Slot(), cmd[0].NoReply())
slot := slotfn(len(m.muxwires), cmd[0].Slot(), cmd[0].NoReply())
wire := m.pipe(ctx, slot)
resp = wire.DoMulti(ctx, cmd...)
for _, r := range resp.s {
if isBroken(r.NonRedisError(), wire) {
m.wire[slot].CompareAndSwap(wire, m.init)
m.muxwires[slot].wire.CompareAndSwap(wire, m.init)
return resp
}
}
return resp
}

func (m *mux) DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) RedisResult {
slot := cmd.Slot() & uint16(len(m.wire)-1)
slot := cmd.Slot() & uint16(len(m.muxwires)-1)
wire := m.pipe(ctx, slot)
resp := wire.DoCache(ctx, cmd, ttl)
if isBroken(resp.NonRedisError(), wire) {
m.wire[slot].CompareAndSwap(wire, m.init)
m.muxwires[slot].wire.CompareAndSwap(wire, m.init)
}
return resp
}

func (m *mux) DoMultiCache(ctx context.Context, multi ...CacheableTTL) (results *redisresults) {
var slots *muxslots
var mask = uint16(len(m.wire) - 1)
var mask = uint16(len(m.muxwires) - 1)

if mask == 0 {
return m.doMultiCache(ctx, 0, multi)
}

slots = muxslotsp.Get(len(m.wire), len(m.wire))
slots = muxslotsp.Get(len(m.muxwires), len(m.muxwires))
for _, cmd := range multi {
slots.s[cmd.Cmd.Slot()&mask]++
}
Expand All @@ -333,7 +335,7 @@ func (m *mux) DoMultiCache(ctx context.Context, multi ...CacheableTTL) (results
return m.doMultiCache(ctx, multi[0].Cmd.Slot()&mask, multi)
}

batches := batchcachemaps.Get(len(m.wire), len(m.wire))
batches := batchcachemaps.Get(len(m.muxwires), len(m.muxwires))
for slot, count := range slots.s {
if count > 0 {
batches.m[uint16(slot)] = batchcachep.Get(0, count)
Expand Down Expand Up @@ -370,19 +372,19 @@ func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTT
resps = wire.DoMultiCache(ctx, multi...)
for _, r := range resps.s {
if isBroken(r.NonRedisError(), wire) {
m.wire[slot].CompareAndSwap(wire, m.init)
m.muxwires[slot].wire.CompareAndSwap(wire, m.init)
return resps
}
}
return resps
}

func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error {
slot := slotfn(len(m.wire), subscribe.Slot(), subscribe.NoReply())
slot := slotfn(len(m.muxwires), subscribe.Slot(), subscribe.NoReply())
wire := m.pipe(ctx, slot)
err := wire.Receive(ctx, subscribe, fn)
if isBroken(err, wire) {
m.wire[slot].CompareAndSwap(wire, m.init)
m.muxwires[slot].wire.CompareAndSwap(wire, m.init)
}
return err
}
Expand All @@ -398,8 +400,8 @@ func (m *mux) Store(w wire) {
}

func (m *mux) Close() {
for i := 0; i < len(m.wire); i++ {
if prev := m.wire[i].Swap(m.dead).(wire); prev != m.init && prev != m.dead {
for i := 0; i < len(m.muxwires); i++ {
if prev := m.muxwires[i].wire.Swap(m.dead).(wire); prev != m.init && prev != m.dead {
prev.Close()
}
}
Expand Down
4 changes: 2 additions & 2 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ func TestNewMuxPipelineMultiplex(t *testing.T) {
defer ShouldNotLeak(SetupLeakDetection())
for _, v := range []int{-1, 0, 1, 2} {
m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return nil, nil })
if (v < 0 && len(m.wire) != 1) || (v >= 0 && len(m.wire) != 1<<v) {
t.Fatalf("unexpected len(m.wire): %v", len(m.wire))
if (v < 0 && len(m.muxwires) != 1) || (v >= 0 && len(m.muxwires) != 1<<v) {
t.Fatalf("unexpected len(m.muxwires): %v", len(m.muxwires))
}
}
}
Expand Down
Loading