diff --git a/fnlb/lb/allgrouper.go b/fnlb/lb/allgrouper.go index 28f825c36..195bc7e6c 100644 --- a/fnlb/lb/allgrouper.go +++ b/fnlb/lb/allgrouper.go @@ -2,26 +2,16 @@ package lb import ( "context" - "database/sql" "encoding/json" - "errors" "io" "io/ioutil" "net/http" - "net/url" - "os" - "path/filepath" "sort" - "strings" "sync" "time" "fmt" "github.com/coreos/go-semver/semver" - "github.com/go-sql-driver/mysql" - "github.com/jmoiron/sqlx" - "github.com/lib/pq" - "github.com/mattn/go-sqlite3" "github.com/sirupsen/logrus" ) @@ -29,26 +19,24 @@ import ( // that are being maintained, regardless of key. An 'AllGrouper' will health // check servers at a specified interval, taking them in and out as they // pass/fail and exposes endpoints for adding, removing and listing nodes. -func NewAllGrouper(conf Config) (Grouper, error) { - db, err := db(conf.DBurl) - if err != nil { - return nil, err - } - +func NewAllGrouper(conf Config, db DBStore) (Grouper, error) { a := &allGrouper{ - ded: make(map[string]int64), - db: db, + nodeList: make(map[string]nodeState), + nodeHealthyList: make([]string, 0), + db: db, // XXX (reed): need to be reconfigurable at some point hcInterval: time.Duration(conf.HealthcheckInterval) * time.Second, hcEndpoint: conf.HealthcheckEndpoint, hcUnhealthy: int64(conf.HealthcheckUnhealthy), + hcHealthy: int64(conf.HealthcheckHealthy), hcTimeout: time.Duration(conf.HealthcheckTimeout) * time.Second, minAPIVersion: *conf.MinAPIVersion, // for health checks httpClient: &http.Client{Transport: conf.Transport}, } + for _, n := range conf.Nodes { err := a.add(n) if err != nil { @@ -60,6 +48,17 @@ func NewAllGrouper(conf Config) (Grouper, error) { return a, nil } +// nodeState is used to store success/fail counts and other health related data. +type nodeState struct { + + // num of consecutive successes & failures + success uint64 + fail uint64 + + // current health state + healthy bool +} + // allGrouper will return all healthy nodes it is tracking from List. // nodes may be added / removed through the HTTP api. each allGrouper will // poll its database for the full list of nodes, and then run its own @@ -70,14 +69,12 @@ func NewAllGrouper(conf Config) (Grouper, error) { // to maintain a list among nodes in the db, which could have thrashing // due to network connectivity between any pair). type allGrouper struct { - // protects allNodes, healthy & ded - sync.RWMutex - // TODO rename nodes to 'allNodes' or something so everything breaks and then stitch - // ded is the set of disjoint nodes nodes from intersecting nodes & healthy - allNodes, healthy []string - ded map[string]int64 // [node] -> failedCount - // allNodes is a cache of db.List, we can probably trash it.. + // health checker state and lock + nodeLock sync.RWMutex + nodeList map[string]nodeState + nodeHealthyList []string + db DBStore httpClient *http.Client @@ -85,138 +82,11 @@ type allGrouper struct { hcInterval time.Duration hcEndpoint string hcUnhealthy int64 + hcHealthy int64 hcTimeout time.Duration minAPIVersion semver.Version } -// TODO put this somewhere better -type DBStore interface { - Add(string) error - Delete(string) error - List() ([]string, error) -} - -// implements DBStore -type sqlStore struct { - db *sqlx.DB - - // TODO we should prepare all of the statements, rebind them - // and store them all here. -} - -// New will open the db specified by url, create any tables necessary -// and return a models.Datastore safe for concurrent usage. -func db(uri string) (DBStore, error) { - url, err := url.Parse(uri) - if err != nil { - return nil, err - } - - driver := url.Scheme - // driver must be one of these for sqlx to work, double check: - switch driver { - case "postgres", "pgx", "mysql", "sqlite3", "oci8", "ora", "goracle": - default: - return nil, errors.New("invalid db driver, refer to the code") - } - - if driver == "sqlite3" { - // make all the dirs so we can make the file.. - dir := filepath.Dir(url.Path) - err := os.MkdirAll(dir, 0755) - if err != nil { - return nil, err - } - } - - uri = url.String() - if driver != "postgres" { - // postgres seems to need this as a prefix in lib/pq, everyone else wants it stripped of scheme - uri = strings.TrimPrefix(url.String(), url.Scheme+"://") - } - - sqldb, err := sql.Open(driver, uri) - if err != nil { - logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't open db") - return nil, err - } - - db := sqlx.NewDb(sqldb, driver) - // force a connection and test that it worked - err = db.Ping() - if err != nil { - logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't ping db") - return nil, err - } - - maxIdleConns := 30 // c.MaxIdleConnections - db.SetMaxIdleConns(maxIdleConns) - logrus.WithFields(logrus.Fields{"max_idle_connections": maxIdleConns, "datastore": driver}).Info("datastore dialed") - - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS lb_nodes ( - address text NOT NULL PRIMARY KEY - );`) - if err != nil { - return nil, err - } - - return &sqlStore{db: db}, nil -} - -func (s *sqlStore) Add(node string) error { - query := s.db.Rebind("INSERT INTO lb_nodes (address) VALUES (?);") - _, err := s.db.Exec(query, node) - if err != nil { - // if it already exists, just filter that error out - switch err := err.(type) { - case *mysql.MySQLError: - if err.Number == 1062 { - return nil - } - case *pq.Error: - if err.Code == "23505" { - return nil - } - case sqlite3.Error: - if err.ExtendedCode == sqlite3.ErrConstraintUnique || err.ExtendedCode == sqlite3.ErrConstraintPrimaryKey { - return nil - } - } - } - return err -} - -func (s *sqlStore) Delete(node string) error { - query := s.db.Rebind(`DELETE FROM lb_nodes WHERE address=?`) - _, err := s.db.Exec(query, node) - // TODO we can filter if it didn't exist, too... - return err -} - -func (s *sqlStore) List() ([]string, error) { - query := s.db.Rebind(`SELECT DISTINCT address FROM lb_nodes`) - rows, err := s.db.Query(query) - if err != nil { - return nil, err - } - - var nodes []string - for rows.Next() { - var node string - err := rows.Scan(&node) - if err == nil { - nodes = append(nodes, node) - } - } - - err = rows.Err() - if err == sql.ErrNoRows { - err = nil // don't care... - } - - return nodes, err -} - func (a *allGrouper) add(newb string) error { if newb == "" { return nil // we can't really do a lot of validation since hosts could be an ip or domain but we have health checks @@ -232,32 +102,33 @@ func (a *allGrouper) remove(ded string) error { return a.db.Delete(ded) } -// call with a.Lock held -func (a *allGrouper) addHealthy(newb string) { - // filter dupes, under lock. sorted, so binary search - i := sort.SearchStrings(a.healthy, newb) - if i < len(a.healthy) && a.healthy[i] == newb { - return - } - a.healthy = append(a.healthy, newb) - // need to keep in sorted order so that hash index works across nodes - sort.Sort(sort.StringSlice(a.healthy)) -} +func (a *allGrouper) publishHealth() { -// call with a.Lock held -func (a *allGrouper) removeHealthy(ded string) { - i := sort.SearchStrings(a.healthy, ded) - if i < len(a.healthy) && a.healthy[i] == ded { - a.healthy = append(a.healthy[:i], a.healthy[i+1:]...) + a.nodeLock.Lock() + + // get a list of healthy nodes + newList := make([]string, 0, len(a.nodeList)) + for key, value := range a.nodeList { + if value.healthy { + newList = append(newList, key) + } } + + // sort and update healthy List + sort.Strings(newList) + a.nodeHealthyList = newList + + a.nodeLock.Unlock() } // return a copy func (a *allGrouper) List(string) ([]string, error) { - a.RLock() - ret := make([]string, len(a.healthy)) - copy(ret, a.healthy) - a.RUnlock() + + a.nodeLock.RLock() + ret := make([]string, len(a.nodeHealthyList)) + copy(ret, a.nodeHealthyList) + a.nodeLock.RUnlock() + var err error if len(ret) == 0 { err = ErrNoNodes @@ -265,22 +136,76 @@ func (a *allGrouper) List(string) ([]string, error) { return ret, err } +func (a *allGrouper) runHealthCheck() { + + // fetch a list of nodes from DB + list, err := a.db.List() + if err != nil { + // if DB fails, the show must go on, report it but perform HC + logrus.WithError(err).Error("error checking db for nodes") + + // compile a list of nodes to be health checked + a.nodeLock.RLock() + list = make([]string, 0, len(a.nodeList)) + for key, _ := range a.nodeList { + list = append(list, key) + } + a.nodeLock.RUnlock() + + } else { + + isChanged := false + + // compile a map of DB nodes for deletion check + deleteCheck := make(map[string]bool, len(list)) + for _, node := range list { + deleteCheck[node] = true + } + + a.nodeLock.Lock() + + // handle new nodes + for _, node := range list { + _, ok := a.nodeList[node] + if !ok { + // add new node + a.nodeList[node] = nodeState{ + healthy: true, + } + isChanged = true + } + } + + // handle deleted nodes: purge unmarked nodes + for key, _ := range a.nodeList { + _, ok := deleteCheck[key] + if !ok { + delete(a.nodeList, key) + isChanged = true + } + } + + a.nodeLock.Unlock() + + // publish if add/deleted nodes + if isChanged { + a.publishHealth() + } + } + + // spawn health checkers + for _, key := range list { + go a.ping(key) + } +} + func (a *allGrouper) healthcheck() { + + // run hc immediately upon startup + a.runHealthCheck() + for range time.Tick(a.hcInterval) { - // health check the entire list of nodes [from db] - list, err := a.db.List() - if err != nil { - logrus.WithError(err).Error("error checking db for nodes") - continue - } - - a.Lock() - a.allNodes = list - a.Unlock() - - for _, n := range list { - go a.ping(n) - } + a.runHealthCheck() } } @@ -312,14 +237,18 @@ func (a *allGrouper) getVersion(urlString string) (string, error) { } func (a *allGrouper) checkAPIVersion(node string) error { - versionURL := "http://" + node + "/version" + versionURL := "http://" + node + a.hcEndpoint version, err := a.getVersion(versionURL) if err != nil { return err } - nodeVer := semver.New(version) + nodeVer, err := semver.NewVersion(version) + if err != nil { + return err + } + if nodeVer.LessThan(a.minAPIVersion) { return fmt.Errorf("incompatible API version: %v", nodeVer) } @@ -336,32 +265,79 @@ func (a *allGrouper) ping(node string) { } } -func (a *allGrouper) fail(node string) { - // shouldn't be a hot path so shouldn't be too contended on since health - // checks are infrequent - a.Lock() - a.ded[node]++ - failed := a.ded[node] - if failed >= a.hcUnhealthy { - a.removeHealthy(node) +func (a *allGrouper) fail(key string) { + + isChanged := false + + a.nodeLock.Lock() + + // if deleted, skip + node, ok := a.nodeList[key] + if !ok { + a.nodeLock.Unlock() + return + } + + node.success = 0 + node.fail++ + + // overflow case + if node.fail == 0 { + node.fail = uint64(a.hcUnhealthy) + } + + if node.healthy && node.fail >= uint64(a.hcUnhealthy) { + node.healthy = false + isChanged = true + } + + a.nodeList[key] = node + a.nodeLock.Unlock() + + if isChanged { + logrus.WithFields(logrus.Fields{"node": key}).Info("is unhealthy") + a.publishHealth() } - a.Unlock() } -func (a *allGrouper) alive(node string) { - // TODO alive is gonna get called a lot, should maybe start w/ every node in ded - // so we can RLock (but lock contention should be low since these are ~quick) -- - // "a lot" being every 1s per node, so not too crazy really, but 1k nodes @ ms each... - a.Lock() - delete(a.ded, node) - a.addHealthy(node) - a.Unlock() +func (a *allGrouper) alive(key string) { + + isChanged := false + + a.nodeLock.Lock() + + // if deleted, skip + node, ok := a.nodeList[key] + if !ok { + a.nodeLock.Unlock() + return + } + + node.fail = 0 + node.success++ + + // overflow case + if node.success == 0 { + node.success = uint64(a.hcHealthy) + } + + if !node.healthy && node.success >= uint64(a.hcHealthy) { + node.healthy = true + isChanged = true + } + + a.nodeList[key] = node + a.nodeLock.Unlock() + + if isChanged { + logrus.WithFields(logrus.Fields{"node": key}).Info("is healthy") + a.publishHealth() + } } func (a *allGrouper) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - // XXX (reed): probably do these on a separate port to avoid conflicts case "/1/lb/nodes": switch r.Method { case "PUT": @@ -415,33 +391,24 @@ func (a *allGrouper) removeNode(w http.ResponseWriter, r *http.Request) { } func (a *allGrouper) listNodes(w http.ResponseWriter, r *http.Request) { - a.RLock() - nodes := make([]string, len(a.allNodes)) - copy(nodes, a.allNodes) - a.RUnlock() - // TODO this isn't correct until at least one health check has hit all nodes (on start up). - // seems like not a huge deal, but here's a note anyway (every node will simply 'appear' healthy - // from this api even if we aren't routing to it [until first health check]). - out := make(map[string]string, len(nodes)) - for _, n := range nodes { - if a.isDead(n) { - out[n] = "offline" + a.nodeLock.RLock() + + out := make(map[string]string, len(a.nodeList)) + + for key, value := range a.nodeList { + if value.healthy { + out[key] = "online" } else { - out[n] = "online" + out[key] = "offline" } } + a.nodeLock.RUnlock() + sendValue(w, struct { Nodes map[string]string `json:"nodes"` }{ Nodes: out, }) } - -func (a *allGrouper) isDead(node string) bool { - a.RLock() - val, ok := a.ded[node] - a.RUnlock() - return ok && val >= a.hcUnhealthy -} diff --git a/fnlb/lb/allgrouper_test.go b/fnlb/lb/allgrouper_test.go new file mode 100644 index 000000000..6c11d831f --- /dev/null +++ b/fnlb/lb/allgrouper_test.go @@ -0,0 +1,480 @@ +package lb + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "sort" + "testing" + "time" + + "github.com/coreos/go-semver/semver" +) + +type mockDB struct { + isAddError bool + isDeleteError bool + isListError bool + nodeList map[string]bool +} + +func (mock *mockDB) Add(node string) error { + if mock.isAddError { + return errors.New("simulated add error") + } + mock.nodeList[node] = true + return nil +} +func (mock *mockDB) Delete(node string) error { + if mock.isDeleteError { + return errors.New("simulated delete error") + } + delete(mock.nodeList, node) + return nil +} +func (mock *mockDB) List() ([]string, error) { + if mock.isListError { + return nil, errors.New("simulated list error") + } + list := make([]string, 0, len(mock.nodeList)) + for key, _ := range mock.nodeList { + list = append(list, key) + } + return list, nil +} + +func initializeRunner() (Grouper, error) { + db := &mockDB{ + nodeList: make(map[string]bool), + } + + conf := Config{ + HealthcheckInterval: 1, + HealthcheckEndpoint: "/version", + HealthcheckUnhealthy: 1, + HealthcheckHealthy: 1, + HealthcheckTimeout: 1, + MinAPIVersion: semver.New("0.0.104"), + Transport: &http.Transport{}, + } + + return NewAllGrouper(conf, db) +} + +type testServer struct { + addr string + version string + healthy bool + inPool bool + listener *net.Listener + server *http.Server +} + +func (s *testServer) getHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "/version" == r.URL.Path { + if s.healthy { + sendValue(w, fnVersion{Version: s.version}) + } else { + sendError(w, http.StatusServiceUnavailable, "service unhealthy") + } + } else { + sendError(w, http.StatusNotFound, "unknown uri") + } + }) +} + +// return a list of supposed to be healthy (good version and in pool) nodes +func getCurrentHealthySet(list []*testServer) []string { + + out := make([]string, 0) + + for _, val := range list { + if val.healthy && val.inPool { + out = append(out, val.addr) + } + } + + sort.Strings(out) + return out +} + +// shutdown a server +func teardownServer(t *testing.T, server *http.Server) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(100)*time.Millisecond) + defer cancel() + + err := server.Shutdown(ctx) + if err != nil { + t.Logf("shutdown error: %s", err.Error()) + } +} + +// spin up backend servers +func initializeAPIServer(t *testing.T, grouper Grouper) (*http.Server, string, error) { + + listener, err := net.Listen("tcp", ":0") + if err != nil { + return nil, "", err + } + + addr := listener.Addr().String() + handler := NullHandler() + handler = grouper.Wrap(handler) // add/del/list endpoints + server := &http.Server{Handler: handler} + + go func(srv *http.Server, addr string) { + err := server.Serve(listener) + if err != nil { + t.Logf("server exited %s with %s", addr, err.Error()) + } + }(server, addr) + + return server, addr, nil +} + +// spin up backend servers +func initializeTestServers(t *testing.T, numOfServers uint64) ([]*testServer, error) { + + list := make([]*testServer, 0) + + for i := uint64(0); i < numOfServers; i++ { + + listener, err := net.Listen("tcp", ":0") + if err != nil { + return list, err + } + + server := &testServer{ + addr: listener.Addr().String(), + version: "0.0.104", + healthy: true, + inPool: false, + listener: &listener, + } + + server.server = &http.Server{Handler: server.getHandler()} + + go func(srv *testServer) { + err := srv.server.Serve(listener) + if err != nil { + t.Logf("server exited %s with %s", srv.addr, err.Error()) + } + }(server) + + list = append(list, server) + } + + return list, nil +} + +// tear down backend servers +func shutdownTestServers(t *testing.T, servers []*testServer) { + for _, srv := range servers { + teardownServer(t, srv.server) + } +} + +func testCompare(t *testing.T, grouper Grouper, servers []*testServer, ctx string) { + + // compare current supposed to be healthy VS healthy list from allGrouper + current := getCurrentHealthySet(servers) + t.Logf("%s Expecting healthy servers %v", ctx, current) + + round, err := grouper.List("ignore") + if err != nil { + if len(current) != 0 { + t.Errorf("%s Not expected error %s", ctx, err.Error()) + } + } else { + t.Logf("%s Detected healthy servers %v", ctx, round) + + if len(current) != len(round) { + t.Errorf("%s Got %d servers, expected: %d", ctx, len(round), len(current)) + } + for idx, srv := range round { + if srv != current[idx] { + t.Errorf("%s Mismatch idx: %d %s != %s", ctx, idx, srv, current[idx]) + } + } + } +} + +// using mgmt API modify (add/remove) a node +func mgmtModServer(t *testing.T, addr string, operation string, node string) error { + client := &http.Client{} + url := "http://" + addr + "/1/lb/nodes" + + str := fmt.Sprintf("{\"Node\":\"%s\"}", node) + body := []byte(str) + req, err := http.NewRequest(operation, url, bytes.NewBuffer(body)) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + + t.Logf("%s node=%s response=%s status=%d", operation, node, respBody, resp.StatusCode) + + if resp.StatusCode != 200 { + return fmt.Errorf("%s node=%s status=%d %s", operation, node, resp.StatusCode, respBody) + } + + return nil +} + +// using mgmt api list servers and compare with test server list +func mgmtListServers(t *testing.T, addr string, servers []*testServer) error { + client := &http.Client{} + url := "http://" + addr + "/1/lb/nodes" + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return fmt.Errorf("list status=%d", resp.StatusCode) + } + + var listResp struct { + Nodes map[string]string `json:"nodes"` + } + + err = json.NewDecoder(resp.Body).Decode(&listResp) + if err != nil { + return err + } + + cmpList := make(map[string]string, len(servers)) + for _, val := range servers { + if val.inPool { + if val.healthy { + cmpList[val.addr] = "online" + } else { + cmpList[val.addr] = "offline" + } + } + } + + t.Logf("list response=%v expected=%v", listResp.Nodes, cmpList) + + for key, val1 := range cmpList { + val2, ok := listResp.Nodes[key] + if !ok { + t.Errorf("failed list comparison node=`%s` is not in received", key) + return nil + + } + + if val1 != val2 { + t.Errorf("failed list comparison node=`%s` expected=`%s` received=`%s`", key, val1, val2) + return nil + } + + delete(cmpList, key) + delete(listResp.Nodes, key) + } + + if len(cmpList) != 0 || len(listResp.Nodes) != 0 { + t.Errorf("failed list comparison (remaining unmatches) expected=`%v` received=`%v`", cmpList, listResp.Nodes) + } + + return nil +} + +// Basic tests via DB add/remove functions +func TestRouteRunnerExecution(t *testing.T) { + + a, err := initializeRunner() + if err != nil { + t.Errorf("Not expected error `%s`", err.Error()) + } + + var concrete *allGrouper + concrete = a.(*allGrouper) + + // initialize and add some servers (all healthy) + serverCount := 10 + servers, err := initializeTestServers(t, uint64(serverCount)) + if err != nil { + t.Errorf("Not expected error `%s`", err.Error()) + } else { + defer shutdownTestServers(t, servers) + + if serverCount != len(servers) { + t.Errorf("Got %d servers, expected: %d", len(servers), serverCount) + } + + srvList := make([]string, 0, len(servers)) + for _, srv := range servers { + srvList = append(srvList, srv.addr) + } + t.Logf("Spawned servers %s", srvList) + + testCompare(t, a, servers, "round0") + + for _, srv := range servers { + // add these servers to allGrouper + err := concrete.add(srv.addr) + if err != nil { + t.Errorf("Not expected error `%s` when adding `%s`", err.Error(), srv.addr) + } + srv.inPool = true + } + } + + t.Logf("Starting round1 all servers healthy") + + // let health checker converge + time.Sleep(time.Duration(2) * time.Second) + testCompare(t, a, servers, "round1") + + t.Logf("Starting round2 one server unhealthy") + + // now set one server unhealthy + servers[2].healthy = false + t.Logf("Setting server %s to unhealthy", servers[2].addr) + + // let health checker converge + time.Sleep(time.Duration(2) * time.Second) + testCompare(t, a, servers, "round2") + + t.Logf("Starting round3 remove one server from grouper") + + t.Logf("Removing server %s from grouper", servers[3].addr) + err = concrete.remove(servers[3].addr) + if err != nil { + t.Errorf("Not expected error `%s` when removing `%s`", err.Error(), servers[3].addr) + } + servers[3].inPool = false + + time.Sleep(time.Duration(2) * time.Second) + testCompare(t, a, servers, "round3") + + t.Logf("Starting round4 add server back to grouper") + + t.Logf("Adding server %s to grouper", servers[3].addr) + err = concrete.add(servers[3].addr) + if err != nil { + t.Errorf("Not expected error `%s` when adding `%s`", err.Error(), servers[3].addr) + } + servers[3].inPool = true + + time.Sleep(time.Duration(2) * time.Second) + testCompare(t, a, servers, "round4") + + t.Logf("Starting round5 set unhealthy server back to healthy") + servers[2].healthy = true + t.Logf("Setting server %s to healthy", servers[2].addr) + + // let health checker converge + time.Sleep(time.Duration(2) * time.Second) + testCompare(t, a, servers, "round5") + + t.Logf("Starting round6 no change") + // fetch list again + testCompare(t, a, servers, "round6") +} + +// Basic tests via mgmt API +func TestRouteRunnerMgmtAPI(t *testing.T) { + + a, err := initializeRunner() + if err != nil { + t.Errorf("Not expected error `%s`", err.Error()) + } + + mgmtSrv, mgmtAddr, err := initializeAPIServer(t, a) + if err != nil { + t.Errorf("cannot start mgmt api server `%s`", err.Error()) + } + defer teardownServer(t, mgmtSrv) + + // initialize and add some servers (all healthy) + serverCount := 5 + servers, err := initializeTestServers(t, uint64(serverCount)) + if err != nil { + t.Errorf("Not expected error `%s`", err.Error()) + } else { + defer shutdownTestServers(t, servers) + + if serverCount != len(servers) { + t.Errorf("Got %d servers, expected: %d", len(servers), serverCount) + } + + srvList := make([]string, 0, len(servers)) + for _, srv := range servers { + srvList = append(srvList, srv.addr) + } + t.Logf("Spawned servers %s", srvList) + + testCompare(t, a, servers, "round0") + + for _, srv := range servers { + err := mgmtModServer(t, mgmtAddr, "PUT", srv.addr) + if err != nil { + t.Errorf("Not expected error `%s` when adding `%s`", err.Error(), srv.addr) + } + srv.inPool = true + } + } + + t.Logf("Starting round1 all servers healthy") + + // let health checker converge + time.Sleep(time.Duration(2) * time.Second) + testCompare(t, a, servers, "round1") + + err = mgmtListServers(t, mgmtAddr, servers) + if err != nil { + t.Errorf("Not expected error `%s` when listing", err.Error()) + } + + t.Logf("Starting round2 remove one server from grouper") + + // let's set server at 2 as unhealthy as well + servers[2].healthy = false + + t.Logf("Removing server %s from grouper", servers[3].addr) + err = mgmtModServer(t, mgmtAddr, "DELETE", servers[3].addr) + if err != nil { + t.Errorf("Not expected error `%s` when removing `%s`", err.Error(), servers[3].addr) + } + servers[3].inPool = false + + time.Sleep(time.Duration(2) * time.Second) + testCompare(t, a, servers, "round2") + + err = mgmtListServers(t, mgmtAddr, servers) + if err != nil { + t.Errorf("Not expected error `%s` when listing", err.Error()) + } +} + +// TODO: test old version case +// TODO: test DB unhealthy case +// TODO: test healthy/unhealthy thresholds +// TODO: test health check timeout case diff --git a/fnlb/lb/ch.go b/fnlb/lb/ch.go index 5cadda6d2..638418d3e 100644 --- a/fnlb/lb/ch.go +++ b/fnlb/lb/ch.go @@ -150,61 +150,61 @@ func loadKey(node, key string) string { return node + "\x00" + key } +func (ch *chRouter) checkLoad(key, n string) bool { + var load time.Duration + ch.loadMu.RLock() + loadPtr := ch.load[loadKey(n, key)] + ch.loadMu.RUnlock() + if loadPtr != nil { + load = time.Duration(atomic.LoadInt64(loadPtr)) + } + + const ( + // TODO we should probably use deltas rather than fixed wait times. for 'cold' + // functions these could always trigger. i.e. if wait time increased 5x over last + // 100 data points, point the cannon elsewhere (we'd have to track 2 numbers but meh) + lowerLat = 500 * time.Millisecond + upperLat = 2 * time.Second + ) + + // TODO flesh out these values. + // if we send < 50% of traffic off to other nodes when loaded + // then as function scales nodes will get flooded, need to be careful. + // + // back off loaded node/function combos slightly to spread load + if load < lowerLat { + return true + } else if load > upperLat { + // really loaded + if ch.rng.Intn(100) < 10 { // XXX (reed): 10% could be problematic, should sliding scale prob with log(x) ? + return true + } + } else { + // 10 < x < 40, as load approaches upperLat, x decreases [linearly] + x := translate(int64(load), int64(lowerLat), int64(upperLat), 10, 40) + if ch.rng.Intn(100) < x { + return true + } + } + + // return invalid node to try next node + return false +} + func (ch *chRouter) besti(key string, i int, nodes []string) (string, error) { if len(nodes) < 1 { // supposed to be caught in grouper, but double check return "", ErrNoNodes } - // XXX (reed): trash the closure - f := func(n string) string { - var load time.Duration - ch.loadMu.RLock() - loadPtr := ch.load[loadKey(n, key)] - ch.loadMu.RUnlock() - if loadPtr != nil { - load = time.Duration(atomic.LoadInt64(loadPtr)) - } - - const ( - // TODO we should probably use deltas rather than fixed wait times. for 'cold' - // functions these could always trigger. i.e. if wait time increased 5x over last - // 100 data points, point the cannon elsewhere (we'd have to track 2 numbers but meh) - lowerLat = 500 * time.Millisecond - upperLat = 2 * time.Second - ) - - // TODO flesh out these values. - // if we send < 50% of traffic off to other nodes when loaded - // then as function scales nodes will get flooded, need to be careful. - // - // back off loaded node/function combos slightly to spread load - if load < lowerLat { - return n - } else if load > upperLat { - // really loaded - if ch.rng.Intn(100) < 10 { // XXX (reed): 10% could be problematic, should sliding scale prob with log(x) ? - return n - } - } else { - // 10 < x < 40, as load approaches upperLat, x decreases [linearly] - x := translate(int64(load), int64(lowerLat), int64(upperLat), 10, 40) - if ch.rng.Intn(100) < x { - return n - } - } - - // return invalid node to try next node - return "" - } - for ; ; i++ { // theoretically this could take infinite time, but practically improbable... // TODO we need a way to add a node for a given key from down here if a node is overloaded. - node := f(nodes[i]) - if node != "" { - return node, nil - } else if i == len(nodes)-1 { + if ch.checkLoad(key, nodes[i]) { + return nodes[i], nil + } + + if i == len(nodes)-1 { i = -1 // reset i to 0 } } @@ -221,7 +221,6 @@ func translate(val, inFrom, inTo, outFrom, outTo int64) int { func (ch *chRouter) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - // XXX (reed): probably do these on a separate port to avoid conflicts case "/1/lb/stats": ch.statsGet(w, r) return diff --git a/fnlb/lb/db.go b/fnlb/lb/db.go new file mode 100644 index 000000000..8dc0b1daa --- /dev/null +++ b/fnlb/lb/db.go @@ -0,0 +1,153 @@ +package lb + +import ( + "database/sql" + "errors" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx" + "github.com/lib/pq" + "github.com/mattn/go-sqlite3" + "github.com/sirupsen/logrus" +) + +func NewDB(conf Config) (DBStore, error) { + db, err := db(conf.DBurl) + if err != nil { + return nil, err + } + + return db, err +} + +// TODO put this somewhere better +type DBStore interface { + Add(string) error + Delete(string) error + List() ([]string, error) +} + +// implements DBStore +type sqlStore struct { + db *sqlx.DB + + // TODO we should prepare all of the statements, rebind them + // and store them all here. +} + +// New will open the db specified by url, create any tables necessary +// and return a models.Datastore safe for concurrent usage. +func db(uri string) (DBStore, error) { + url, err := url.Parse(uri) + if err != nil { + return nil, err + } + + driver := url.Scheme + // driver must be one of these for sqlx to work, double check: + switch driver { + case "postgres", "pgx", "mysql", "sqlite3", "oci8", "ora", "goracle": + default: + return nil, errors.New("invalid db driver, refer to the code") + } + + if driver == "sqlite3" { + // make all the dirs so we can make the file.. + dir := filepath.Dir(url.Path) + err := os.MkdirAll(dir, 0755) + if err != nil { + return nil, err + } + } + + uri = url.String() + if driver != "postgres" { + // postgres seems to need this as a prefix in lib/pq, everyone else wants it stripped of scheme + uri = strings.TrimPrefix(url.String(), url.Scheme+"://") + } + + sqldb, err := sql.Open(driver, uri) + if err != nil { + logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't open db") + return nil, err + } + + db := sqlx.NewDb(sqldb, driver) + // force a connection and test that it worked + err = db.Ping() + if err != nil { + logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't ping db") + return nil, err + } + + maxIdleConns := 30 // c.MaxIdleConnections + db.SetMaxIdleConns(maxIdleConns) + logrus.WithFields(logrus.Fields{"max_idle_connections": maxIdleConns, "datastore": driver}).Info("datastore dialed") + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS lb_nodes ( + address text NOT NULL PRIMARY KEY + );`) + if err != nil { + return nil, err + } + + return &sqlStore{db: db}, nil +} + +func (s *sqlStore) Add(node string) error { + query := s.db.Rebind("INSERT INTO lb_nodes (address) VALUES (?);") + _, err := s.db.Exec(query, node) + if err != nil { + // if it already exists, just filter that error out + switch err := err.(type) { + case *mysql.MySQLError: + if err.Number == 1062 { + return nil + } + case *pq.Error: + if err.Code == "23505" { + return nil + } + case sqlite3.Error: + if err.ExtendedCode == sqlite3.ErrConstraintUnique || err.ExtendedCode == sqlite3.ErrConstraintPrimaryKey { + return nil + } + } + } + return err +} + +func (s *sqlStore) Delete(node string) error { + query := s.db.Rebind(`DELETE FROM lb_nodes WHERE address=?`) + _, err := s.db.Exec(query, node) + // TODO we can filter if it didn't exist, too... + return err +} + +func (s *sqlStore) List() ([]string, error) { + query := s.db.Rebind(`SELECT DISTINCT address FROM lb_nodes`) + rows, err := s.db.Query(query) + if err != nil { + return nil, err + } + + var nodes []string + for rows.Next() { + var node string + err := rows.Scan(&node) + if err == nil { + nodes = append(nodes, node) + } + } + + err = rows.Err() + if err == sql.ErrNoRows { + err = nil // don't care... + } + + return nodes, err +} diff --git a/fnlb/lb/proxy.go b/fnlb/lb/proxy.go index 3105f3e2e..148249fe7 100644 --- a/fnlb/lb/proxy.go +++ b/fnlb/lb/proxy.go @@ -32,11 +32,14 @@ import ( type Config struct { DBurl string `json:"db_url"` Listen string `json:"port"` + MgmtListen string `json:"mgmt_port"` + ShutdownTimeout int `json:"shutdown_timeout"` ZipkinURL string `json:"zipkin_url"` Nodes []string `json:"nodes"` HealthcheckInterval int `json:"healthcheck_interval"` HealthcheckEndpoint string `json:"healthcheck_endpoint"` HealthcheckUnhealthy int `json:"healthcheck_unhealthy"` + HealthcheckHealthy int `json:"healthcheck_healthy"` HealthcheckTimeout int `json:"healthcheck_timeout"` MinAPIVersion *semver.Version `json:"min_api_version"` diff --git a/fnlb/lb/util.go b/fnlb/lb/util.go index fb81273a0..2e8fcaeb5 100644 --- a/fnlb/lb/util.go +++ b/fnlb/lb/util.go @@ -9,7 +9,8 @@ import ( ) var ( - ErrNoNodes = errors.New("no nodes available") + ErrNoNodes = errors.New("no nodes available") + ErrUnknownCommand = errors.New("unknown command") ) func sendValue(w http.ResponseWriter, v interface{}) { @@ -45,3 +46,9 @@ func sendError(w http.ResponseWriter, code int, msg string) { logrus.WithError(err).Error("error writing response response") } } + +func NullHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sendError(w, http.StatusNotFound, ErrUnknownCommand.Error()) + }) +} diff --git a/fnlb/main.go b/fnlb/main.go index 1db100e84..ca57ff687 100644 --- a/fnlb/main.go +++ b/fnlb/main.go @@ -27,9 +27,12 @@ func main() { var conf lb.Config flag.StringVar(&conf.DBurl, "db", "sqlite3://:memory:", "backend to store nodes, default to in memory") flag.StringVar(&conf.Listen, "listen", ":8081", "port to run on") + flag.StringVar(&conf.MgmtListen, "mgmt-listen", ":8081", "management port to run on") + flag.IntVar(&conf.ShutdownTimeout, "shutdown-timeout", 0, "graceful shutdown timeout") flag.IntVar(&conf.HealthcheckInterval, "hc-interval", 3, "how often to check f(x) nodes, in seconds") flag.StringVar(&conf.HealthcheckEndpoint, "hc-path", "/version", "endpoint to determine node health") flag.IntVar(&conf.HealthcheckUnhealthy, "hc-unhealthy", 2, "threshold of failed checks to declare node unhealthy") + flag.IntVar(&conf.HealthcheckHealthy, "hc-healthy", 1, "threshold of success checks to declare node healthy") flag.IntVar(&conf.HealthcheckTimeout, "hc-timeout", 5, "timeout of healthcheck endpoint, in seconds") flag.StringVar(&conf.ZipkinURL, "zipkin", "", "zipkin endpoint to send traces") flag.Parse() @@ -54,7 +57,12 @@ func main() { }, } - g, err := lb.NewAllGrouper(conf) + db, err := lb.NewDB(conf) + if err != nil { + logrus.WithError(err).Fatal("error setting up database") + } + + g, err := lb.NewAllGrouper(conf, db) if err != nil { logrus.WithError(err).Fatal("error setting up grouper") } @@ -64,27 +72,57 @@ func main() { return r.URL.Path, nil } - h := lb.NewProxy(k, g, r, conf) - h = g.Wrap(h) // add/del/list endpoints - h = r.Wrap(h) // stats / dash endpoint + servers := make([]*http.Server, 0, 1) + handler := lb.NewProxy(k, g, r, conf) - err = serve(conf.Listen, h) - if err != nil { - logrus.WithError(err).Fatal("server error") + // a separate mgmt listener is requested? then let's create a LB traffic only server + if conf.Listen != conf.MgmtListen { + servers = append(servers, &http.Server{Addr: conf.Listen, Handler: handler}) + handler = lb.NullHandler() } + + // add mgmt endpoints to the handler + handler = g.Wrap(handler) // add/del/list endpoints + handler = r.Wrap(handler) // stats / dash endpoint + + servers = append(servers, &http.Server{Addr: conf.MgmtListen, Handler: handler}) + serve(servers, &conf) } -func serve(addr string, handler http.Handler) error { - server := &http.Server{Addr: addr, Handler: handler} +func serve(servers []*http.Server, conf *lb.Config) { ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGQUIT, syscall.SIGINT) - go func() { - for sig := range ch { - logrus.WithFields(logrus.Fields{"signal": sig}).Info("received signal") - server.Shutdown(context.Background()) // safe shutdown - return + + for i := 0; i < len(servers); i++ { + go func(idx int) { + err := servers[idx].ListenAndServe() + if err != nil && err != http.ErrServerClosed { + logrus.WithFields(logrus.Fields{"server_id": idx}).WithError(err).Fatal("server error") + } else { + logrus.WithFields(logrus.Fields{"server_id": idx}).Info("server stopped") + } + }(i) + } + + sig := <-ch + logrus.WithFields(logrus.Fields{"signal": sig}).Info("received signal") + + for i := 0; i < len(servers); i++ { + + ctx := context.Background() + + if conf.ShutdownTimeout > 0 { + tmpCtx, cancel := context.WithTimeout(context.Background(), time.Duration(conf.ShutdownTimeout)*time.Second) + ctx = tmpCtx + defer cancel() } - }() - return server.ListenAndServe() + + err := servers[i].Shutdown(ctx) // safe shutdown + if err != nil { + logrus.WithFields(logrus.Fields{"server_id": i}).WithError(err).Fatal("server shutdown error") + } else { + logrus.WithFields(logrus.Fields{"server_id": i}).Info("server shutdown") + } + } }