diff --git a/api/agent/agent.go b/api/agent/agent.go index ca3bfad91..09afc991d 100644 --- a/api/agent/agent.go +++ b/api/agent/agent.go @@ -112,21 +112,22 @@ type agent struct { resources ResourceTracker // used to track running calls / safe shutdown - wg sync.WaitGroup // TODO rename + shutWg *common.WaitGroup shutonce sync.Once - shutdown chan struct{} callEndCount int64 } // New creates an Agent that executes functions locally as Docker containers. func New(da DataAccess) Agent { - a := createAgent(da, true).(*agent) - a.wg.Add(1) + a := createAgent(da, true, nil).(*agent) + if !a.shutWg.AddSession(1) { + logrus.Fatalf("cannot start agent, unable to add session") + } go a.asyncDequeue() // safe shutdown can nanny this fine return a } -func createAgent(da DataAccess, withDocker bool) Agent { +func createAgent(da DataAccess, withDocker bool, withShutWg *common.WaitGroup) Agent { cfg, err := NewAgentConfig() if err != nil { logrus.WithError(err).Fatalf("error in agent config cfg=%+v", cfg) @@ -147,6 +148,9 @@ func createAgent(da DataAccess, withDocker bool) Agent { } else { driver = mock.New() } + if withShutWg == nil { + withShutWg = common.NewWaitGroup() + } a := &agent{ cfg: *cfg, @@ -154,7 +158,7 @@ func createAgent(da DataAccess, withDocker bool) Agent { driver: driver, slotMgr: NewSlotQueueMgr(), resources: NewResourceTracker(cfg), - shutdown: make(chan struct{}), + shutWg: withShutWg, } // TODO assert that agent doesn't get started for API nodes up above ? @@ -176,25 +180,23 @@ func (a *agent) Enqueue(ctx context.Context, call *models.Call) error { func (a *agent) Close() error { var err error + + // wait for ongoing sessions + a.shutWg.CloseGroup() + a.shutonce.Do(func() { + // now close docker layer if a.driver != nil { err = a.driver.Close() } - close(a.shutdown) }) - a.wg.Wait() return err } func (a *agent) Submit(callI Call) error { - a.wg.Add(1) - defer a.wg.Done() - - select { - case <-a.shutdown: + if !a.shutWg.AddSession(1) { return models.ErrCallTimeoutServerBusy - default: } call := callI.(*call) @@ -254,15 +256,24 @@ func (a *agent) submit(ctx context.Context, call *call) error { } func (a *agent) scheduleCallEnd(fn func()) { - a.wg.Add(1) atomic.AddInt64(&a.callEndCount, 1) go func() { fn() atomic.AddInt64(&a.callEndCount, -1) - a.wg.Done() + a.shutWg.AddSession(-1) }() } +func (a *agent) finalizeCallEnd(ctx context.Context, err error, isRetriable, isScheduled bool) error { + // if scheduled in background, let scheduleCallEnd() handle + // the shutWg group, otherwise decrement here. + if !isScheduled { + a.shutWg.AddSession(-1) + } + handleStatsEnd(ctx, err) + return transformTimeout(err, isRetriable) +} + func (a *agent) handleCallEnd(ctx context.Context, call *call, slot Slot, err error, isCommitted bool) error { // For hot-containers, slot close is a simple channel close... No need @@ -284,9 +295,7 @@ func (a *agent) handleCallEnd(ctx context.Context, call *call, slot Slot, err er call.End(ctx, err) cancel() }) - - handleStatsEnd(ctx, err) - return transformTimeout(err, false) + return a.finalizeCallEnd(ctx, err, false, true) } // The call did not succeed. And it is retriable. We close the slot @@ -296,10 +305,10 @@ func (a *agent) handleCallEnd(ctx context.Context, call *call, slot Slot, err er a.scheduleCallEnd(func() { slot.Close(common.BackgroundContext(ctx)) // (no timeout) }) + return a.finalizeCallEnd(ctx, err, true, true) } - handleStatsDequeue(ctx, err) - return transformTimeout(err, true) + return a.finalizeCallEnd(ctx, err, true, false) } func transformTimeout(e error, isRetriable bool) error { @@ -400,7 +409,7 @@ func (a *agent) hotLauncher(ctx context.Context, call *call) { a.checkLaunch(ctx, call) select { - case <-a.shutdown: // server shutdown + case <-a.shutWg.Closer(): // server shutdown cancel() return case <-ctx.Done(): // timed out @@ -431,17 +440,22 @@ func (a *agent) checkLaunch(ctx context.Context, call *call) { select { case tok := <-a.resources.GetResourceToken(ctx, call.Memory, uint64(call.CPUs), isAsync): - a.wg.Add(1) // add waiter in this thread - go func() { - // NOTE: runHot will not inherit the timeout from ctx (ignore timings) - a.runHot(ctx, call, tok, state) - a.wg.Done() - }() + if a.shutWg.AddSession(1) { + go func() { + // NOTE: runHot will not inherit the timeout from ctx (ignore timings) + a.runHot(ctx, call, tok, state) + a.shutWg.AddSession(-1) + }() + return + } + if tok != nil { + tok.Close() + } case <-ctx.Done(): // timeout - state.UpdateState(ctx, ContainerStateDone, call.slots) - case <-a.shutdown: // server shutdown - state.UpdateState(ctx, ContainerStateDone, call.slots) + case <-a.shutWg.Closer(): // server shutdown } + + state.UpdateState(ctx, ContainerStateDone, call.slots) } // waitHot pings and waits for a hot container from the slot queue @@ -471,7 +485,7 @@ func (a *agent) waitHot(ctx context.Context, call *call) (Slot, error) { // we failed to take ownership of the token (eg. container idle timeout) => try again case <-ctx.Done(): return nil, ctx.Err() - case <-a.shutdown: // server shutdown + case <-a.shutWg.Closer(): // server shutdown return nil, models.ErrCallTimeoutServerBusy case <-time.After(sleep): // ping dequeuer again @@ -735,7 +749,7 @@ func (a *agent) runHot(ctx context.Context, call *call, tok ResourceToken, state select { // make sure everything is up before trying to send slot case <-ctx.Done(): // container shutdown return - case <-a.shutdown: // server shutdown + case <-a.shutWg.Closer(): // server shutdown return default: // ok } @@ -808,7 +822,7 @@ func (a *agent) runHotReq(ctx context.Context, call *call, state ContainerState, select { case <-s.trigger: // slot already consumed case <-ctx.Done(): // container shutdown - case <-a.shutdown: // server shutdown + case <-a.shutWg.Closer(): // server shutdown case <-idleTimer.C: case <-freezeTimer.C: if !isFrozen { diff --git a/api/agent/async.go b/api/agent/async.go index 43c6771b6..a26bbb6f4 100644 --- a/api/agent/async.go +++ b/api/agent/async.go @@ -12,8 +12,6 @@ import ( ) func (a *agent) asyncDequeue() { - defer a.wg.Done() // we can treat this thread like one big task and get safe shutdown fo free - // this is just so we can hang up the dequeue request if we get shut down ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -24,7 +22,8 @@ func (a *agent) asyncDequeue() { for { select { - case <-a.shutdown: + case <-a.shutWg.Closer(): + a.shutWg.AddSession(-1) return case <-a.resources.WaitAsyncResource(ctx): // TODO we _could_ return a token here to reserve the ram so that there's @@ -35,15 +34,20 @@ func (a *agent) asyncDequeue() { // we think we can get a cookie now, so go get a cookie select { - case <-a.shutdown: + case <-a.shutWg.Closer(): + a.shutWg.AddSession(-1) return case model, ok := <-a.asyncChew(ctx): if ok { - a.wg.Add(1) // need to add 1 in this thread to ensure safe shutdown go func(model *models.Call) { a.asyncRun(ctx, model) - a.wg.Done() // can shed it after this is done, Submit will add 1 too but it's fine + a.shutWg.AddSession(-1) }(model) + + // WARNING: tricky. We reserve another session for next iteration of the loop + if !a.shutWg.AddSession(1) { + return + } } } } diff --git a/api/agent/lb_agent.go b/api/agent/lb_agent.go index 44bc5cce9..6e833af76 100644 --- a/api/agent/lb_agent.go +++ b/api/agent/lb_agent.go @@ -2,12 +2,12 @@ package agent import ( "context" - "sync" "time" "github.com/sirupsen/logrus" "go.opencensus.io/trace" + "github.com/fnproject/fn/api/common" "github.com/fnproject/fn/api/models" pool "github.com/fnproject/fn/api/runnerpool" "github.com/fnproject/fn/fnext" @@ -28,19 +28,19 @@ type lbAgent struct { delegatedAgent Agent rp pool.RunnerPool placer pool.Placer - - wg sync.WaitGroup // Needs a good name - shutdown chan struct{} + shutWg *common.WaitGroup } // NewLBAgent creates an Agent that knows how to load-balance function calls // across a group of runner nodes. func NewLBAgent(da DataAccess, rp pool.RunnerPool, p pool.Placer) (Agent, error) { - agent := createAgent(da, false) + wg := common.NewWaitGroup() + agent := createAgent(da, false, wg) a := &lbAgent{ delegatedAgent: agent, rp: rp, placer: p, + shutWg: wg, } return a, nil } @@ -63,18 +63,31 @@ func (a *lbAgent) GetCall(opts ...CallOpt) (Call, error) { } func (a *lbAgent) Close() error { - // we should really be passing the server's context here + + // start closing the front gate first + ch := a.shutWg.CloseGroupNB() + + // delegated agent shutdown next, blocks here... + err1 := a.delegatedAgent.Close() + if err1 != nil { + logrus.WithError(err1).Warn("Delegated agent shutdown error") + } + + // finally shutdown the runner pool ctx, cancel := context.WithTimeout(context.Background(), runnerPoolShutdownTimeout) defer cancel() - - close(a.shutdown) - a.rp.Shutdown(ctx) - err := a.delegatedAgent.Close() - a.wg.Wait() - if err != nil { - return err + err2 := a.rp.Shutdown(ctx) + if err2 != nil { + logrus.WithError(err2).Warn("Runner pool shutdown error") } - return nil + + // gate-on front-gate, should be completed if delegated agent & runner pool is gone. + <-ch + + if err1 != nil { + return err1 + } + return err2 } func GetGroupID(call *models.Call) string { @@ -90,13 +103,8 @@ func GetGroupID(call *models.Call) string { } func (a *lbAgent) Submit(callI Call) error { - a.wg.Add(1) - defer a.wg.Done() - - select { - case <-a.shutdown: + if !a.shutWg.AddSession(1) { return models.ErrCallTimeoutServerBusy - default: } call := callI.(*call) diff --git a/api/agent/pure_runner.go b/api/agent/pure_runner.go index 887839ec9..956cf536c 100644 --- a/api/agent/pure_runner.go +++ b/api/agent/pure_runner.go @@ -671,7 +671,7 @@ func DefaultPureRunner(cancel context.CancelFunc, addr string, da DataAccess, ce } func NewPureRunner(cancel context.CancelFunc, addr string, da DataAccess, cert string, key string, ca string, gate CapacityGate) (Agent, error) { - a := createAgent(da, true) + a := createAgent(da, true, nil) var pr *pureRunner var err error if cert != "" && key != "" && ca != "" { diff --git a/api/agent/runner_client.go b/api/agent/runner_client.go index d4843c38e..7a6819426 100644 --- a/api/agent/runner_client.go +++ b/api/agent/runner_client.go @@ -3,22 +3,26 @@ package agent import ( "context" "encoding/json" + "errors" "io" - "sync" "time" "google.golang.org/grpc" "google.golang.org/grpc/credentials" pb "github.com/fnproject/fn/api/agent/grpc" + "github.com/fnproject/fn/api/common" pool "github.com/fnproject/fn/api/runnerpool" "github.com/fnproject/fn/grpcutil" "github.com/sirupsen/logrus" ) +var ( + ErrorRunnerClosed = errors.New("Runner is closed") +) + type gRPCRunner struct { - // Need a WaitGroup of TryExec in flight - wg sync.WaitGroup + shutWg *common.WaitGroup address string conn *grpc.ClientConn client pb.RunnerProtocolClient @@ -31,6 +35,7 @@ func SecureGRPCRunnerFactory(addr, runnerCertCN string, pki *pool.PKIData) (pool } return &gRPCRunner{ + shutWg: common.NewWaitGroup(), address: addr, conn: conn, client: client, @@ -43,7 +48,7 @@ func (r *gRPCRunner) Close(ctx context.Context) error { err := make(chan error, 1) go func() { defer close(err) - r.wg.Wait() + r.shutWg.CloseGroup() err <- r.conn.Close() }() @@ -86,8 +91,10 @@ func (r *gRPCRunner) Address() string { func (r *gRPCRunner) TryExec(ctx context.Context, call pool.RunnerCall) (bool, error) { logrus.WithField("runner_addr", r.address).Debug("Attempting to place call") - r.wg.Add(1) - defer r.wg.Done() + if !r.shutWg.AddSession(1) { + return true, ErrorRunnerClosed + } + defer r.shutWg.AddSession(-1) // extract the call's model data to pass on to the pure runner modelJSON, err := json.Marshal(call.Model()) diff --git a/api/common/wait_utils.go b/api/common/wait_utils.go index cc1f5c35f..0fa544f5c 100644 --- a/api/common/wait_utils.go +++ b/api/common/wait_utils.go @@ -1,6 +1,7 @@ package common import ( + "fmt" "math" "sync" ) @@ -8,17 +9,20 @@ import ( /* WaitGroup is used to manage and wait for a collection of sessions. It is similar to sync.WaitGroup, but - AddSession/RmSession/WaitClose session is not only thread + AddSession/CloseGroup session is not only thread safe but can be executed in any order unlike sync.WaitGroup. Once a shutdown is initiated via CloseGroup(), add/rm operations will still function correctly, where - AddSession would return false error. - In this state, CloseGroup() blocks until sessions get drained - via RmSession() calls. + AddSession would return false. In this state, + CloseGroup() blocks until sessions get drained + via remove operations. - It is an error to call RmSession without a corresponding - successful AddSession. + It is an error to call AddSession() with invalid values. + For example, if current session count is 1, AddSession + can only add more or subtract 1 from this. Caller needs + to make sure addition/subtraction math is correct when + using WaitGroup. Example usage: @@ -26,11 +30,11 @@ import ( for item := range(items) { go func(item string) { - if !group.AddSession() { + if !group.AddSession(1) { // group may be closing or full return } - defer group.RmSession() + defer group.AddSession(-1) // do stuff }(item) @@ -42,57 +46,95 @@ import ( type WaitGroup struct { cond *sync.Cond + closer chan struct{} isClosed bool sessions uint64 } func NewWaitGroup() *WaitGroup { return &WaitGroup{ - cond: sync.NewCond(new(sync.Mutex)), + cond: sync.NewCond(new(sync.Mutex)), + closer: make(chan struct{}), } } -func (r *WaitGroup) AddSession() bool { +// Closer returns a channel that is closed if +// WaitGroup is in closing state +func (r *WaitGroup) Closer() chan struct{} { + return r.closer +} + +// AddSession manipulates the session counter by +// adding or subtracting the delta value. Incrementing +// the session counter is not possible and will set +// return value to false if a close was initiated. +// It's callers responsibility to make sure addition and +// subtraction math is correct. +func (r *WaitGroup) AddSession(delta int64) bool { r.cond.L.Lock() defer r.cond.L.Unlock() - if r.isClosed { - return false - } - if r.sessions == math.MaxUint64 { - return false - } + if delta >= 0 { + // we cannot add if we are being shutdown + if r.isClosed { + return false + } - r.sessions++ + incr := uint64(delta) + + // we have maxed out + if r.sessions == math.MaxUint64-incr { + return false + } + + r.sessions += incr + } else { + decr := uint64(-delta) + + // illegal operation, it's callers responsibility + // to make sure subtraction and addition math is correct. + if r.sessions < decr { + panic(fmt.Sprintf("common.WaitGroup misuse sum=%d decr=%d isClosed=%v", + r.sessions, decr, r.isClosed)) + } + + r.sessions -= decr + + // subtractions need to notify CloseGroup + r.cond.Broadcast() + } return true } -func (r *WaitGroup) RmSession() { - r.cond.L.Lock() - - if r.sessions == 0 { - panic("WaitGroup misuse: no sessions to remove") - } - - r.sessions-- - r.cond.Broadcast() - - r.cond.L.Unlock() -} - +// CloseGroup initiates a close and blocks until +// session counter becomes zero. func (r *WaitGroup) CloseGroup() { r.cond.L.Lock() - r.isClosed = true - for r.sessions > 0 { + if !r.isClosed { + r.isClosed = true + close(r.closer) + } + + for r.sessions != 0 { r.cond.Wait() } r.cond.L.Unlock() } +// CloseGroupNB is non-blocking version of CloseGroup +// which returns a channel that can be waited on. func (r *WaitGroup) CloseGroupNB() chan struct{} { + // set to closing state immediately + r.cond.L.Lock() + if !r.isClosed { + r.isClosed = true + close(r.closer) + } + r.cond.L.Unlock() + closer := make(chan struct{}) go func() { diff --git a/api/common/wait_utils_test.go b/api/common/wait_utils_test.go new file mode 100644 index 000000000..76661b498 --- /dev/null +++ b/api/common/wait_utils_test.go @@ -0,0 +1,124 @@ +package common + +import ( + "testing" +) + +func isClosed(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + } + return false +} + +func TestWaitGroupEmpty(t *testing.T) { + + wg := NewWaitGroup() + + if !wg.AddSession(0) { + t.Fatalf("Add 0 should not fail") + } + + if isClosed(wg.Closer()) { + t.Fatalf("Should not be closed yet") + } + + done := wg.CloseGroupNB() + + // gate-on close + wg.CloseGroup() + + if !isClosed(wg.Closer()) { + t.Fatalf("Should be closing state") + } + + if isClosed(done) { + t.Fatalf("NB Chan I should be closed") + } + + done = wg.CloseGroupNB() + if isClosed(done) { + t.Fatalf("NB Chan II should be closed") + } +} + +func TestWaitGroupSingle(t *testing.T) { + + wg := NewWaitGroup() + + if isClosed(wg.Closer()) { + t.Fatalf("Should not be closing state yet") + } + + if !wg.AddSession(1) { + t.Fatalf("Add 1 should not fail") + } + + if isClosed(wg.Closer()) { + t.Fatalf("Should not be closing state yet") + } + + if !wg.AddSession(-1) { + t.Fatalf("Add -1 should not fail") + } + + // sum should be zero now. + + if !wg.AddSession(2) { + t.Fatalf("Add 2 should not fail") + } + + // sum is 2 now + // initiate shutdown + done := wg.CloseGroupNB() + + if isClosed(done) { + t.Fatalf("NB Chan should not be closed yet, since sum is 2") + } + + if !wg.AddSession(-1) { + t.Fatalf("Add -1 should not fail") + } + if wg.AddSession(1) { + t.Fatalf("Add 1 should fail (we are shutting down)") + } + if !isClosed(wg.Closer()) { + t.Fatalf("Should be closing state") + } + + // sum is 1 now + + if isClosed(done) { + t.Fatalf("NB Chan should not be closed yet, since sum is 1") + } + + if wg.AddSession(0) { + t.Fatalf("Add 0 should fail (considered positive number and we are closing)") + } + + if wg.AddSession(100) { + t.Fatalf("Add 100 should fail (we are shutting down)") + } + + if !isClosed(wg.Closer()) { + t.Fatalf("Should be closing state") + } + + if !wg.AddSession(-1) { + t.Fatalf("Add -1 should not fail") + } + + // sum is 0 now + <-done + + if !isClosed(done) { + t.Fatalf("NB Chan should be closed, since sum is 0") + } + + if !isClosed(wg.Closer()) { + t.Fatalf("Should be closing state") + } + +}