Files
fn-serverless/lb/lb.go
2017-05-22 13:00:27 -07:00

512 lines
12 KiB
Go

package main
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Sirupsen/logrus"
"github.com/dchest/siphash"
)
// TODO: consistent hashing is nice to get a cheap way to place nodes but it
// doesn't account well for certain functions that may be 'hotter' than others.
// we should very likely keep a load ordered list and distribute based on that.
// if we can get some kind of feedback from the f(x) nodes, we can use that.
// maybe it's good enough to just ch(x) + 1 if ch(x) is marked as "hot"?
// TODO the load balancers all need to have the same list of nodes. gossip?
// also gossip would handle failure detection instead of elb style
// TODO when adding nodes we should health check them once before adding them
// TODO when node goes offline should try to redirect request instead of 5xxing
// TODO config
// TODO TLS
func main() {
// XXX (reed): normalize
fnodes := flag.String("nodes", "", "comma separated list of IronFunction nodes")
var conf config
flag.IntVar(&conf.Port, "port", 8081, "port to run on")
flag.IntVar(&conf.HealthcheckInterval, "hc-interval", 3, "how often to check f(x) nodes, in seconds")
flag.StringVar(&conf.HealthcheckEndpoint, "hc-path", "/version", "endpoint to determine node health")
flag.IntVar(&conf.HealthcheckUnhealthy, "hc-unhealthy", 2, "threshold of failed checks to declare node unhealthy")
flag.IntVar(&conf.HealthcheckTimeout, "hc-timeout", 5, "timeout of healthcheck endpoint, in seconds")
flag.Parse()
conf.Nodes = strings.Split(*fnodes, ",")
ch := newProxy(conf)
// XXX (reed): safe shutdown
fmt.Println(http.ListenAndServe(":8081", ch))
}
type config struct {
Port int `json:"port"`
Nodes []string `json:"nodes"`
HealthcheckInterval int `json:"healthcheck_interval"`
HealthcheckEndpoint string `json:"healthcheck_endpoint"`
HealthcheckUnhealthy int `json:"healthcheck_unhealthy"`
HealthcheckTimeout int `json:"healthcheck_timeout"`
}
type chProxy struct {
ch consistentHash
sync.RWMutex
// TODO map[string][]time.Time
ded map[string]int64
hcInterval time.Duration
hcEndpoint string
hcUnhealthy int64
hcTimeout time.Duration
statMu sync.Mutex
stats []*stat
proxy *httputil.ReverseProxy
httpClient *http.Client
transport http.RoundTripper
}
type stat struct {
tim time.Time
latency time.Duration
host string
code uint64
}
func newProxy(conf config) *chProxy {
tranny := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 120 * time.Second,
}).Dial,
MaxIdleConnsPerHost: 512,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{
ClientSessionCache: tls.NewLRUClientSessionCache(4096),
},
}
ch := &chProxy{
ded: make(map[string]int64),
// XXX (reed): need to be reconfigurable at some point
hcInterval: time.Duration(conf.HealthcheckInterval) * time.Second,
hcEndpoint: conf.HealthcheckEndpoint,
hcUnhealthy: int64(conf.HealthcheckUnhealthy),
hcTimeout: time.Duration(conf.HealthcheckTimeout) * time.Second,
httpClient: &http.Client{Transport: tranny},
transport: tranny,
}
director := func(req *http.Request) {
target, err := ch.ch.get(req.URL.Path)
if err != nil {
target = "error"
}
req.URL.Scheme = "http" // XXX (reed): h2 support
req.URL.Host = target
}
ch.proxy = &httputil.ReverseProxy{
Director: director,
Transport: tranny,
BufferPool: newBufferPool(),
}
for _, n := range conf.Nodes {
// XXX (reed): need to health check these
ch.ch.add(n)
}
go ch.healthcheck()
return ch
}
func (ch *chProxy) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Host == "error" {
io.Copy(ioutil.Discard, req.Body)
req.Body.Close()
// XXX (reed): if we let the proxy code write the response it will be body-less. ok?
return nil, ErrNoNodes
}
resp, err := ch.transport.RoundTrip(req)
ch.intercept(req, resp)
return resp, err
}
func (ch *chProxy) intercept(req *http.Request, resp *http.Response) {
// XXX (reed): give f(x) nodes ability to self inspect load and send it back
// XXX (reed): we should prob clear this from user response
load, _ := strconv.Atoi(resp.Header.Get("XXX-FXLB-LOAD"))
// XXX (reed): need to validate these prob
ch.ch.setLoad(req.URL.Host, int64(load))
// XXX (reed): stats data
}
type bufferPool struct {
bufs *sync.Pool
}
func newBufferPool() httputil.BufferPool {
return &bufferPool{
bufs: &sync.Pool{
New: func() interface{} { return make([]byte, 32*1024) },
},
}
}
func (b *bufferPool) Get() []byte { return b.bufs.Get().([]byte) }
func (b *bufferPool) Put(x []byte) { b.bufs.Put(x) }
func (ch *chProxy) healthcheck() {
for range time.Tick(ch.hcInterval) {
nodes := ch.ch.list()
nodes = append(nodes, ch.dead()...)
// XXX (reed): need to figure out elegant adding / removing better
for _, n := range nodes {
go ch.ping(n)
}
}
}
func (ch *chProxy) ping(node string) {
req, _ := http.NewRequest("GET", "http://"+node+ch.hcEndpoint, nil)
ctx, cancel := context.WithTimeout(context.Background(), ch.hcTimeout)
defer cancel()
req = req.WithContext(ctx)
resp, err := ch.httpClient.Do(req)
if resp != nil && resp.Body != nil {
io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
}
if err != nil || resp.StatusCode < 200 || resp.StatusCode > 299 {
logrus.WithFields(logrus.Fields{"node": node}).Error("health check failed")
ch.fail(node)
} else {
ch.alive(node)
}
}
func (ch *chProxy) fail(node string) {
// shouldn't be a hot path so shouldn't be too contended on since health
// checks are infrequent
ch.Lock()
ch.ded[node]++
failed := ch.ded[node]
ch.Unlock()
if failed >= ch.hcUnhealthy {
ch.ch.remove(node) // TODO under lock?
}
}
func (ch *chProxy) alive(node string) {
ch.RLock()
_, ok := ch.ded[node]
ch.RUnlock()
if ok {
ch.Lock()
delete(ch.ded, node)
ch.Unlock()
ch.ch.add(node) // TODO under lock?
}
}
func (ch *chProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/1/lb/nodes" {
switch r.Method {
case "PUT":
ch.addNode(w, r)
return
case "DELETE":
ch.removeNode(w, r)
return
case "GET":
ch.listNodes(w, r)
return
}
// XXX (reed): stats?
// XXX (reed): probably do these on a separate port to avoid conflicts
}
ch.proxy.ServeHTTP(w, r)
// XXX (reed): for stats need our own transport
//ch.statsMu.Lock()
//ch.stats = append(ch.stats, &stat{
//host: r.URL.Host,
//}
//ch.stats = r.URL.Host
}
func (ch *chProxy) addNode(w http.ResponseWriter, r *http.Request) {
var bod struct {
Node string `json:"node"`
}
err := json.NewDecoder(r.Body).Decode(&bod)
if err != nil {
sendError(w, http.StatusBadRequest, err.Error())
return
}
ch.ch.add(bod.Node)
sendSuccess(w, "node added")
}
func (ch *chProxy) removeNode(w http.ResponseWriter, r *http.Request) {
var bod struct {
Node string `json:"node"`
}
err := json.NewDecoder(r.Body).Decode(&bod)
if err != nil {
sendError(w, http.StatusBadRequest, err.Error())
return
}
ch.ch.remove(bod.Node)
sendSuccess(w, "node deleted")
}
func (ch *chProxy) listNodes(w http.ResponseWriter, r *http.Request) {
nodes := ch.ch.list()
dead := ch.dead()
out := make(map[string]string, len(nodes)+len(dead))
for _, n := range nodes {
if ch.isDead(n) {
out[n] = "offline"
} else {
out[n] = "online"
}
}
for _, n := range dead {
out[n] = "offline"
}
sendValue(w, struct {
Nodes map[string]string `json:"nodes"`
}{
Nodes: out,
})
}
func (ch *chProxy) isDead(node string) bool {
ch.RLock()
val, ok := ch.ded[node]
ch.RUnlock()
return ok && val >= ch.hcUnhealthy
}
func (ch *chProxy) dead() []string {
ch.RLock()
defer ch.RUnlock()
nodes := make([]string, 0, len(ch.ded))
for n, val := range ch.ded {
if val >= ch.hcUnhealthy {
nodes = append(nodes, n)
}
}
return nodes
}
func sendValue(w http.ResponseWriter, v interface{}) {
err := json.NewEncoder(w).Encode(v)
if err != nil {
logrus.WithError(err).Error("error writing response response")
}
}
func sendSuccess(w http.ResponseWriter, msg string) {
err := json.NewEncoder(w).Encode(struct {
Msg string `json:"msg"`
}{
Msg: msg,
})
if err != nil {
logrus.WithError(err).Error("error writing response response")
}
}
func sendError(w http.ResponseWriter, code int, msg string) {
w.WriteHeader(code)
err := json.NewEncoder(w).Encode(struct {
Msg string `json:"msg"`
}{
Msg: msg,
})
if err != nil {
logrus.WithError(err).Error("error writing response response")
}
}
// consistentHash will maintain a list of strings which can be accessed by
// keying them with a separate group of strings
type consistentHash struct {
// protects nodes
sync.RWMutex
nodes []string
loadMu sync.RWMutex
load map[string]*int64
}
func (ch *consistentHash) add(newb string) {
ch.Lock()
defer ch.Unlock()
// filter dupes, under lock. sorted, so binary search
i := sort.SearchStrings(ch.nodes, newb)
if i < len(ch.nodes) && ch.nodes[i] == newb {
return
}
ch.nodes = append(ch.nodes, newb)
// need to keep in sorted order so that hash index works across nodes
sort.Sort(sort.StringSlice(ch.nodes))
}
func (ch *consistentHash) remove(ded string) {
ch.Lock()
i := sort.SearchStrings(ch.nodes, ded)
if i < len(ch.nodes) && ch.nodes[i] == ded {
ch.nodes = append(ch.nodes[:i], ch.nodes[i+1:]...)
}
ch.Unlock()
}
// return a copy
func (ch *consistentHash) list() []string {
ch.RLock()
ret := make([]string, len(ch.nodes))
copy(ret, ch.nodes)
ch.RUnlock()
return ret
}
func (ch *consistentHash) get(key string) (string, error) {
// crc not unique enough & sha is too slow, it's 1 import
sum64 := siphash.Hash(0, 0x4c617279426f6174, []byte(key))
ch.RLock()
defer ch.RUnlock()
i := int(jumpConsistentHash(sum64, int32(len(ch.nodes))))
return ch.besti(i)
}
// A Fast, Minimal Memory, Consistent Hash Algorithm:
// https://arxiv.org/ftp/arxiv/papers/1406/1406.2294.pdf
func jumpConsistentHash(key uint64, num_buckets int32) int32 {
var b, j int64 = -1, 0
for j < int64(num_buckets) {
b = j
key = key*2862933555777941757 + 1
j = (b + 1) * int64((1<<31)/(key>>33)+1)
}
return int32(b)
}
func (ch *consistentHash) setLoad(key string, load int64) {
ch.loadMu.RLock()
l, ok := ch.load[key]
ch.loadMu.RUnlock()
if ok {
atomic.StoreInt64(l, load)
} else {
ch.loadMu.Lock()
if _, ok := ch.load[key]; !ok {
ch.load[key] = &load
}
ch.loadMu.Unlock()
}
}
var (
ErrNoNodes = errors.New("no nodes available")
)
// XXX (reed): push down fails / load into ch
func (ch *consistentHash) besti(i int) (string, error) {
ch.RLock()
defer ch.RUnlock()
if len(ch.nodes) < 1 {
return "", ErrNoNodes
}
f := func(n string) string {
var load int64
ch.loadMu.RLock()
loadPtr := ch.load[n]
ch.loadMu.RUnlock()
if loadPtr != nil {
load = atomic.LoadInt64(loadPtr)
}
// TODO flesh out these values with some testing
// back off loaded nodes slightly to spread load
if load < 70 {
return n
} else if load > 90 {
// XXX (reed): seed rand
if rand.Intn(100) < 60 {
return n
}
} else if load > 70 {
if rand.Float64() < 80 {
return n
}
}
// otherwise loop until we find a sufficiently unloaded node or a lucky coin flip
return ""
}
for _, n := range ch.nodes[i:] {
node := f(n)
if node != "" {
return node, nil
}
}
// try the other half of the ring
for _, n := range ch.nodes[:i] {
node := f(n)
if node != "" {
return node, nil
}
}
return "", ErrNoNodes
}