From aa12f3c724223f884f1e6e467738c0fbef7ade4f Mon Sep 17 00:00:00 2001 From: C Cirello Date: Thu, 6 Oct 2016 00:32:56 +0200 Subject: [PATCH] Add graceful shutdown support for async runners (#125) --- api/runner/async_runner.go | 59 +++++++++++++++++++++++---------- api/runner/async_runner_test.go | 20 +++++++++++ main.go | 21 +++++++++--- 3 files changed, 78 insertions(+), 22 deletions(-) diff --git a/api/runner/async_runner.go b/api/runner/async_runner.go index 0e3e214d8..20e6600c0 100644 --- a/api/runner/async_runner.go +++ b/api/runner/async_runner.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "net/url" + "sync" "time" log "github.com/Sirupsen/logrus" @@ -95,28 +96,50 @@ func runTask(task *models.Task) error { } // RunAsyncRunner pulls tasks off a queue and processes them -func RunAsyncRunner(tasksrv, port string) { +func RunAsyncRunner(ctx context.Context, wgAsync *sync.WaitGroup, tasksrv, port string, n int) { u := tasksrvURL(tasksrv, port) + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go startAsyncRunners(ctx, &wg, i, u, runTask) + } + + wg.Wait() + <-ctx.Done() + wgAsync.Done() +} + +func startAsyncRunners(ctx context.Context, wg *sync.WaitGroup, i int, url string, runTask func(task *models.Task) error) { + defer wg.Done() for { - task, err := getTask(u) - if err != nil { - log.WithError(err).Info("Cannot get task") - time.Sleep(1 * time.Second) - continue - } - log.Info("Picked up task:", task.ID) + select { + case <-ctx.Done(): + return - // Process Task - if err := runTask(task); err != nil { - log.WithError(err) - continue - } - log.Info("Processed task:", task.ID) + default: + task, err := getTask(url) + if err != nil { + log.WithError(err).Error("Could not fetch task") + time.Sleep(1 * time.Second) + continue + } + log.Info("Picked up task:", task.ID) + + log.Info("Running task:", task.ID) + // Process Task + if err := runTask(task); err != nil { + log.WithError(err).WithFields(log.Fields{"async runner": i, "task_id": task.ID}).Error("Cannot run task") + continue + } + log.Info("Processed task:", task.ID) + + // Delete task from queue + if err := deleteTask(url, task); err != nil { + log.WithError(err).WithFields(log.Fields{"async runner": i, "task_id": task.ID}).Error("Cannot delete task") + continue + } - // Delete task from queue - if err := deleteTask(u, task); err != nil { - log.WithError(err) - } else { log.Info("Deleted task:", task.ID) } } diff --git a/api/runner/async_runner_test.go b/api/runner/async_runner_test.go index f9e6a0b67..292690cd0 100644 --- a/api/runner/async_runner_test.go +++ b/api/runner/async_runner_test.go @@ -7,7 +7,9 @@ import ( "math/rand" "net/http" "net/http/httptest" + "sync" "testing" + "time" "github.com/Sirupsen/logrus" "github.com/gin-gonic/gin" @@ -167,3 +169,21 @@ func TestTasksrvURL(t *testing.T) { } } } + +func TestAsyncRunnersGracefulShutdown(t *testing.T) { + mockTask := getMockTask() + ts := getTestServer([]*models.Task{&mockTask}) + defer ts.Close() + + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + var wg sync.WaitGroup + wg.Add(1) + go startAsyncRunners(ctx, &wg, 0, ts.URL+"/tasks", func(task *models.Task) error { + return nil + }) + wg.Wait() + + if err := ctx.Err(); err != context.DeadlineExceeded { + t.Errorf("async runners stopped unexpectedly. context error: %v", err) + } +} diff --git a/main.go b/main.go index 5557e75cc..3c9454e2e 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,9 @@ package main import ( "fmt" "os" + "os/signal" "strings" + "sync" log "github.com/Sirupsen/logrus" "github.com/iron-io/functions/api/datastore" @@ -38,7 +40,14 @@ func init() { } func main() { - ctx := context.Background() + ctx, halt := context.WithCancel(context.Background()) + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + go func() { + <-c + log.Info("Halting...") + halt() + }() ds, err := datastore.New(viper.GetString("DB")) if err != nil { @@ -57,10 +66,14 @@ func main() { tasksURL, port, nasync := viper.GetString("tasks_url"), viper.GetString("port"), viper.GetInt("nasync") log.Info("async workers:", nasync) - for i := 0; i < nasync; i++ { - go runner.RunAsyncRunner(tasksURL, port) + var wgAsync sync.WaitGroup + if nasync > 0 { + wgAsync.Add(1) + go runner.RunAsyncRunner(ctx, &wgAsync, tasksURL, port, nasync) } srv := server.New(ds, mqType, rnr) - srv.Run(ctx) + go srv.Run(ctx) + <-ctx.Done() + wgAsync.Wait() }