fn: fnlb: enhancements and new grouper tests (#493)

* fn: fnlb: enhancements and new grouper tests

*) added healthy threshold (default: 1)
*) grouper is now using configured hcEndpoint for version checks
*) grouper now logs when servers switch between healthy/unhealthy status
*) moved DB code out of grouper
*) run health check immediately at start (don't wait until hcInterval)
*) optional shutdown timeout (default: 0) & mgmt port (default: 8081)
*) hot path List() in grouper now uses atomic ptr Load
*) consistent router: moved closure to a new function
*) bugfix: version parsing from fn servers should not panic fnlb
*) bugfix: servers removed from DB, stayed in healthy list
*) bugfix: if DB is down, health checker stopped monitoring
*) basic new tests for grouper (add/rm/unhealthy/healthy) server
This commit is contained in:
Tolga Ceylan
2017-11-16 11:35:30 -08:00
committed by GitHub
parent 910612d0b1
commit 657afd5838
7 changed files with 937 additions and 290 deletions

View File

@@ -2,26 +2,16 @@ package lb
import (
"context"
"database/sql"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"fmt"
"github.com/coreos/go-semver/semver"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"github.com/sirupsen/logrus"
)
@@ -29,26 +19,24 @@ import (
// that are being maintained, regardless of key. An 'AllGrouper' will health
// check servers at a specified interval, taking them in and out as they
// pass/fail and exposes endpoints for adding, removing and listing nodes.
func NewAllGrouper(conf Config) (Grouper, error) {
db, err := db(conf.DBurl)
if err != nil {
return nil, err
}
func NewAllGrouper(conf Config, db DBStore) (Grouper, error) {
a := &allGrouper{
ded: make(map[string]int64),
db: db,
nodeList: make(map[string]nodeState),
nodeHealthyList: make([]string, 0),
db: db,
// XXX (reed): need to be reconfigurable at some point
hcInterval: time.Duration(conf.HealthcheckInterval) * time.Second,
hcEndpoint: conf.HealthcheckEndpoint,
hcUnhealthy: int64(conf.HealthcheckUnhealthy),
hcHealthy: int64(conf.HealthcheckHealthy),
hcTimeout: time.Duration(conf.HealthcheckTimeout) * time.Second,
minAPIVersion: *conf.MinAPIVersion,
// for health checks
httpClient: &http.Client{Transport: conf.Transport},
}
for _, n := range conf.Nodes {
err := a.add(n)
if err != nil {
@@ -60,6 +48,17 @@ func NewAllGrouper(conf Config) (Grouper, error) {
return a, nil
}
// nodeState is used to store success/fail counts and other health related data.
type nodeState struct {
// num of consecutive successes & failures
success uint64
fail uint64
// current health state
healthy bool
}
// allGrouper will return all healthy nodes it is tracking from List.
// nodes may be added / removed through the HTTP api. each allGrouper will
// poll its database for the full list of nodes, and then run its own
@@ -70,14 +69,12 @@ func NewAllGrouper(conf Config) (Grouper, error) {
// to maintain a list among nodes in the db, which could have thrashing
// due to network connectivity between any pair).
type allGrouper struct {
// protects allNodes, healthy & ded
sync.RWMutex
// TODO rename nodes to 'allNodes' or something so everything breaks and then stitch
// ded is the set of disjoint nodes nodes from intersecting nodes & healthy
allNodes, healthy []string
ded map[string]int64 // [node] -> failedCount
// allNodes is a cache of db.List, we can probably trash it..
// health checker state and lock
nodeLock sync.RWMutex
nodeList map[string]nodeState
nodeHealthyList []string
db DBStore
httpClient *http.Client
@@ -85,138 +82,11 @@ type allGrouper struct {
hcInterval time.Duration
hcEndpoint string
hcUnhealthy int64
hcHealthy int64
hcTimeout time.Duration
minAPIVersion semver.Version
}
// TODO put this somewhere better
type DBStore interface {
Add(string) error
Delete(string) error
List() ([]string, error)
}
// implements DBStore
type sqlStore struct {
db *sqlx.DB
// TODO we should prepare all of the statements, rebind them
// and store them all here.
}
// New will open the db specified by url, create any tables necessary
// and return a models.Datastore safe for concurrent usage.
func db(uri string) (DBStore, error) {
url, err := url.Parse(uri)
if err != nil {
return nil, err
}
driver := url.Scheme
// driver must be one of these for sqlx to work, double check:
switch driver {
case "postgres", "pgx", "mysql", "sqlite3", "oci8", "ora", "goracle":
default:
return nil, errors.New("invalid db driver, refer to the code")
}
if driver == "sqlite3" {
// make all the dirs so we can make the file..
dir := filepath.Dir(url.Path)
err := os.MkdirAll(dir, 0755)
if err != nil {
return nil, err
}
}
uri = url.String()
if driver != "postgres" {
// postgres seems to need this as a prefix in lib/pq, everyone else wants it stripped of scheme
uri = strings.TrimPrefix(url.String(), url.Scheme+"://")
}
sqldb, err := sql.Open(driver, uri)
if err != nil {
logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't open db")
return nil, err
}
db := sqlx.NewDb(sqldb, driver)
// force a connection and test that it worked
err = db.Ping()
if err != nil {
logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't ping db")
return nil, err
}
maxIdleConns := 30 // c.MaxIdleConnections
db.SetMaxIdleConns(maxIdleConns)
logrus.WithFields(logrus.Fields{"max_idle_connections": maxIdleConns, "datastore": driver}).Info("datastore dialed")
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS lb_nodes (
address text NOT NULL PRIMARY KEY
);`)
if err != nil {
return nil, err
}
return &sqlStore{db: db}, nil
}
func (s *sqlStore) Add(node string) error {
query := s.db.Rebind("INSERT INTO lb_nodes (address) VALUES (?);")
_, err := s.db.Exec(query, node)
if err != nil {
// if it already exists, just filter that error out
switch err := err.(type) {
case *mysql.MySQLError:
if err.Number == 1062 {
return nil
}
case *pq.Error:
if err.Code == "23505" {
return nil
}
case sqlite3.Error:
if err.ExtendedCode == sqlite3.ErrConstraintUnique || err.ExtendedCode == sqlite3.ErrConstraintPrimaryKey {
return nil
}
}
}
return err
}
func (s *sqlStore) Delete(node string) error {
query := s.db.Rebind(`DELETE FROM lb_nodes WHERE address=?`)
_, err := s.db.Exec(query, node)
// TODO we can filter if it didn't exist, too...
return err
}
func (s *sqlStore) List() ([]string, error) {
query := s.db.Rebind(`SELECT DISTINCT address FROM lb_nodes`)
rows, err := s.db.Query(query)
if err != nil {
return nil, err
}
var nodes []string
for rows.Next() {
var node string
err := rows.Scan(&node)
if err == nil {
nodes = append(nodes, node)
}
}
err = rows.Err()
if err == sql.ErrNoRows {
err = nil // don't care...
}
return nodes, err
}
func (a *allGrouper) add(newb string) error {
if newb == "" {
return nil // we can't really do a lot of validation since hosts could be an ip or domain but we have health checks
@@ -232,32 +102,33 @@ func (a *allGrouper) remove(ded string) error {
return a.db.Delete(ded)
}
// call with a.Lock held
func (a *allGrouper) addHealthy(newb string) {
// filter dupes, under lock. sorted, so binary search
i := sort.SearchStrings(a.healthy, newb)
if i < len(a.healthy) && a.healthy[i] == newb {
return
}
a.healthy = append(a.healthy, newb)
// need to keep in sorted order so that hash index works across nodes
sort.Sort(sort.StringSlice(a.healthy))
}
func (a *allGrouper) publishHealth() {
// call with a.Lock held
func (a *allGrouper) removeHealthy(ded string) {
i := sort.SearchStrings(a.healthy, ded)
if i < len(a.healthy) && a.healthy[i] == ded {
a.healthy = append(a.healthy[:i], a.healthy[i+1:]...)
a.nodeLock.Lock()
// get a list of healthy nodes
newList := make([]string, 0, len(a.nodeList))
for key, value := range a.nodeList {
if value.healthy {
newList = append(newList, key)
}
}
// sort and update healthy List
sort.Strings(newList)
a.nodeHealthyList = newList
a.nodeLock.Unlock()
}
// return a copy
func (a *allGrouper) List(string) ([]string, error) {
a.RLock()
ret := make([]string, len(a.healthy))
copy(ret, a.healthy)
a.RUnlock()
a.nodeLock.RLock()
ret := make([]string, len(a.nodeHealthyList))
copy(ret, a.nodeHealthyList)
a.nodeLock.RUnlock()
var err error
if len(ret) == 0 {
err = ErrNoNodes
@@ -265,22 +136,76 @@ func (a *allGrouper) List(string) ([]string, error) {
return ret, err
}
func (a *allGrouper) runHealthCheck() {
// fetch a list of nodes from DB
list, err := a.db.List()
if err != nil {
// if DB fails, the show must go on, report it but perform HC
logrus.WithError(err).Error("error checking db for nodes")
// compile a list of nodes to be health checked
a.nodeLock.RLock()
list = make([]string, 0, len(a.nodeList))
for key, _ := range a.nodeList {
list = append(list, key)
}
a.nodeLock.RUnlock()
} else {
isChanged := false
// compile a map of DB nodes for deletion check
deleteCheck := make(map[string]bool, len(list))
for _, node := range list {
deleteCheck[node] = true
}
a.nodeLock.Lock()
// handle new nodes
for _, node := range list {
_, ok := a.nodeList[node]
if !ok {
// add new node
a.nodeList[node] = nodeState{
healthy: true,
}
isChanged = true
}
}
// handle deleted nodes: purge unmarked nodes
for key, _ := range a.nodeList {
_, ok := deleteCheck[key]
if !ok {
delete(a.nodeList, key)
isChanged = true
}
}
a.nodeLock.Unlock()
// publish if add/deleted nodes
if isChanged {
a.publishHealth()
}
}
// spawn health checkers
for _, key := range list {
go a.ping(key)
}
}
func (a *allGrouper) healthcheck() {
// run hc immediately upon startup
a.runHealthCheck()
for range time.Tick(a.hcInterval) {
// health check the entire list of nodes [from db]
list, err := a.db.List()
if err != nil {
logrus.WithError(err).Error("error checking db for nodes")
continue
}
a.Lock()
a.allNodes = list
a.Unlock()
for _, n := range list {
go a.ping(n)
}
a.runHealthCheck()
}
}
@@ -312,14 +237,18 @@ func (a *allGrouper) getVersion(urlString string) (string, error) {
}
func (a *allGrouper) checkAPIVersion(node string) error {
versionURL := "http://" + node + "/version"
versionURL := "http://" + node + a.hcEndpoint
version, err := a.getVersion(versionURL)
if err != nil {
return err
}
nodeVer := semver.New(version)
nodeVer, err := semver.NewVersion(version)
if err != nil {
return err
}
if nodeVer.LessThan(a.minAPIVersion) {
return fmt.Errorf("incompatible API version: %v", nodeVer)
}
@@ -336,32 +265,79 @@ func (a *allGrouper) ping(node string) {
}
}
func (a *allGrouper) fail(node string) {
// shouldn't be a hot path so shouldn't be too contended on since health
// checks are infrequent
a.Lock()
a.ded[node]++
failed := a.ded[node]
if failed >= a.hcUnhealthy {
a.removeHealthy(node)
func (a *allGrouper) fail(key string) {
isChanged := false
a.nodeLock.Lock()
// if deleted, skip
node, ok := a.nodeList[key]
if !ok {
a.nodeLock.Unlock()
return
}
node.success = 0
node.fail++
// overflow case
if node.fail == 0 {
node.fail = uint64(a.hcUnhealthy)
}
if node.healthy && node.fail >= uint64(a.hcUnhealthy) {
node.healthy = false
isChanged = true
}
a.nodeList[key] = node
a.nodeLock.Unlock()
if isChanged {
logrus.WithFields(logrus.Fields{"node": key}).Info("is unhealthy")
a.publishHealth()
}
a.Unlock()
}
func (a *allGrouper) alive(node string) {
// TODO alive is gonna get called a lot, should maybe start w/ every node in ded
// so we can RLock (but lock contention should be low since these are ~quick) --
// "a lot" being every 1s per node, so not too crazy really, but 1k nodes @ ms each...
a.Lock()
delete(a.ded, node)
a.addHealthy(node)
a.Unlock()
func (a *allGrouper) alive(key string) {
isChanged := false
a.nodeLock.Lock()
// if deleted, skip
node, ok := a.nodeList[key]
if !ok {
a.nodeLock.Unlock()
return
}
node.fail = 0
node.success++
// overflow case
if node.success == 0 {
node.success = uint64(a.hcHealthy)
}
if !node.healthy && node.success >= uint64(a.hcHealthy) {
node.healthy = true
isChanged = true
}
a.nodeList[key] = node
a.nodeLock.Unlock()
if isChanged {
logrus.WithFields(logrus.Fields{"node": key}).Info("is healthy")
a.publishHealth()
}
}
func (a *allGrouper) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
// XXX (reed): probably do these on a separate port to avoid conflicts
case "/1/lb/nodes":
switch r.Method {
case "PUT":
@@ -415,33 +391,24 @@ func (a *allGrouper) removeNode(w http.ResponseWriter, r *http.Request) {
}
func (a *allGrouper) listNodes(w http.ResponseWriter, r *http.Request) {
a.RLock()
nodes := make([]string, len(a.allNodes))
copy(nodes, a.allNodes)
a.RUnlock()
// TODO this isn't correct until at least one health check has hit all nodes (on start up).
// seems like not a huge deal, but here's a note anyway (every node will simply 'appear' healthy
// from this api even if we aren't routing to it [until first health check]).
out := make(map[string]string, len(nodes))
for _, n := range nodes {
if a.isDead(n) {
out[n] = "offline"
a.nodeLock.RLock()
out := make(map[string]string, len(a.nodeList))
for key, value := range a.nodeList {
if value.healthy {
out[key] = "online"
} else {
out[n] = "online"
out[key] = "offline"
}
}
a.nodeLock.RUnlock()
sendValue(w, struct {
Nodes map[string]string `json:"nodes"`
}{
Nodes: out,
})
}
func (a *allGrouper) isDead(node string) bool {
a.RLock()
val, ok := a.ded[node]
a.RUnlock()
return ok && val >= a.hcUnhealthy
}

480
fnlb/lb/allgrouper_test.go Normal file
View File

@@ -0,0 +1,480 @@
package lb
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"sort"
"testing"
"time"
"github.com/coreos/go-semver/semver"
)
type mockDB struct {
isAddError bool
isDeleteError bool
isListError bool
nodeList map[string]bool
}
func (mock *mockDB) Add(node string) error {
if mock.isAddError {
return errors.New("simulated add error")
}
mock.nodeList[node] = true
return nil
}
func (mock *mockDB) Delete(node string) error {
if mock.isDeleteError {
return errors.New("simulated delete error")
}
delete(mock.nodeList, node)
return nil
}
func (mock *mockDB) List() ([]string, error) {
if mock.isListError {
return nil, errors.New("simulated list error")
}
list := make([]string, 0, len(mock.nodeList))
for key, _ := range mock.nodeList {
list = append(list, key)
}
return list, nil
}
func initializeRunner() (Grouper, error) {
db := &mockDB{
nodeList: make(map[string]bool),
}
conf := Config{
HealthcheckInterval: 1,
HealthcheckEndpoint: "/version",
HealthcheckUnhealthy: 1,
HealthcheckHealthy: 1,
HealthcheckTimeout: 1,
MinAPIVersion: semver.New("0.0.104"),
Transport: &http.Transport{},
}
return NewAllGrouper(conf, db)
}
type testServer struct {
addr string
version string
healthy bool
inPool bool
listener *net.Listener
server *http.Server
}
func (s *testServer) getHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if "/version" == r.URL.Path {
if s.healthy {
sendValue(w, fnVersion{Version: s.version})
} else {
sendError(w, http.StatusServiceUnavailable, "service unhealthy")
}
} else {
sendError(w, http.StatusNotFound, "unknown uri")
}
})
}
// return a list of supposed to be healthy (good version and in pool) nodes
func getCurrentHealthySet(list []*testServer) []string {
out := make([]string, 0)
for _, val := range list {
if val.healthy && val.inPool {
out = append(out, val.addr)
}
}
sort.Strings(out)
return out
}
// shutdown a server
func teardownServer(t *testing.T, server *http.Server) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(100)*time.Millisecond)
defer cancel()
err := server.Shutdown(ctx)
if err != nil {
t.Logf("shutdown error: %s", err.Error())
}
}
// spin up backend servers
func initializeAPIServer(t *testing.T, grouper Grouper) (*http.Server, string, error) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return nil, "", err
}
addr := listener.Addr().String()
handler := NullHandler()
handler = grouper.Wrap(handler) // add/del/list endpoints
server := &http.Server{Handler: handler}
go func(srv *http.Server, addr string) {
err := server.Serve(listener)
if err != nil {
t.Logf("server exited %s with %s", addr, err.Error())
}
}(server, addr)
return server, addr, nil
}
// spin up backend servers
func initializeTestServers(t *testing.T, numOfServers uint64) ([]*testServer, error) {
list := make([]*testServer, 0)
for i := uint64(0); i < numOfServers; i++ {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return list, err
}
server := &testServer{
addr: listener.Addr().String(),
version: "0.0.104",
healthy: true,
inPool: false,
listener: &listener,
}
server.server = &http.Server{Handler: server.getHandler()}
go func(srv *testServer) {
err := srv.server.Serve(listener)
if err != nil {
t.Logf("server exited %s with %s", srv.addr, err.Error())
}
}(server)
list = append(list, server)
}
return list, nil
}
// tear down backend servers
func shutdownTestServers(t *testing.T, servers []*testServer) {
for _, srv := range servers {
teardownServer(t, srv.server)
}
}
func testCompare(t *testing.T, grouper Grouper, servers []*testServer, ctx string) {
// compare current supposed to be healthy VS healthy list from allGrouper
current := getCurrentHealthySet(servers)
t.Logf("%s Expecting healthy servers %v", ctx, current)
round, err := grouper.List("ignore")
if err != nil {
if len(current) != 0 {
t.Errorf("%s Not expected error %s", ctx, err.Error())
}
} else {
t.Logf("%s Detected healthy servers %v", ctx, round)
if len(current) != len(round) {
t.Errorf("%s Got %d servers, expected: %d", ctx, len(round), len(current))
}
for idx, srv := range round {
if srv != current[idx] {
t.Errorf("%s Mismatch idx: %d %s != %s", ctx, idx, srv, current[idx])
}
}
}
}
// using mgmt API modify (add/remove) a node
func mgmtModServer(t *testing.T, addr string, operation string, node string) error {
client := &http.Client{}
url := "http://" + addr + "/1/lb/nodes"
str := fmt.Sprintf("{\"Node\":\"%s\"}", node)
body := []byte(str)
req, err := http.NewRequest(operation, url, bytes.NewBuffer(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
t.Logf("%s node=%s response=%s status=%d", operation, node, respBody, resp.StatusCode)
if resp.StatusCode != 200 {
return fmt.Errorf("%s node=%s status=%d %s", operation, node, resp.StatusCode, respBody)
}
return nil
}
// using mgmt api list servers and compare with test server list
func mgmtListServers(t *testing.T, addr string, servers []*testServer) error {
client := &http.Client{}
url := "http://" + addr + "/1/lb/nodes"
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("list status=%d", resp.StatusCode)
}
var listResp struct {
Nodes map[string]string `json:"nodes"`
}
err = json.NewDecoder(resp.Body).Decode(&listResp)
if err != nil {
return err
}
cmpList := make(map[string]string, len(servers))
for _, val := range servers {
if val.inPool {
if val.healthy {
cmpList[val.addr] = "online"
} else {
cmpList[val.addr] = "offline"
}
}
}
t.Logf("list response=%v expected=%v", listResp.Nodes, cmpList)
for key, val1 := range cmpList {
val2, ok := listResp.Nodes[key]
if !ok {
t.Errorf("failed list comparison node=`%s` is not in received", key)
return nil
}
if val1 != val2 {
t.Errorf("failed list comparison node=`%s` expected=`%s` received=`%s`", key, val1, val2)
return nil
}
delete(cmpList, key)
delete(listResp.Nodes, key)
}
if len(cmpList) != 0 || len(listResp.Nodes) != 0 {
t.Errorf("failed list comparison (remaining unmatches) expected=`%v` received=`%v`", cmpList, listResp.Nodes)
}
return nil
}
// Basic tests via DB add/remove functions
func TestRouteRunnerExecution(t *testing.T) {
a, err := initializeRunner()
if err != nil {
t.Errorf("Not expected error `%s`", err.Error())
}
var concrete *allGrouper
concrete = a.(*allGrouper)
// initialize and add some servers (all healthy)
serverCount := 10
servers, err := initializeTestServers(t, uint64(serverCount))
if err != nil {
t.Errorf("Not expected error `%s`", err.Error())
} else {
defer shutdownTestServers(t, servers)
if serverCount != len(servers) {
t.Errorf("Got %d servers, expected: %d", len(servers), serverCount)
}
srvList := make([]string, 0, len(servers))
for _, srv := range servers {
srvList = append(srvList, srv.addr)
}
t.Logf("Spawned servers %s", srvList)
testCompare(t, a, servers, "round0")
for _, srv := range servers {
// add these servers to allGrouper
err := concrete.add(srv.addr)
if err != nil {
t.Errorf("Not expected error `%s` when adding `%s`", err.Error(), srv.addr)
}
srv.inPool = true
}
}
t.Logf("Starting round1 all servers healthy")
// let health checker converge
time.Sleep(time.Duration(2) * time.Second)
testCompare(t, a, servers, "round1")
t.Logf("Starting round2 one server unhealthy")
// now set one server unhealthy
servers[2].healthy = false
t.Logf("Setting server %s to unhealthy", servers[2].addr)
// let health checker converge
time.Sleep(time.Duration(2) * time.Second)
testCompare(t, a, servers, "round2")
t.Logf("Starting round3 remove one server from grouper")
t.Logf("Removing server %s from grouper", servers[3].addr)
err = concrete.remove(servers[3].addr)
if err != nil {
t.Errorf("Not expected error `%s` when removing `%s`", err.Error(), servers[3].addr)
}
servers[3].inPool = false
time.Sleep(time.Duration(2) * time.Second)
testCompare(t, a, servers, "round3")
t.Logf("Starting round4 add server back to grouper")
t.Logf("Adding server %s to grouper", servers[3].addr)
err = concrete.add(servers[3].addr)
if err != nil {
t.Errorf("Not expected error `%s` when adding `%s`", err.Error(), servers[3].addr)
}
servers[3].inPool = true
time.Sleep(time.Duration(2) * time.Second)
testCompare(t, a, servers, "round4")
t.Logf("Starting round5 set unhealthy server back to healthy")
servers[2].healthy = true
t.Logf("Setting server %s to healthy", servers[2].addr)
// let health checker converge
time.Sleep(time.Duration(2) * time.Second)
testCompare(t, a, servers, "round5")
t.Logf("Starting round6 no change")
// fetch list again
testCompare(t, a, servers, "round6")
}
// Basic tests via mgmt API
func TestRouteRunnerMgmtAPI(t *testing.T) {
a, err := initializeRunner()
if err != nil {
t.Errorf("Not expected error `%s`", err.Error())
}
mgmtSrv, mgmtAddr, err := initializeAPIServer(t, a)
if err != nil {
t.Errorf("cannot start mgmt api server `%s`", err.Error())
}
defer teardownServer(t, mgmtSrv)
// initialize and add some servers (all healthy)
serverCount := 5
servers, err := initializeTestServers(t, uint64(serverCount))
if err != nil {
t.Errorf("Not expected error `%s`", err.Error())
} else {
defer shutdownTestServers(t, servers)
if serverCount != len(servers) {
t.Errorf("Got %d servers, expected: %d", len(servers), serverCount)
}
srvList := make([]string, 0, len(servers))
for _, srv := range servers {
srvList = append(srvList, srv.addr)
}
t.Logf("Spawned servers %s", srvList)
testCompare(t, a, servers, "round0")
for _, srv := range servers {
err := mgmtModServer(t, mgmtAddr, "PUT", srv.addr)
if err != nil {
t.Errorf("Not expected error `%s` when adding `%s`", err.Error(), srv.addr)
}
srv.inPool = true
}
}
t.Logf("Starting round1 all servers healthy")
// let health checker converge
time.Sleep(time.Duration(2) * time.Second)
testCompare(t, a, servers, "round1")
err = mgmtListServers(t, mgmtAddr, servers)
if err != nil {
t.Errorf("Not expected error `%s` when listing", err.Error())
}
t.Logf("Starting round2 remove one server from grouper")
// let's set server at 2 as unhealthy as well
servers[2].healthy = false
t.Logf("Removing server %s from grouper", servers[3].addr)
err = mgmtModServer(t, mgmtAddr, "DELETE", servers[3].addr)
if err != nil {
t.Errorf("Not expected error `%s` when removing `%s`", err.Error(), servers[3].addr)
}
servers[3].inPool = false
time.Sleep(time.Duration(2) * time.Second)
testCompare(t, a, servers, "round2")
err = mgmtListServers(t, mgmtAddr, servers)
if err != nil {
t.Errorf("Not expected error `%s` when listing", err.Error())
}
}
// TODO: test old version case
// TODO: test DB unhealthy case
// TODO: test healthy/unhealthy thresholds
// TODO: test health check timeout case

View File

@@ -150,61 +150,61 @@ func loadKey(node, key string) string {
return node + "\x00" + key
}
func (ch *chRouter) checkLoad(key, n string) bool {
var load time.Duration
ch.loadMu.RLock()
loadPtr := ch.load[loadKey(n, key)]
ch.loadMu.RUnlock()
if loadPtr != nil {
load = time.Duration(atomic.LoadInt64(loadPtr))
}
const (
// TODO we should probably use deltas rather than fixed wait times. for 'cold'
// functions these could always trigger. i.e. if wait time increased 5x over last
// 100 data points, point the cannon elsewhere (we'd have to track 2 numbers but meh)
lowerLat = 500 * time.Millisecond
upperLat = 2 * time.Second
)
// TODO flesh out these values.
// if we send < 50% of traffic off to other nodes when loaded
// then as function scales nodes will get flooded, need to be careful.
//
// back off loaded node/function combos slightly to spread load
if load < lowerLat {
return true
} else if load > upperLat {
// really loaded
if ch.rng.Intn(100) < 10 { // XXX (reed): 10% could be problematic, should sliding scale prob with log(x) ?
return true
}
} else {
// 10 < x < 40, as load approaches upperLat, x decreases [linearly]
x := translate(int64(load), int64(lowerLat), int64(upperLat), 10, 40)
if ch.rng.Intn(100) < x {
return true
}
}
// return invalid node to try next node
return false
}
func (ch *chRouter) besti(key string, i int, nodes []string) (string, error) {
if len(nodes) < 1 {
// supposed to be caught in grouper, but double check
return "", ErrNoNodes
}
// XXX (reed): trash the closure
f := func(n string) string {
var load time.Duration
ch.loadMu.RLock()
loadPtr := ch.load[loadKey(n, key)]
ch.loadMu.RUnlock()
if loadPtr != nil {
load = time.Duration(atomic.LoadInt64(loadPtr))
}
const (
// TODO we should probably use deltas rather than fixed wait times. for 'cold'
// functions these could always trigger. i.e. if wait time increased 5x over last
// 100 data points, point the cannon elsewhere (we'd have to track 2 numbers but meh)
lowerLat = 500 * time.Millisecond
upperLat = 2 * time.Second
)
// TODO flesh out these values.
// if we send < 50% of traffic off to other nodes when loaded
// then as function scales nodes will get flooded, need to be careful.
//
// back off loaded node/function combos slightly to spread load
if load < lowerLat {
return n
} else if load > upperLat {
// really loaded
if ch.rng.Intn(100) < 10 { // XXX (reed): 10% could be problematic, should sliding scale prob with log(x) ?
return n
}
} else {
// 10 < x < 40, as load approaches upperLat, x decreases [linearly]
x := translate(int64(load), int64(lowerLat), int64(upperLat), 10, 40)
if ch.rng.Intn(100) < x {
return n
}
}
// return invalid node to try next node
return ""
}
for ; ; i++ {
// theoretically this could take infinite time, but practically improbable...
// TODO we need a way to add a node for a given key from down here if a node is overloaded.
node := f(nodes[i])
if node != "" {
return node, nil
} else if i == len(nodes)-1 {
if ch.checkLoad(key, nodes[i]) {
return nodes[i], nil
}
if i == len(nodes)-1 {
i = -1 // reset i to 0
}
}
@@ -221,7 +221,6 @@ func translate(val, inFrom, inTo, outFrom, outTo int64) int {
func (ch *chRouter) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
// XXX (reed): probably do these on a separate port to avoid conflicts
case "/1/lb/stats":
ch.statsGet(w, r)
return

153
fnlb/lb/db.go Normal file
View File

@@ -0,0 +1,153 @@
package lb
import (
"database/sql"
"errors"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"github.com/sirupsen/logrus"
)
func NewDB(conf Config) (DBStore, error) {
db, err := db(conf.DBurl)
if err != nil {
return nil, err
}
return db, err
}
// TODO put this somewhere better
type DBStore interface {
Add(string) error
Delete(string) error
List() ([]string, error)
}
// implements DBStore
type sqlStore struct {
db *sqlx.DB
// TODO we should prepare all of the statements, rebind them
// and store them all here.
}
// New will open the db specified by url, create any tables necessary
// and return a models.Datastore safe for concurrent usage.
func db(uri string) (DBStore, error) {
url, err := url.Parse(uri)
if err != nil {
return nil, err
}
driver := url.Scheme
// driver must be one of these for sqlx to work, double check:
switch driver {
case "postgres", "pgx", "mysql", "sqlite3", "oci8", "ora", "goracle":
default:
return nil, errors.New("invalid db driver, refer to the code")
}
if driver == "sqlite3" {
// make all the dirs so we can make the file..
dir := filepath.Dir(url.Path)
err := os.MkdirAll(dir, 0755)
if err != nil {
return nil, err
}
}
uri = url.String()
if driver != "postgres" {
// postgres seems to need this as a prefix in lib/pq, everyone else wants it stripped of scheme
uri = strings.TrimPrefix(url.String(), url.Scheme+"://")
}
sqldb, err := sql.Open(driver, uri)
if err != nil {
logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't open db")
return nil, err
}
db := sqlx.NewDb(sqldb, driver)
// force a connection and test that it worked
err = db.Ping()
if err != nil {
logrus.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't ping db")
return nil, err
}
maxIdleConns := 30 // c.MaxIdleConnections
db.SetMaxIdleConns(maxIdleConns)
logrus.WithFields(logrus.Fields{"max_idle_connections": maxIdleConns, "datastore": driver}).Info("datastore dialed")
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS lb_nodes (
address text NOT NULL PRIMARY KEY
);`)
if err != nil {
return nil, err
}
return &sqlStore{db: db}, nil
}
func (s *sqlStore) Add(node string) error {
query := s.db.Rebind("INSERT INTO lb_nodes (address) VALUES (?);")
_, err := s.db.Exec(query, node)
if err != nil {
// if it already exists, just filter that error out
switch err := err.(type) {
case *mysql.MySQLError:
if err.Number == 1062 {
return nil
}
case *pq.Error:
if err.Code == "23505" {
return nil
}
case sqlite3.Error:
if err.ExtendedCode == sqlite3.ErrConstraintUnique || err.ExtendedCode == sqlite3.ErrConstraintPrimaryKey {
return nil
}
}
}
return err
}
func (s *sqlStore) Delete(node string) error {
query := s.db.Rebind(`DELETE FROM lb_nodes WHERE address=?`)
_, err := s.db.Exec(query, node)
// TODO we can filter if it didn't exist, too...
return err
}
func (s *sqlStore) List() ([]string, error) {
query := s.db.Rebind(`SELECT DISTINCT address FROM lb_nodes`)
rows, err := s.db.Query(query)
if err != nil {
return nil, err
}
var nodes []string
for rows.Next() {
var node string
err := rows.Scan(&node)
if err == nil {
nodes = append(nodes, node)
}
}
err = rows.Err()
if err == sql.ErrNoRows {
err = nil // don't care...
}
return nodes, err
}

View File

@@ -32,11 +32,14 @@ import (
type Config struct {
DBurl string `json:"db_url"`
Listen string `json:"port"`
MgmtListen string `json:"mgmt_port"`
ShutdownTimeout int `json:"shutdown_timeout"`
ZipkinURL string `json:"zipkin_url"`
Nodes []string `json:"nodes"`
HealthcheckInterval int `json:"healthcheck_interval"`
HealthcheckEndpoint string `json:"healthcheck_endpoint"`
HealthcheckUnhealthy int `json:"healthcheck_unhealthy"`
HealthcheckHealthy int `json:"healthcheck_healthy"`
HealthcheckTimeout int `json:"healthcheck_timeout"`
MinAPIVersion *semver.Version `json:"min_api_version"`

View File

@@ -9,7 +9,8 @@ import (
)
var (
ErrNoNodes = errors.New("no nodes available")
ErrNoNodes = errors.New("no nodes available")
ErrUnknownCommand = errors.New("unknown command")
)
func sendValue(w http.ResponseWriter, v interface{}) {
@@ -45,3 +46,9 @@ func sendError(w http.ResponseWriter, code int, msg string) {
logrus.WithError(err).Error("error writing response response")
}
}
func NullHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sendError(w, http.StatusNotFound, ErrUnknownCommand.Error())
})
}

View File

@@ -27,9 +27,12 @@ func main() {
var conf lb.Config
flag.StringVar(&conf.DBurl, "db", "sqlite3://:memory:", "backend to store nodes, default to in memory")
flag.StringVar(&conf.Listen, "listen", ":8081", "port to run on")
flag.StringVar(&conf.MgmtListen, "mgmt-listen", ":8081", "management port to run on")
flag.IntVar(&conf.ShutdownTimeout, "shutdown-timeout", 0, "graceful shutdown timeout")
flag.IntVar(&conf.HealthcheckInterval, "hc-interval", 3, "how often to check f(x) nodes, in seconds")
flag.StringVar(&conf.HealthcheckEndpoint, "hc-path", "/version", "endpoint to determine node health")
flag.IntVar(&conf.HealthcheckUnhealthy, "hc-unhealthy", 2, "threshold of failed checks to declare node unhealthy")
flag.IntVar(&conf.HealthcheckHealthy, "hc-healthy", 1, "threshold of success checks to declare node healthy")
flag.IntVar(&conf.HealthcheckTimeout, "hc-timeout", 5, "timeout of healthcheck endpoint, in seconds")
flag.StringVar(&conf.ZipkinURL, "zipkin", "", "zipkin endpoint to send traces")
flag.Parse()
@@ -54,7 +57,12 @@ func main() {
},
}
g, err := lb.NewAllGrouper(conf)
db, err := lb.NewDB(conf)
if err != nil {
logrus.WithError(err).Fatal("error setting up database")
}
g, err := lb.NewAllGrouper(conf, db)
if err != nil {
logrus.WithError(err).Fatal("error setting up grouper")
}
@@ -64,27 +72,57 @@ func main() {
return r.URL.Path, nil
}
h := lb.NewProxy(k, g, r, conf)
h = g.Wrap(h) // add/del/list endpoints
h = r.Wrap(h) // stats / dash endpoint
servers := make([]*http.Server, 0, 1)
handler := lb.NewProxy(k, g, r, conf)
err = serve(conf.Listen, h)
if err != nil {
logrus.WithError(err).Fatal("server error")
// a separate mgmt listener is requested? then let's create a LB traffic only server
if conf.Listen != conf.MgmtListen {
servers = append(servers, &http.Server{Addr: conf.Listen, Handler: handler})
handler = lb.NullHandler()
}
// add mgmt endpoints to the handler
handler = g.Wrap(handler) // add/del/list endpoints
handler = r.Wrap(handler) // stats / dash endpoint
servers = append(servers, &http.Server{Addr: conf.MgmtListen, Handler: handler})
serve(servers, &conf)
}
func serve(addr string, handler http.Handler) error {
server := &http.Server{Addr: addr, Handler: handler}
func serve(servers []*http.Server, conf *lb.Config) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGQUIT, syscall.SIGINT)
go func() {
for sig := range ch {
logrus.WithFields(logrus.Fields{"signal": sig}).Info("received signal")
server.Shutdown(context.Background()) // safe shutdown
return
for i := 0; i < len(servers); i++ {
go func(idx int) {
err := servers[idx].ListenAndServe()
if err != nil && err != http.ErrServerClosed {
logrus.WithFields(logrus.Fields{"server_id": idx}).WithError(err).Fatal("server error")
} else {
logrus.WithFields(logrus.Fields{"server_id": idx}).Info("server stopped")
}
}(i)
}
sig := <-ch
logrus.WithFields(logrus.Fields{"signal": sig}).Info("received signal")
for i := 0; i < len(servers); i++ {
ctx := context.Background()
if conf.ShutdownTimeout > 0 {
tmpCtx, cancel := context.WithTimeout(context.Background(), time.Duration(conf.ShutdownTimeout)*time.Second)
ctx = tmpCtx
defer cancel()
}
}()
return server.ListenAndServe()
err := servers[i].Shutdown(ctx) // safe shutdown
if err != nil {
logrus.WithFields(logrus.Fields{"server_id": i}).WithError(err).Fatal("server shutdown error")
} else {
logrus.WithFields(logrus.Fields{"server_id": i}).Info("server shutdown")
}
}
}