diff --git a/fnlb/lb/allgrouper.go b/fnlb/lb/allgrouper.go index e23254cd7..bd7b36115 100644 --- a/fnlb/lb/allgrouper.go +++ b/fnlb/lb/allgrouper.go @@ -2,24 +2,40 @@ package lb import ( "context" + "database/sql" "encoding/json" + "errors" "io" "io/ioutil" "net/http" + "net/url" + "os" + "path/filepath" "sort" + "strings" "sync" "time" "github.com/Sirupsen/logrus" + "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx" + "github.com/lib/pq" + "github.com/mattn/go-sqlite3" ) // NewAllGrouper returns a Grouper that will return the entire list of nodes // that are being maintained, regardless of key. An 'AllGrouper' will health // check servers at a specified interval, taking them in and out as they // pass/fail and exposes endpoints for adding, removing and listing nodes. -func NewAllGrouper(conf Config) Grouper { +func NewAllGrouper(conf Config) (Grouper, error) { + db, err := db(conf.DBurl) + if err != nil { + return nil, err + } + a := &allGrouper{ ded: make(map[string]int64), + db: db, // XXX (reed): need to be reconfigurable at some point hcInterval: time.Duration(conf.HealthcheckInterval) * time.Second, @@ -31,18 +47,35 @@ func NewAllGrouper(conf Config) Grouper { httpClient: &http.Client{Transport: conf.Transport}, } for _, n := range conf.Nodes { - a.add(n) + err := a.add(n) + if err != nil { + // XXX (reed): could prob ignore these but meh + logrus.WithError(err).WithFields(logrus.Fields{"node": n}).Error("error adding node") + } } go a.healthcheck() - return a + return a, nil } -// TODO +// allGrouper will return all healthy nodes it is tracking from List. +// nodes may be added / removed through the HTTP api. each allGrouper will +// poll its database for the full list of nodes, and then run its own +// health checks on those nodes to maintain a list of healthy nodes. +// the list of healthy nodes will be maintained in sorted order so that, +// without any network partitions, all lbs may consistently hash with the +// same backing list, such that H(k) -> v for any k->v pair (vs attempting +// to maintain a list among nodes in the db, which could have thrashing +// due to network connectivity between any pair). type allGrouper struct { - // protects nodes & ded + // protects allNodes, healthy & ded sync.RWMutex - nodes []string - ded map[string]int64 // [node] -> failedCount + // TODO rename nodes to 'allNodes' or something so everything breaks and then stitch + // ded is the set of disjoint nodes nodes from intersecting nodes & healthy + allNodes, healthy []string + ded map[string]int64 // [node] -> failedCount + + // allNodes is a cache of db.List, we can probably trash it.. + db DBStore httpClient *http.Client @@ -52,44 +85,170 @@ type allGrouper struct { hcTimeout time.Duration } -func (a *allGrouper) add(newb string) { - if newb == "" { - return // we can't really do a lot of validation since hosts could be an ip or domain but we have health checks - } - a.Lock() - a.addNoLock(newb) - a.Unlock() +// TODO put this somewhere better +type DBStore interface { + Add(string) error + Delete(string) error + List() ([]string, error) } -func (a *allGrouper) addNoLock(newb string) { +// implements DBStore +type sqlStore struct { + db *sqlx.DB + + // TODO we should prepare all of the statements, rebind them + // and store them all here. +} + +// New will open the db specified by url, create any tables necessary +// and return a models.Datastore safe for concurrent usage. +func db(uri string) (DBStore, error) { + url, err := url.Parse(uri) + if err != nil { + return nil, err + } + + driver := url.Scheme + // driver must be one of these for sqlx to work, double check: + switch driver { + case "postgres", "pgx", "mysql", "sqlite3", "oci8", "ora", "goracle": + default: + return nil, errors.New("invalid db driver, refer to the code") + } + + if driver == "sqlite3" { + // make all the dirs so we can make the file.. + dir := filepath.Dir(url.Path) + err := os.MkdirAll(dir, 0755) + if err != nil { + return nil, err + } + } + + uri = url.String() + if driver != "postgres" { + // postgres seems to need this as a prefix in lib/pq, everyone else wants it stripped of scheme + uri = strings.TrimPrefix(url.String(), url.Scheme+"://") + } + + sqldb, err := sql.Open(driver, uri) + if err != nil { + logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't open db") + return nil, err + } + + db := sqlx.NewDb(sqldb, driver) + // force a connection and test that it worked + err = db.Ping() + if err != nil { + logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't ping db") + return nil, err + } + + maxIdleConns := 30 // c.MaxIdleConnections + db.SetMaxIdleConns(maxIdleConns) + logrus.WithFields(logrus.Fields{"max_idle_connections": maxIdleConns, "datastore": driver}).Info("datastore dialed") + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS lb_nodes ( + address text NOT NULL PRIMARY KEY + );`) + if err != nil { + return nil, err + } + + return &sqlStore{db: db}, nil +} + +func (s *sqlStore) Add(node string) error { + query := s.db.Rebind("INSERT INTO lb_nodes (address) VALUES (?);") + _, err := s.db.Exec(query, node) + if err != nil { + // if it already exists, just filter that error out + switch err := err.(type) { + case *mysql.MySQLError: + if err.Number == 1062 { + return nil + } + case *pq.Error: + if err.Code == "23505" { + return nil + } + case sqlite3.Error: + if err.ExtendedCode == sqlite3.ErrConstraintUnique || err.ExtendedCode == sqlite3.ErrConstraintPrimaryKey { + return nil + } + } + } + return err +} + +func (s *sqlStore) Delete(node string) error { + query := s.db.Rebind(`DELETE FROM lb_nodes WHERE address=?`) + _, err := s.db.Exec(query, node) + // TODO we can filter if it didn't exist, too... + return err +} + +func (s *sqlStore) List() ([]string, error) { + query := s.db.Rebind(`SELECT DISTINCT address FROM lb_nodes`) + rows, err := s.db.Query(query) + if err != nil { + return nil, err + } + + var nodes []string + for rows.Next() { + var node string + err := rows.Scan(&node) + if err == nil { + nodes = append(nodes, node) + } + } + + err = rows.Err() + if err == sql.ErrNoRows { + err = nil // don't care... + } + + return nodes, err +} + +func (a *allGrouper) add(newb string) error { + if newb == "" { + return nil // we can't really do a lot of validation since hosts could be an ip or domain but we have health checks + } + return a.db.Add(newb) +} + +func (a *allGrouper) remove(ded string) error { + return a.db.Delete(ded) +} + +// call with a.Lock held +func (a *allGrouper) addHealthy(newb string) { // filter dupes, under lock. sorted, so binary search - i := sort.SearchStrings(a.nodes, newb) - if i < len(a.nodes) && a.nodes[i] == newb { + i := sort.SearchStrings(a.healthy, newb) + if i < len(a.healthy) && a.healthy[i] == newb { return } - a.nodes = append(a.nodes, newb) + a.healthy = append(a.healthy, newb) // need to keep in sorted order so that hash index works across nodes - sort.Sort(sort.StringSlice(a.nodes)) + sort.Sort(sort.StringSlice(a.healthy)) } -func (a *allGrouper) remove(ded string) { - a.Lock() - a.removeNoLock(ded) - a.Unlock() -} - -func (a *allGrouper) removeNoLock(ded string) { - i := sort.SearchStrings(a.nodes, ded) - if i < len(a.nodes) && a.nodes[i] == ded { - a.nodes = append(a.nodes[:i], a.nodes[i+1:]...) +// call with a.Lock held +func (a *allGrouper) removeHealthy(ded string) { + i := sort.SearchStrings(a.healthy, ded) + if i < len(a.healthy) && a.healthy[i] == ded { + a.healthy = append(a.healthy[:i], a.healthy[i+1:]...) } } // return a copy func (a *allGrouper) List(string) ([]string, error) { a.RLock() - ret := make([]string, len(a.nodes)) - copy(ret, a.nodes) + ret := make([]string, len(a.healthy)) + copy(ret, a.healthy) a.RUnlock() var err error if len(ret) == 0 { @@ -100,9 +259,18 @@ func (a *allGrouper) List(string) ([]string, error) { func (a *allGrouper) healthcheck() { for range time.Tick(a.hcInterval) { - nodes, _ := a.List("") - nodes = append(nodes, a.dead()...) - for _, n := range nodes { + // health check the entire list of nodes [from db] + list, err := a.db.List() + if err != nil { + logrus.WithError(err).Error("error checking db for nodes") + continue + } + + a.Lock() + a.allNodes = list + a.Unlock() + + for _, n := range list { go a.ping(n) } } @@ -135,21 +303,19 @@ func (a *allGrouper) fail(node string) { a.ded[node]++ failed := a.ded[node] if failed >= a.hcUnhealthy { - a.removeNoLock(node) + a.removeHealthy(node) } a.Unlock() } func (a *allGrouper) alive(node string) { - a.RLock() - _, ok := a.ded[node] - a.RUnlock() - if ok { - a.Lock() - delete(a.ded, node) - a.addNoLock(node) - a.Unlock() - } + // TODO alive is gonna get called a lot, should maybe start w/ every node in ded + // so we can RLock (but lock contention should be low since these are ~quick) -- + // "a lot" being every 1s per node, so not too crazy really, but 1k nodes @ ms each... + a.Lock() + delete(a.ded, node) + a.addHealthy(node) + a.Unlock() } func (a *allGrouper) Wrap(next http.Handler) http.Handler { @@ -160,14 +326,12 @@ func (a *allGrouper) Wrap(next http.Handler) http.Handler { switch r.Method { case "PUT": a.addNode(w, r) - return case "DELETE": a.removeNode(w, r) - return case "GET": a.listNodes(w, r) - return } + return } next.ServeHTTP(w, r) @@ -184,7 +348,11 @@ func (a *allGrouper) addNode(w http.ResponseWriter, r *http.Request) { return } - a.add(bod.Node) + err = a.add(bod.Node) + if err != nil { + sendError(w, 500, err.Error()) // TODO filter ? + return + } sendSuccess(w, "node added") } @@ -198,15 +366,24 @@ func (a *allGrouper) removeNode(w http.ResponseWriter, r *http.Request) { return } - a.remove(bod.Node) + err = a.remove(bod.Node) + if err != nil { + sendError(w, 500, err.Error()) // TODO filter ? + return + } sendSuccess(w, "node deleted") } func (a *allGrouper) listNodes(w http.ResponseWriter, r *http.Request) { - nodes, _ := a.List("") - dead := a.dead() + a.RLock() + nodes := make([]string, len(a.allNodes)) + copy(nodes, a.allNodes) + a.RUnlock() - out := make(map[string]string, len(nodes)+len(dead)) + // TODO this isn't correct until at least one health check has hit all nodes (on start up). + // seems like not a huge deal, but here's a note anyway (every node will simply 'appear' healthy + // from this api even if we aren't routing to it [until first health check]). + out := make(map[string]string, len(nodes)) for _, n := range nodes { if a.isDead(n) { out[n] = "offline" @@ -215,10 +392,6 @@ func (a *allGrouper) listNodes(w http.ResponseWriter, r *http.Request) { } } - for _, n := range dead { - out[n] = "offline" - } - sendValue(w, struct { Nodes map[string]string `json:"nodes"` }{ @@ -232,15 +405,3 @@ func (a *allGrouper) isDead(node string) bool { a.RUnlock() return ok && val >= a.hcUnhealthy } - -func (a *allGrouper) dead() []string { - a.RLock() - defer a.RUnlock() - nodes := make([]string, 0, len(a.ded)) - for n, val := range a.ded { - if val >= a.hcUnhealthy { - nodes = append(nodes, n) - } - } - return nodes -} diff --git a/fnlb/lb/proxy.go b/fnlb/lb/proxy.go index 8a281a3d3..dfb413268 100644 --- a/fnlb/lb/proxy.go +++ b/fnlb/lb/proxy.go @@ -26,6 +26,7 @@ import ( // TODO TLS type Config struct { + DBurl string `json:"db_url"` Listen string `json:"port"` Nodes []string `json:"nodes"` HealthcheckInterval int `json:"healthcheck_interval"` diff --git a/fnlb/main.go b/fnlb/main.go index 94bd1adff..652f41f8b 100644 --- a/fnlb/main.go +++ b/fnlb/main.go @@ -24,6 +24,7 @@ func main() { fnodes := flag.String("nodes", "", "comma separated list of functions nodes") var conf lb.Config + flag.StringVar(&conf.DBurl, "db", "sqlite3://:memory:", "backend to store nodes, default to in memory") flag.StringVar(&conf.Listen, "listen", ":8081", "port to run on") flag.IntVar(&conf.HealthcheckInterval, "hc-interval", 3, "how often to check f(x) nodes, in seconds") flag.StringVar(&conf.HealthcheckEndpoint, "hc-path", "/version", "endpoint to determine node health") @@ -49,7 +50,11 @@ func main() { }, } - g := lb.NewAllGrouper(conf) + g, err := lb.NewAllGrouper(conf) + if err != nil { + logrus.WithError(err).Fatal("error setting up grouper") + } + r := lb.NewConsistentRouter(conf) k := func(r *http.Request) (string, error) { return r.URL.Path, nil @@ -59,9 +64,9 @@ func main() { h = g.Wrap(h) // add/del/list endpoints h = r.Wrap(h) // stats / dash endpoint - err := serve(conf.Listen, h) + err = serve(conf.Listen, h) if err != nil { - logrus.WithError(err).Error("server error") + logrus.WithError(err).Fatal("server error") } }