From cf2c3cf40482926b71539262ea0d4283043cb62d Mon Sep 17 00:00:00 2001 From: James Jeffrey Date: Wed, 19 Jul 2017 13:44:26 -0700 Subject: [PATCH] Use chi style middle ware --- api/server/app_listeners.go | 35 ++++--- api/server/apps_create.go | 4 +- api/server/apps_delete.go | 2 +- api/server/apps_get.go | 3 +- api/server/apps_list.go | 3 +- api/server/apps_update.go | 2 +- api/server/call_get.go | 3 +- api/server/call_list.go | 3 +- api/server/call_logs.go | 5 +- api/server/error_response.go | 18 ++-- api/server/middleware.go | 151 +++++++++-------------------- api/server/middleware_test.go | 46 +++++++++ api/server/routes_create_update.go | 8 +- api/server/routes_delete.go | 3 +- api/server/routes_get.go | 3 +- api/server/routes_list.go | 11 ++- api/server/runner.go | 10 +- api/server/server.go | 2 - api/server/special_handler.go | 61 +++--------- api/server/special_handler_test.go | 11 ++- examples/middleware/main.go | 48 ++++----- 21 files changed, 192 insertions(+), 240 deletions(-) create mode 100644 api/server/middleware_test.go diff --git a/api/server/app_listeners.go b/api/server/app_listeners.go index 3d02ce085..f80b675d4 100644 --- a/api/server/app_listeners.go +++ b/api/server/app_listeners.go @@ -1,30 +1,34 @@ package server import ( + "context" + "gitlab-odx.oracle.com/odx/functions/api/models" ) +// AppListener is an interface used to inject custom code at key points in app lifecycle. type AppListener interface { // BeforeAppCreate called right before creating App in the database - BeforeAppCreate(ctx MiddlewareContext, app *models.App) error + BeforeAppCreate(ctx context.Context, app *models.App) error // AfterAppCreate called after creating App in the database - AfterAppCreate(ctx MiddlewareContext, app *models.App) error + AfterAppCreate(ctx context.Context, app *models.App) error // BeforeAppUpdate called right before updating App in the database - BeforeAppUpdate(ctx MiddlewareContext, app *models.App) error + BeforeAppUpdate(ctx context.Context, app *models.App) error // AfterAppUpdate called after updating App in the database - AfterAppUpdate(ctx MiddlewareContext, app *models.App) error + AfterAppUpdate(ctx context.Context, app *models.App) error // BeforeAppDelete called right before deleting App in the database - BeforeAppDelete(ctx MiddlewareContext, app *models.App) error + BeforeAppDelete(ctx context.Context, app *models.App) error // AfterAppDelete called after deleting App in the database - AfterAppDelete(ctx MiddlewareContext, app *models.App) error + AfterAppDelete(ctx context.Context, app *models.App) error } -// AddAppCreateListener adds a listener that will be notified on App created. +// AddAppListener adds a listener that will be notified on App created. func (s *Server) AddAppListener(listener AppListener) { s.appListeners = append(s.appListeners, listener) } -func (s *Server) FireBeforeAppCreate(ctx MiddlewareContext, app *models.App) error { +// FireBeforeAppCreate is used to call all the server's Listeners BeforeAppCreate functions. +func (s *Server) FireBeforeAppCreate(ctx context.Context, app *models.App) error { for _, l := range s.appListeners { err := l.BeforeAppCreate(ctx, app) if err != nil { @@ -34,7 +38,8 @@ func (s *Server) FireBeforeAppCreate(ctx MiddlewareContext, app *models.App) err return nil } -func (s *Server) FireAfterAppCreate(ctx MiddlewareContext, app *models.App) error { +// FireAfterAppCreate is used to call all the server's Listeners AfterAppCreate functions. +func (s *Server) FireAfterAppCreate(ctx context.Context, app *models.App) error { for _, l := range s.appListeners { err := l.AfterAppCreate(ctx, app) if err != nil { @@ -44,7 +49,8 @@ func (s *Server) FireAfterAppCreate(ctx MiddlewareContext, app *models.App) erro return nil } -func (s *Server) FireBeforeAppUpdate(ctx MiddlewareContext, app *models.App) error { +// FireBeforeAppUpdate is used to call all the server's Listeners BeforeAppUpdate functions. +func (s *Server) FireBeforeAppUpdate(ctx context.Context, app *models.App) error { for _, l := range s.appListeners { err := l.BeforeAppUpdate(ctx, app) if err != nil { @@ -54,7 +60,8 @@ func (s *Server) FireBeforeAppUpdate(ctx MiddlewareContext, app *models.App) err return nil } -func (s *Server) FireAfterAppUpdate(ctx MiddlewareContext, app *models.App) error { +// FireAfterAppUpdate is used to call all the server's Listeners AfterAppUpdate functions. +func (s *Server) FireAfterAppUpdate(ctx context.Context, app *models.App) error { for _, l := range s.appListeners { err := l.AfterAppUpdate(ctx, app) if err != nil { @@ -64,7 +71,8 @@ func (s *Server) FireAfterAppUpdate(ctx MiddlewareContext, app *models.App) erro return nil } -func (s *Server) FireBeforeAppDelete(ctx MiddlewareContext, app *models.App) error { +// FireBeforeAppDelete is used to call all the server's Listeners BeforeAppDelete functions. +func (s *Server) FireBeforeAppDelete(ctx context.Context, app *models.App) error { for _, l := range s.appListeners { err := l.BeforeAppDelete(ctx, app) if err != nil { @@ -74,7 +82,8 @@ func (s *Server) FireBeforeAppDelete(ctx MiddlewareContext, app *models.App) err return nil } -func (s *Server) FireAfterAppDelete(ctx MiddlewareContext, app *models.App) error { +// FireAfterAppDelete is used to call all the server's Listeners AfterAppDelete functions. +func (s *Server) FireAfterAppDelete(ctx context.Context, app *models.App) error { for _, l := range s.appListeners { err := l.AfterAppDelete(ctx, app) if err != nil { diff --git a/api/server/apps_create.go b/api/server/apps_create.go index 63e51a7ac..d8fca1189 100644 --- a/api/server/apps_create.go +++ b/api/server/apps_create.go @@ -8,7 +8,7 @@ import ( ) func (s *Server) handleAppCreate(c *gin.Context) { - ctx := c.MustGet("mctx").(MiddlewareContext) + ctx := c.Request.Context() var wapp models.AppWrapper @@ -23,7 +23,7 @@ func (s *Server) handleAppCreate(c *gin.Context) { return } - if err := wapp.Validate(); err != nil { + if err = wapp.Validate(); err != nil { handleErrorResponse(c, err) return } diff --git a/api/server/apps_delete.go b/api/server/apps_delete.go index bcfe13b35..a260747a6 100644 --- a/api/server/apps_delete.go +++ b/api/server/apps_delete.go @@ -10,7 +10,7 @@ import ( ) func (s *Server) handleAppDelete(c *gin.Context) { - ctx := c.MustGet("mctx").(MiddlewareContext) + ctx := c.Request.Context() log := common.Logger(ctx) app := &models.App{Name: c.MustGet(api.AppName).(string)} diff --git a/api/server/apps_get.go b/api/server/apps_get.go index d0cc480b2..a6af705db 100644 --- a/api/server/apps_get.go +++ b/api/server/apps_get.go @@ -1,7 +1,6 @@ package server import ( - "context" "net/http" "github.com/gin-gonic/gin" @@ -9,7 +8,7 @@ import ( ) func (s *Server) handleAppGet(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() appName := c.MustGet(api.AppName).(string) app, err := s.Datastore.GetApp(ctx, appName) diff --git a/api/server/apps_list.go b/api/server/apps_list.go index 54ea3e1ab..851e23830 100644 --- a/api/server/apps_list.go +++ b/api/server/apps_list.go @@ -1,7 +1,6 @@ package server import ( - "context" "net/http" "github.com/gin-gonic/gin" @@ -9,7 +8,7 @@ import ( ) func (s *Server) handleAppList(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() filter := &models.AppFilter{} diff --git a/api/server/apps_update.go b/api/server/apps_update.go index 0430bdbdf..1d1e0fb8e 100644 --- a/api/server/apps_update.go +++ b/api/server/apps_update.go @@ -9,7 +9,7 @@ import ( ) func (s *Server) handleAppUpdate(c *gin.Context) { - ctx := c.MustGet("mctx").(MiddlewareContext) + ctx := c.Request.Context() wapp := models.AppWrapper{} diff --git a/api/server/call_get.go b/api/server/call_get.go index 35d3549f2..6131918bd 100644 --- a/api/server/call_get.go +++ b/api/server/call_get.go @@ -1,7 +1,6 @@ package server import ( - "context" "net/http" "github.com/gin-gonic/gin" @@ -9,7 +8,7 @@ import ( ) func (s *Server) handleCallGet(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() callID := c.Param(api.Call) callObj, err := s.Datastore.GetTask(ctx, callID) diff --git a/api/server/call_list.go b/api/server/call_list.go index 0fc74d903..611f67cfa 100644 --- a/api/server/call_list.go +++ b/api/server/call_list.go @@ -1,7 +1,6 @@ package server import ( - "context" "net/http" "github.com/gin-gonic/gin" @@ -10,7 +9,7 @@ import ( ) func (s *Server) handleCallList(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() appName, ok := c.MustGet(api.AppName).(string) if ok && appName == "" { diff --git a/api/server/call_logs.go b/api/server/call_logs.go index 79a031dd5..c9b0beb33 100644 --- a/api/server/call_logs.go +++ b/api/server/call_logs.go @@ -1,7 +1,6 @@ package server import ( - "context" "net/http" "github.com/gin-gonic/gin" @@ -9,7 +8,7 @@ import ( ) func (s *Server) handleCallLogGet(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() callID := c.Param(api.Call) _, err := s.Datastore.GetTask(ctx, callID) @@ -28,7 +27,7 @@ func (s *Server) handleCallLogGet(c *gin.Context) { } func (s *Server) handleCallLogDelete(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() callID := c.Param(api.Call) _, err := s.Datastore.GetTask(ctx, callID) diff --git a/api/server/error_response.go b/api/server/error_response.go index 5972d7eef..7cfceea98 100644 --- a/api/server/error_response.go +++ b/api/server/error_response.go @@ -1,7 +1,6 @@ package server import ( - "context" "errors" "net/http" "runtime/debug" @@ -12,6 +11,7 @@ import ( "gitlab-odx.oracle.com/odx/functions/api/runner/common" ) +// ErrInternalServerError returned when something exceptional happens. var ErrInternalServerError = errors.New("internal server error") func simpleError(err error) *models.Error { @@ -19,14 +19,14 @@ func simpleError(err error) *models.Error { } func handleErrorResponse(c *gin.Context, err error) { - ctx := c.MustGet("ctx").(context.Context) - log := common.Logger(ctx) - - if aerr, ok := err.(models.APIError); ok { - log.WithFields(logrus.Fields{"code": aerr.Code()}).WithError(err).Error("api error") - c.JSON(aerr.Code(), simpleError(err)) - } else if err != nil { - // get a stack trace so we can trace this error + log := common.Logger(c.Request.Context()) + switch e := err.(type) { + case models.APIError: + if e.Code() >= 500 { + log.WithFields(logrus.Fields{"code": e.Code()}).WithError(e).Error("api error") + } + c.JSON(e.Code(), simpleError(e)) + default: log.WithError(err).WithFields(logrus.Fields{"stack": string(debug.Stack())}).Error("internal server error") c.JSON(http.StatusInternalServerError, simpleError(ErrInternalServerError)) } diff --git a/api/server/middleware.go b/api/server/middleware.go index e2f0220d4..97f1318b8 100644 --- a/api/server/middleware.go +++ b/api/server/middleware.go @@ -1,132 +1,69 @@ -// TODO: it would be nice to move these into the top level folder so people can use these with the "functions" package, eg: functions.AddMiddleware(...) package server import ( "context" "net/http" - "github.com/Sirupsen/logrus" + "gitlab-odx.oracle.com/odx/functions/api/runner/common" + "github.com/gin-gonic/gin" - "gitlab-odx.oracle.com/odx/functions/api/models" ) -// Middleware is the interface required for implementing functions middlewar +// Middleware just takes a http.Handler and returns one. So the next middle ware must be called +// within the returned handler or it would be ignored. type Middleware interface { - // Serve is what the Middleware must implement. Can modify the request, write output, etc. - // todo: should we abstract the HTTP out of this? In case we want to support other protocols. - Serve(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error + Chain(next http.Handler) http.Handler } -// MiddlewareFunc func form of Middleware -type MiddlewareFunc func(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error +// MiddlewareFunc is a here to allow a plain function to be a middleware. +type MiddlewareFunc func(next http.Handler) http.Handler -// Serve wrapper -func (f MiddlewareFunc) Serve(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error { - return f(ctx, w, r, app) -} - -// MiddlewareContext extends context.Context for Middleware -type MiddlewareContext interface { - context.Context - - // Set is used to store a new key/value pair exclusively for this context. - // This is different than WithValue(), as it does not make a copy of the context with the new value, it will be available up the chain as well. - Set(key string, value interface{}) - - // Get returns the value for the given key, ie: (value, true). - // If the value does not exists it returns (nil, false) - Get(key string) (value interface{}, exists bool) - - // MustGet returns the value for the given key if it exists, otherwise it panics. - MustGet(key string) interface{} - - // Middleware can call Next() explicitly to call the next middleware in the chain. If Next() is not called and an error is not returned, Next() will automatically be called. - Next() - // Index returns the index of where we're at in the chain - Index() int -} - -type middlewareContextImpl struct { - context.Context - - ginContext *gin.Context - nextCalled bool - index int - middlewares []Middleware -} - -// Set is used to store a new key/value pair exclusively for this context. -// This is different than WithValue(), as it does not make a copy of the context with the new value, it will be available up the chain as well. -func (c *middlewareContextImpl) Set(key string, value interface{}) { - c.ginContext.Set(key, value) -} - -// Get returns the value for the given key, ie: (value, true). -// If the value does not exists it returns (nil, false) -func (c *middlewareContextImpl) Get(key string) (value interface{}, exists bool) { - return c.ginContext.Get(key) -} - -// MustGet returns the value for the given key if it exists, otherwise it panics. -func (c *middlewareContextImpl) MustGet(key string) interface{} { - return c.ginContext.MustGet(key) -} - -func (c *middlewareContextImpl) Next() { - c.nextCalled = true - c.index++ - c.serveNext() -} - -func (c *middlewareContextImpl) serveNext() { - if c.Index() >= len(c.middlewares) { - return - } - // make shallow copy: - fctx2 := *c - fctx2.nextCalled = false - r := c.ginContext.Request.WithContext(fctx2) - err := c.middlewares[c.Index()].Serve(&fctx2, c.ginContext.Writer, r, nil) - if err != nil { - logrus.WithError(err).Warnln("Middleware error") - // todo: might be a good idea to check if anything is written yet, and if not, output the error: simpleError(err) - // see: http://stackoverflow.com/questions/39415827/golang-http-check-if-responsewriter-has-been-written - c.ginContext.Error(err) - c.ginContext.Abort() - return - } - if !fctx2.nextCalled { - // then we automatically call next - fctx2.Next() - } - -} - -func (c *middlewareContextImpl) Index() int { - return c.index +// Chain used to allow middlewarefuncs to be middleware. +func (m MiddlewareFunc) Chain(next http.Handler) http.Handler { + return m(next) } func (s *Server) middlewareWrapperFunc(ctx context.Context) gin.HandlerFunc { return func(c *gin.Context) { - // TODO: we should get rid of this, gin context and middleware context both implement context, don't need a third one here - ctx = c.MustGet("ctx").(context.Context) - fctx := &middlewareContextImpl{Context: ctx} - // add this context to gin context so we can grab it later - c.Set("mctx", fctx) - fctx.index = -1 - fctx.ginContext = c - fctx.middlewares = s.middlewares - // start the chain: - fctx.Next() + if len(s.middlewares) > 0 { + defer func() { + //This is so that if the server errors or panics on a middleware the server will still respond and not send eof to client. + err := recover() + if err != nil { + common.Logger(c.Request.Context()).WithField("MiddleWarePanicRecovery:", err).Errorln("A panic occurred during middleware.") + handleErrorResponse(c, ErrInternalServerError) + } + }() + var h http.Handler + keepgoing := false + h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Request = c.Request.WithContext(r.Context()) + keepgoing = true + }) + + s.chainAndServe(c.Writer, c.Request, h) + if !keepgoing { + c.Abort() + } + } } } +func (s *Server) chainAndServe(w http.ResponseWriter, r *http.Request, h http.Handler) { + for _, m := range s.middlewares { + h = m.Chain(h) + } + h.ServeHTTP(w, r) +} + // AddMiddleware add middleware func (s *Server) AddMiddleware(m Middleware) { - s.middlewares = append(s.middlewares, m) + //Prepend to array so that we can do first,second,third,last,third,second,first + //and not third,second,first,last,first,second,third + s.middlewares = append([]Middleware{m}, s.middlewares...) } -// AddMiddlewareFunc adds middleware function -func (s *Server) AddMiddlewareFunc(m func(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error) { - s.AddMiddleware(MiddlewareFunc(m)) +// AddMiddlewareFunc add middlewarefunc +func (s *Server) AddMiddlewareFunc(m MiddlewareFunc) { + s.AddMiddleware(m) } diff --git a/api/server/middleware_test.go b/api/server/middleware_test.go new file mode 100644 index 000000000..3e0b0e2c4 --- /dev/null +++ b/api/server/middleware_test.go @@ -0,0 +1,46 @@ +package server + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +type middleWareStruct struct { + name string +} + +func (m *middleWareStruct) Chain(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(m.name + ",")) + next.ServeHTTP(w, r) + }) +} + +func TestMiddleWareChaining(t *testing.T) { + var lastHandler http.Handler + lastHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("last")) + }) + + s := Server{} + s.AddMiddleware(&middleWareStruct{"first"}) + s.AddMiddleware(&middleWareStruct{"second"}) + s.AddMiddleware(&middleWareStruct{"third"}) + s.AddMiddleware(&middleWareStruct{"fourth"}) + + rec := httptest.NewRecorder() + req, _ := http.NewRequest("get", "http://localhost/", nil) + + s.chainAndServe(rec, req, lastHandler) + + result, err := ioutil.ReadAll(rec.Result().Body) + if err != nil { + t.Fatal(err) + } + + if string(result) != "first,second,third,fourth,last" { + t.Fatal("You failed to chain correctly.", string(result)) + } +} diff --git a/api/server/routes_create_update.go b/api/server/routes_create_update.go index e9d23bfb4..183652953 100644 --- a/api/server/routes_create_update.go +++ b/api/server/routes_create_update.go @@ -23,12 +23,12 @@ import ( Patch accepts partial updates / skips validation of zero values. */ func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) { - ctx := c.MustGet("mctx").(MiddlewareContext) + ctx := c.Request.Context() method := strings.ToUpper(c.Request.Method) var wroute models.RouteWrapper - err := s.bindAndValidate(ctx, c, method, &wroute) + err := s.bindAndValidate(c, method, &wroute) if err != nil { handleErrorResponse(c, err) return @@ -53,7 +53,7 @@ func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) { } // ensureApp will only execute if it is on post or put. Patch is not allowed to create apps. -func (s *Server) ensureApp(ctx MiddlewareContext, wroute *models.RouteWrapper, method string) error { +func (s *Server) ensureApp(ctx context.Context, wroute *models.RouteWrapper, method string) error { if !(method == http.MethodPost || method == http.MethodPut) { return nil } @@ -90,7 +90,7 @@ func (s *Server) ensureApp(ctx MiddlewareContext, wroute *models.RouteWrapper, m If it is a put or patch it makes sure that the path in the url matches the provideed one in the body. Defaults are set and if patch skipZero is true for validating the RouteWrapper */ -func (s *Server) bindAndValidate(ctx context.Context, c *gin.Context, method string, wroute *models.RouteWrapper) error { +func (s *Server) bindAndValidate(c *gin.Context, method string, wroute *models.RouteWrapper) error { err := c.BindJSON(wroute) if err != nil { return models.ErrInvalidJSON diff --git a/api/server/routes_delete.go b/api/server/routes_delete.go index d5670c7dc..9749342cb 100644 --- a/api/server/routes_delete.go +++ b/api/server/routes_delete.go @@ -1,7 +1,6 @@ package server import ( - "context" "net/http" "path" @@ -10,7 +9,7 @@ import ( ) func (s *Server) handleRouteDelete(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() appName := c.MustGet(api.AppName).(string) routePath := path.Clean(c.MustGet(api.Path).(string)) diff --git a/api/server/routes_get.go b/api/server/routes_get.go index 09eec93ef..b5444caf0 100644 --- a/api/server/routes_get.go +++ b/api/server/routes_get.go @@ -1,7 +1,6 @@ package server import ( - "context" "net/http" "path" @@ -10,7 +9,7 @@ import ( ) func (s *Server) handleRouteGet(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() appName := c.MustGet(api.AppName).(string) routePath := path.Clean(c.MustGet(api.Path).(string)) diff --git a/api/server/routes_list.go b/api/server/routes_list.go index 3ea13f465..6501280a9 100644 --- a/api/server/routes_list.go +++ b/api/server/routes_list.go @@ -1,7 +1,6 @@ package server import ( - "context" "net/http" "github.com/gin-gonic/gin" @@ -10,7 +9,7 @@ import ( ) func (s *Server) handleRouteList(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() filter := &models.RouteFilter{} @@ -20,8 +19,10 @@ func (s *Server) handleRouteList(c *gin.Context) { var routes []*models.Route var err error - if appName, ok := c.MustGet(api.AppName).(string); ok && appName != "" { - routes, err = s.Datastore.GetRoutesByApp(ctx, appName, filter) + appName, exists := c.Get(api.AppName) + name, ok := appName.(string) + if exists && ok && name != "" { + routes, err = s.Datastore.GetRoutesByApp(ctx, name, filter) } else { routes, err = s.Datastore.GetRoutes(ctx, filter) } @@ -31,5 +32,5 @@ func (s *Server) handleRouteList(c *gin.Context) { return } - c.JSON(http.StatusOK, routesResponse{"Sucessfully listed routes", routes}) + c.JSON(http.StatusOK, routesResponse{"Successfully listed routes", routes}) } diff --git a/api/server/runner.go b/api/server/runner.go index f06d635ba..94fb5210a 100644 --- a/api/server/runner.go +++ b/api/server/runner.go @@ -28,21 +28,21 @@ type runnerResponse struct { } func (s *Server) handleSpecial(c *gin.Context) { - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() ctx = context.WithValue(ctx, api.AppName, "") c.Set(api.AppName, "") ctx = context.WithValue(ctx, api.Path, c.Request.URL.Path) c.Set(api.Path, c.Request.URL.Path) - ctx, err := s.UseSpecialHandlers(ctx, c.Request, c.Writer) + r, err := s.UseSpecialHandlers(c.Writer, c.Request) if err != nil { handleErrorResponse(c, err) return } - c.Set("ctx", ctx) - c.Set(api.AppName, ctx.Value(api.AppName).(string)) + c.Request = r + c.Set(api.AppName, r.Context().Value(api.AppName).(string)) if c.MustGet(api.AppName).(string) == "" { handleErrorResponse(c, models.ErrRunnerRouteNotFound) return @@ -66,7 +66,7 @@ func (s *Server) handleRequest(c *gin.Context, enqueue models.Enqueue) { return } - ctx := c.MustGet("ctx").(context.Context) + ctx := c.Request.Context() reqID := id.New().String() ctx, log := common.LoggerWithFields(ctx, logrus.Fields{"call_id": reqID}) diff --git a/api/server/server.go b/api/server/server.go index d1a86b953..6747769a3 100644 --- a/api/server/server.go +++ b/api/server/server.go @@ -161,8 +161,6 @@ func prepareMiddleware(ctx context.Context) gin.HandlerFunc { c.Set(api.Path, routePath) } - // todo: can probably replace the "ctx" value with the Go 1.7 context on the http.Request - c.Set("ctx", ctx) c.Request = c.Request.WithContext(ctx) c.Next() } diff --git a/api/server/special_handler.go b/api/server/special_handler.go index 00023a433..831c4bf6a 100644 --- a/api/server/special_handler.go +++ b/api/server/special_handler.go @@ -1,73 +1,36 @@ package server import ( - "context" "net/http" "gitlab-odx.oracle.com/odx/functions/api/models" ) +// SpecialHandler verysimilar to a handler but since it is not used as middle ware no way +// to get context without returning it. So we just return a request which could have newly made +// contexts. type SpecialHandler interface { - Handle(c HandlerContext) error -} - -// Each handler can modify the context here so when it gets passed along, it will use the new info. -type HandlerContext interface { - // Context return the context object - Context() context.Context - - // Request returns the underlying http.Request object - Request() *http.Request - - // Response returns the http.ResponseWriter - Response() http.ResponseWriter - - // Overwrite value in the context - Set(key string, value interface{}) -} - -type SpecialHandlerContext struct { - request *http.Request - response http.ResponseWriter - ctx context.Context -} - -func (c *SpecialHandlerContext) Context() context.Context { - return c.ctx -} - -func (c *SpecialHandlerContext) Request() *http.Request { - return c.request -} - -func (c *SpecialHandlerContext) Response() http.ResponseWriter { - return c.response -} - -func (c *SpecialHandlerContext) Set(key string, value interface{}) { - c.ctx = context.WithValue(c.ctx, key, value) + Handle(w http.ResponseWriter, r *http.Request) (*http.Request, error) } +// AddSpecialHandler adds the SpecialHandler to the specialHandlers list. func (s *Server) AddSpecialHandler(handler SpecialHandler) { s.specialHandlers = append(s.specialHandlers, handler) } // UseSpecialHandlers execute all special handlers -func (s *Server) UseSpecialHandlers(ctx context.Context, req *http.Request, resp http.ResponseWriter) (context.Context, error) { +func (s *Server) UseSpecialHandlers(resp http.ResponseWriter, req *http.Request) (*http.Request, error) { if len(s.specialHandlers) == 0 { - return ctx, models.ErrNoSpecialHandlerFound + return req, models.ErrNoSpecialHandlerFound } + var r *http.Request + var err error - c := &SpecialHandlerContext{ - request: req, - response: resp, - ctx: ctx, - } for _, l := range s.specialHandlers { - err := l.Handle(c) + r, err = l.Handle(resp, req) if err != nil { - return c.ctx, err + return nil, err } } - return c.ctx, nil + return r, nil } diff --git a/api/server/special_handler_test.go b/api/server/special_handler_test.go index 85ddf23f8..b6989ad11 100644 --- a/api/server/special_handler_test.go +++ b/api/server/special_handler_test.go @@ -1,12 +1,15 @@ package server -import "testing" +import ( + "net/http" + "testing" +) type testSpecialHandler struct{} -func (h *testSpecialHandler) Handle(c HandlerContext) error { - // c.Set(api.AppName, "test") - return nil +func (h *testSpecialHandler) Handle(w http.ResponseWriter, r *http.Request) (*http.Request, error) { + // r = r.WithContext(context.WithValue(r.Context(), api.AppName, "test")) + return r, nil } func TestSpecialHandlerSet(t *testing.T) { diff --git a/examples/middleware/main.go b/examples/middleware/main.go index 143b67ea6..413e6b3a6 100644 --- a/examples/middleware/main.go +++ b/examples/middleware/main.go @@ -3,13 +3,11 @@ package main import ( "context" "encoding/json" - "errors" "fmt" "net/http" "strings" "time" - "gitlab-odx.oracle.com/odx/functions/api/models" "gitlab-odx.oracle.com/odx/functions/api/server" ) @@ -18,13 +16,15 @@ func main() { funcServer := server.NewFromEnv(ctx) - funcServer.AddMiddlewareFunc(func(ctx server.MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error { - start := time.Now() - fmt.Println("CustomMiddlewareFunc called at:", start) - ctx.Next() - fmt.Println("Duration:", (time.Now().Sub(start))) - return nil + funcServer.AddMiddlewareFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + fmt.Println("CustomMiddlewareFunc called at:", start) + next.ServeHTTP(w, r) + fmt.Println("Duration:", (time.Now().Sub(start))) + }) }) + funcServer.AddMiddleware(&CustomMiddleware{}) funcServer.Start(ctx) @@ -33,20 +33,22 @@ func main() { type CustomMiddleware struct { } -func (h *CustomMiddleware) Serve(ctx server.MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error { - fmt.Println("CustomMiddleware called") +func (h *CustomMiddleware) Serve(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Println("CustomMiddleware called") - // check auth header - tokenHeader := strings.SplitN(r.Header.Get("Authorization"), " ", 3) - if len(tokenHeader) < 2 || tokenHeader[1] != "KlaatuBaradaNikto" { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - m2 := map[string]string{"message": "Invalid Authorization token."} - m := map[string]map[string]string{"error": m2} - json.NewEncoder(w).Encode(m) - return errors.New("Invalid authorization token.") - } - fmt.Println("auth succeeded!") - ctx.Set("user", "I'm in!") - return nil + // check auth header + tokenHeader := strings.SplitN(r.Header.Get("Authorization"), " ", 3) + if len(tokenHeader) < 2 || tokenHeader[1] != "KlaatuBaradaNikto" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + m2 := map[string]string{"message": "Invalid Authorization token."} + m := map[string]map[string]string{"error": m2} + json.NewEncoder(w).Encode(m) + return + } + fmt.Println("auth succeeded!") + r = r.WithContext(context.WithValue(r.Context(), "user", "I'm in!")) + next.ServeHTTP(w, r) + }) }