diff --git a/lb/lb.go b/lb/lb.go index 131a383f5..ffd3672a0 100644 --- a/lb/lb.go +++ b/lb/lb.go @@ -4,14 +4,17 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "flag" "fmt" "io" "io/ioutil" + "math/rand" "net" "net/http" "net/http/httputil" "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -77,8 +80,19 @@ type chProxy struct { hcUnhealthy int64 hcTimeout time.Duration + statMu sync.Mutex + stats []*stat + proxy *httputil.ReverseProxy httpClient *http.Client + transport http.RoundTripper +} + +type stat struct { + tim time.Time + latency time.Duration + host string + code uint64 } func newProxy(conf config) *chProxy { @@ -104,10 +118,14 @@ func newProxy(conf config) *chProxy { hcUnhealthy: int64(conf.HealthcheckUnhealthy), hcTimeout: time.Duration(conf.HealthcheckTimeout) * time.Second, httpClient: &http.Client{Transport: tranny}, + transport: tranny, } director := func(req *http.Request) { - target := ch.ch.get(req.URL.Path) + target, err := ch.ch.get(req.URL.Path) + if err != nil { + target = "error" + } req.URL.Scheme = "http" // XXX (reed): h2 support req.URL.Host = target @@ -127,6 +145,29 @@ func newProxy(conf config) *chProxy { return ch } +func (ch *chProxy) RoundTrip(req *http.Request) (*http.Response, error) { + if req.URL.Host == "error" { + io.Copy(ioutil.Discard, req.Body) + req.Body.Close() + // XXX (reed): if we let the proxy code write the response it will be body-less. ok? + return nil, ErrNoNodes + } + + resp, err := ch.transport.RoundTrip(req) + ch.intercept(req, resp) + return resp, err +} + +func (ch *chProxy) intercept(req *http.Request, resp *http.Response) { + // XXX (reed): give f(x) nodes ability to self inspect load and send it back + // XXX (reed): we should prob clear this from user response + load, _ := strconv.Atoi(resp.Header.Get("XXX-FXLB-LOAD")) + // XXX (reed): need to validate these prob + ch.ch.setLoad(req.URL.Host, int64(load)) + + // XXX (reed): stats data +} + type bufferPool struct { bufs *sync.Pool } @@ -217,6 +258,13 @@ func (ch *chProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { } ch.proxy.ServeHTTP(w, r) + + // XXX (reed): for stats need our own transport + //ch.statsMu.Lock() + //ch.stats = append(ch.stats, &stat{ + //host: r.URL.Host, + //} + //ch.stats = r.URL.Host } func (ch *chProxy) addNode(w http.ResponseWriter, r *http.Request) { @@ -330,6 +378,9 @@ type consistentHash struct { // protects nodes sync.RWMutex nodes []string + + loadMu sync.RWMutex + load map[string]*int64 } func (ch *consistentHash) add(newb string) { @@ -364,14 +415,14 @@ func (ch *consistentHash) list() []string { return ret } -func (ch *consistentHash) get(key string) string { +func (ch *consistentHash) get(key string) (string, error) { // crc not unique enough & sha is too slow, it's 1 import sum64 := siphash.Hash(0, 0x4c617279426f6174, []byte(key)) ch.RLock() defer ch.RUnlock() i := int(jumpConsistentHash(sum64, int32(len(ch.nodes)))) - return ch.nodes[i] + return ch.besti(i) } // A Fast, Minimal Memory, Consistent Hash Algorithm: @@ -386,28 +437,75 @@ func jumpConsistentHash(key uint64, num_buckets int32) int32 { return int32(b) } -func besti() string { +func (ch *consistentHash) setLoad(key string, load int64) { + ch.loadMu.RLock() + l, ok := ch.load[key] + ch.loadMu.RUnlock() + if ok { + atomic.StoreInt64(l, load) + } else { + ch.loadMu.Lock() + if _, ok := ch.load[key]; !ok { + ch.load[key] = &load + } + ch.loadMu.Unlock() + } +} + +var ( + ErrNoNodes = errors.New("no nodes available") +) + +// XXX (reed): push down fails / load into ch +func (ch *consistentHash) besti(i int) (string, error) { ch.RLock() defer ch.RUnlock() - for _, n := range ch.nodes[i:] { - load := atomic.LoadInt64(&ch.load[n]) + if len(ch.nodes) < 1 { + return "", ErrNoNodes + } + + f := func(n string) string { + var load int64 + ch.loadMu.RLock() + loadPtr := ch.load[n] + ch.loadMu.RUnlock() + if loadPtr != nil { + load = atomic.LoadInt64(loadPtr) + } // TODO flesh out these values with some testing // back off loaded nodes slightly to spread load - if load < .7 { + if load < 70 { return n - } else if load > .9 { - if rand.Float64() < .6 { + } else if load > 90 { + // XXX (reed): seed rand + if rand.Intn(100) < 60 { return n } - } else if load > .7 { - if rand.Float64() < .8 { + } else if load > 70 { + if rand.Float64() < 80 { return n } } // otherwise loop until we find a sufficiently unloaded node or a lucky coin flip + return "" } - panic("XXX: (reed) need to 503 or try with higher tolerance") + for _, n := range ch.nodes[i:] { + node := f(n) + if node != "" { + return node, nil + } + } + + // try the other half of the ring + for _, n := range ch.nodes[:i] { + node := f(n) + if node != "" { + return node, nil + } + } + + return "", ErrNoNodes }