Use chi style middle ware

This commit is contained in:
James Jeffrey
2017-07-19 13:44:26 -07:00
committed by Travis Reeder
parent 570e9265f1
commit cf2c3cf404
21 changed files with 192 additions and 240 deletions

View File

@@ -1,30 +1,34 @@
package server package server
import ( import (
"context"
"gitlab-odx.oracle.com/odx/functions/api/models" "gitlab-odx.oracle.com/odx/functions/api/models"
) )
// AppListener is an interface used to inject custom code at key points in app lifecycle.
type AppListener interface { type AppListener interface {
// BeforeAppCreate called right before creating App in the database // BeforeAppCreate called right before creating App in the database
BeforeAppCreate(ctx MiddlewareContext, app *models.App) error BeforeAppCreate(ctx context.Context, app *models.App) error
// AfterAppCreate called after creating App in the database // AfterAppCreate called after creating App in the database
AfterAppCreate(ctx MiddlewareContext, app *models.App) error AfterAppCreate(ctx context.Context, app *models.App) error
// BeforeAppUpdate called right before updating App in the database // BeforeAppUpdate called right before updating App in the database
BeforeAppUpdate(ctx MiddlewareContext, app *models.App) error BeforeAppUpdate(ctx context.Context, app *models.App) error
// AfterAppUpdate called after updating App in the database // AfterAppUpdate called after updating App in the database
AfterAppUpdate(ctx MiddlewareContext, app *models.App) error AfterAppUpdate(ctx context.Context, app *models.App) error
// BeforeAppDelete called right before deleting App in the database // BeforeAppDelete called right before deleting App in the database
BeforeAppDelete(ctx MiddlewareContext, app *models.App) error BeforeAppDelete(ctx context.Context, app *models.App) error
// AfterAppDelete called after deleting App in the database // AfterAppDelete called after deleting App in the database
AfterAppDelete(ctx MiddlewareContext, app *models.App) error AfterAppDelete(ctx context.Context, app *models.App) error
} }
// AddAppCreateListener adds a listener that will be notified on App created. // AddAppListener adds a listener that will be notified on App created.
func (s *Server) AddAppListener(listener AppListener) { func (s *Server) AddAppListener(listener AppListener) {
s.appListeners = append(s.appListeners, listener) s.appListeners = append(s.appListeners, listener)
} }
func (s *Server) FireBeforeAppCreate(ctx MiddlewareContext, app *models.App) error { // FireBeforeAppCreate is used to call all the server's Listeners BeforeAppCreate functions.
func (s *Server) FireBeforeAppCreate(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners { for _, l := range s.appListeners {
err := l.BeforeAppCreate(ctx, app) err := l.BeforeAppCreate(ctx, app)
if err != nil { if err != nil {
@@ -34,7 +38,8 @@ func (s *Server) FireBeforeAppCreate(ctx MiddlewareContext, app *models.App) err
return nil return nil
} }
func (s *Server) FireAfterAppCreate(ctx MiddlewareContext, app *models.App) error { // FireAfterAppCreate is used to call all the server's Listeners AfterAppCreate functions.
func (s *Server) FireAfterAppCreate(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners { for _, l := range s.appListeners {
err := l.AfterAppCreate(ctx, app) err := l.AfterAppCreate(ctx, app)
if err != nil { if err != nil {
@@ -44,7 +49,8 @@ func (s *Server) FireAfterAppCreate(ctx MiddlewareContext, app *models.App) erro
return nil return nil
} }
func (s *Server) FireBeforeAppUpdate(ctx MiddlewareContext, app *models.App) error { // FireBeforeAppUpdate is used to call all the server's Listeners BeforeAppUpdate functions.
func (s *Server) FireBeforeAppUpdate(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners { for _, l := range s.appListeners {
err := l.BeforeAppUpdate(ctx, app) err := l.BeforeAppUpdate(ctx, app)
if err != nil { if err != nil {
@@ -54,7 +60,8 @@ func (s *Server) FireBeforeAppUpdate(ctx MiddlewareContext, app *models.App) err
return nil return nil
} }
func (s *Server) FireAfterAppUpdate(ctx MiddlewareContext, app *models.App) error { // FireAfterAppUpdate is used to call all the server's Listeners AfterAppUpdate functions.
func (s *Server) FireAfterAppUpdate(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners { for _, l := range s.appListeners {
err := l.AfterAppUpdate(ctx, app) err := l.AfterAppUpdate(ctx, app)
if err != nil { if err != nil {
@@ -64,7 +71,8 @@ func (s *Server) FireAfterAppUpdate(ctx MiddlewareContext, app *models.App) erro
return nil return nil
} }
func (s *Server) FireBeforeAppDelete(ctx MiddlewareContext, app *models.App) error { // FireBeforeAppDelete is used to call all the server's Listeners BeforeAppDelete functions.
func (s *Server) FireBeforeAppDelete(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners { for _, l := range s.appListeners {
err := l.BeforeAppDelete(ctx, app) err := l.BeforeAppDelete(ctx, app)
if err != nil { if err != nil {
@@ -74,7 +82,8 @@ func (s *Server) FireBeforeAppDelete(ctx MiddlewareContext, app *models.App) err
return nil return nil
} }
func (s *Server) FireAfterAppDelete(ctx MiddlewareContext, app *models.App) error { // FireAfterAppDelete is used to call all the server's Listeners AfterAppDelete functions.
func (s *Server) FireAfterAppDelete(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners { for _, l := range s.appListeners {
err := l.AfterAppDelete(ctx, app) err := l.AfterAppDelete(ctx, app)
if err != nil { if err != nil {

View File

@@ -8,7 +8,7 @@ import (
) )
func (s *Server) handleAppCreate(c *gin.Context) { func (s *Server) handleAppCreate(c *gin.Context) {
ctx := c.MustGet("mctx").(MiddlewareContext) ctx := c.Request.Context()
var wapp models.AppWrapper var wapp models.AppWrapper
@@ -23,7 +23,7 @@ func (s *Server) handleAppCreate(c *gin.Context) {
return return
} }
if err := wapp.Validate(); err != nil { if err = wapp.Validate(); err != nil {
handleErrorResponse(c, err) handleErrorResponse(c, err)
return return
} }

View File

@@ -10,7 +10,7 @@ import (
) )
func (s *Server) handleAppDelete(c *gin.Context) { func (s *Server) handleAppDelete(c *gin.Context) {
ctx := c.MustGet("mctx").(MiddlewareContext) ctx := c.Request.Context()
log := common.Logger(ctx) log := common.Logger(ctx)
app := &models.App{Name: c.MustGet(api.AppName).(string)} app := &models.App{Name: c.MustGet(api.AppName).(string)}

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -9,7 +8,7 @@ import (
) )
func (s *Server) handleAppGet(c *gin.Context) { func (s *Server) handleAppGet(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
appName := c.MustGet(api.AppName).(string) appName := c.MustGet(api.AppName).(string)
app, err := s.Datastore.GetApp(ctx, appName) app, err := s.Datastore.GetApp(ctx, appName)

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -9,7 +8,7 @@ import (
) )
func (s *Server) handleAppList(c *gin.Context) { func (s *Server) handleAppList(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
filter := &models.AppFilter{} filter := &models.AppFilter{}

View File

@@ -9,7 +9,7 @@ import (
) )
func (s *Server) handleAppUpdate(c *gin.Context) { func (s *Server) handleAppUpdate(c *gin.Context) {
ctx := c.MustGet("mctx").(MiddlewareContext) ctx := c.Request.Context()
wapp := models.AppWrapper{} wapp := models.AppWrapper{}

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -9,7 +8,7 @@ import (
) )
func (s *Server) handleCallGet(c *gin.Context) { func (s *Server) handleCallGet(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
callID := c.Param(api.Call) callID := c.Param(api.Call)
callObj, err := s.Datastore.GetTask(ctx, callID) callObj, err := s.Datastore.GetTask(ctx, callID)

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -10,7 +9,7 @@ import (
) )
func (s *Server) handleCallList(c *gin.Context) { func (s *Server) handleCallList(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
appName, ok := c.MustGet(api.AppName).(string) appName, ok := c.MustGet(api.AppName).(string)
if ok && appName == "" { if ok && appName == "" {

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -9,7 +8,7 @@ import (
) )
func (s *Server) handleCallLogGet(c *gin.Context) { func (s *Server) handleCallLogGet(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
callID := c.Param(api.Call) callID := c.Param(api.Call)
_, err := s.Datastore.GetTask(ctx, callID) _, err := s.Datastore.GetTask(ctx, callID)
@@ -28,7 +27,7 @@ func (s *Server) handleCallLogGet(c *gin.Context) {
} }
func (s *Server) handleCallLogDelete(c *gin.Context) { func (s *Server) handleCallLogDelete(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
callID := c.Param(api.Call) callID := c.Param(api.Call)
_, err := s.Datastore.GetTask(ctx, callID) _, err := s.Datastore.GetTask(ctx, callID)

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"runtime/debug" "runtime/debug"
@@ -12,6 +11,7 @@ import (
"gitlab-odx.oracle.com/odx/functions/api/runner/common" "gitlab-odx.oracle.com/odx/functions/api/runner/common"
) )
// ErrInternalServerError returned when something exceptional happens.
var ErrInternalServerError = errors.New("internal server error") var ErrInternalServerError = errors.New("internal server error")
func simpleError(err error) *models.Error { func simpleError(err error) *models.Error {
@@ -19,14 +19,14 @@ func simpleError(err error) *models.Error {
} }
func handleErrorResponse(c *gin.Context, err error) { func handleErrorResponse(c *gin.Context, err error) {
ctx := c.MustGet("ctx").(context.Context) log := common.Logger(c.Request.Context())
log := common.Logger(ctx) switch e := err.(type) {
case models.APIError:
if aerr, ok := err.(models.APIError); ok { if e.Code() >= 500 {
log.WithFields(logrus.Fields{"code": aerr.Code()}).WithError(err).Error("api error") log.WithFields(logrus.Fields{"code": e.Code()}).WithError(e).Error("api error")
c.JSON(aerr.Code(), simpleError(err)) }
} else if err != nil { c.JSON(e.Code(), simpleError(e))
// get a stack trace so we can trace this error default:
log.WithError(err).WithFields(logrus.Fields{"stack": string(debug.Stack())}).Error("internal server error") log.WithError(err).WithFields(logrus.Fields{"stack": string(debug.Stack())}).Error("internal server error")
c.JSON(http.StatusInternalServerError, simpleError(ErrInternalServerError)) c.JSON(http.StatusInternalServerError, simpleError(ErrInternalServerError))
} }

View File

@@ -1,132 +1,69 @@
// TODO: it would be nice to move these into the top level folder so people can use these with the "functions" package, eg: functions.AddMiddleware(...)
package server package server
import ( import (
"context" "context"
"net/http" "net/http"
"github.com/Sirupsen/logrus" "gitlab-odx.oracle.com/odx/functions/api/runner/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gitlab-odx.oracle.com/odx/functions/api/models"
) )
// Middleware is the interface required for implementing functions middlewar // Middleware just takes a http.Handler and returns one. So the next middle ware must be called
// within the returned handler or it would be ignored.
type Middleware interface { type Middleware interface {
// Serve is what the Middleware must implement. Can modify the request, write output, etc. Chain(next http.Handler) http.Handler
// todo: should we abstract the HTTP out of this? In case we want to support other protocols.
Serve(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error
} }
// MiddlewareFunc func form of Middleware // MiddlewareFunc is a here to allow a plain function to be a middleware.
type MiddlewareFunc func(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error type MiddlewareFunc func(next http.Handler) http.Handler
// Serve wrapper // Chain used to allow middlewarefuncs to be middleware.
func (f MiddlewareFunc) Serve(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error { func (m MiddlewareFunc) Chain(next http.Handler) http.Handler {
return f(ctx, w, r, app) return m(next)
}
// MiddlewareContext extends context.Context for Middleware
type MiddlewareContext interface {
context.Context
// Set is used to store a new key/value pair exclusively for this context.
// This is different than WithValue(), as it does not make a copy of the context with the new value, it will be available up the chain as well.
Set(key string, value interface{})
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
Get(key string) (value interface{}, exists bool)
// MustGet returns the value for the given key if it exists, otherwise it panics.
MustGet(key string) interface{}
// Middleware can call Next() explicitly to call the next middleware in the chain. If Next() is not called and an error is not returned, Next() will automatically be called.
Next()
// Index returns the index of where we're at in the chain
Index() int
}
type middlewareContextImpl struct {
context.Context
ginContext *gin.Context
nextCalled bool
index int
middlewares []Middleware
}
// Set is used to store a new key/value pair exclusively for this context.
// This is different than WithValue(), as it does not make a copy of the context with the new value, it will be available up the chain as well.
func (c *middlewareContextImpl) Set(key string, value interface{}) {
c.ginContext.Set(key, value)
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *middlewareContextImpl) Get(key string) (value interface{}, exists bool) {
return c.ginContext.Get(key)
}
// MustGet returns the value for the given key if it exists, otherwise it panics.
func (c *middlewareContextImpl) MustGet(key string) interface{} {
return c.ginContext.MustGet(key)
}
func (c *middlewareContextImpl) Next() {
c.nextCalled = true
c.index++
c.serveNext()
}
func (c *middlewareContextImpl) serveNext() {
if c.Index() >= len(c.middlewares) {
return
}
// make shallow copy:
fctx2 := *c
fctx2.nextCalled = false
r := c.ginContext.Request.WithContext(fctx2)
err := c.middlewares[c.Index()].Serve(&fctx2, c.ginContext.Writer, r, nil)
if err != nil {
logrus.WithError(err).Warnln("Middleware error")
// todo: might be a good idea to check if anything is written yet, and if not, output the error: simpleError(err)
// see: http://stackoverflow.com/questions/39415827/golang-http-check-if-responsewriter-has-been-written
c.ginContext.Error(err)
c.ginContext.Abort()
return
}
if !fctx2.nextCalled {
// then we automatically call next
fctx2.Next()
}
}
func (c *middlewareContextImpl) Index() int {
return c.index
} }
func (s *Server) middlewareWrapperFunc(ctx context.Context) gin.HandlerFunc { func (s *Server) middlewareWrapperFunc(ctx context.Context) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// TODO: we should get rid of this, gin context and middleware context both implement context, don't need a third one here if len(s.middlewares) > 0 {
ctx = c.MustGet("ctx").(context.Context) defer func() {
fctx := &middlewareContextImpl{Context: ctx} //This is so that if the server errors or panics on a middleware the server will still respond and not send eof to client.
// add this context to gin context so we can grab it later err := recover()
c.Set("mctx", fctx) if err != nil {
fctx.index = -1 common.Logger(c.Request.Context()).WithField("MiddleWarePanicRecovery:", err).Errorln("A panic occurred during middleware.")
fctx.ginContext = c handleErrorResponse(c, ErrInternalServerError)
fctx.middlewares = s.middlewares }
// start the chain: }()
fctx.Next() var h http.Handler
keepgoing := false
h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.Request = c.Request.WithContext(r.Context())
keepgoing = true
})
s.chainAndServe(c.Writer, c.Request, h)
if !keepgoing {
c.Abort()
}
}
} }
} }
func (s *Server) chainAndServe(w http.ResponseWriter, r *http.Request, h http.Handler) {
for _, m := range s.middlewares {
h = m.Chain(h)
}
h.ServeHTTP(w, r)
}
// AddMiddleware add middleware // AddMiddleware add middleware
func (s *Server) AddMiddleware(m Middleware) { func (s *Server) AddMiddleware(m Middleware) {
s.middlewares = append(s.middlewares, m) //Prepend to array so that we can do first,second,third,last,third,second,first
//and not third,second,first,last,first,second,third
s.middlewares = append([]Middleware{m}, s.middlewares...)
} }
// AddMiddlewareFunc adds middleware function // AddMiddlewareFunc add middlewarefunc
func (s *Server) AddMiddlewareFunc(m func(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error) { func (s *Server) AddMiddlewareFunc(m MiddlewareFunc) {
s.AddMiddleware(MiddlewareFunc(m)) s.AddMiddleware(m)
} }

View File

@@ -0,0 +1,46 @@
package server
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
type middleWareStruct struct {
name string
}
func (m *middleWareStruct) Chain(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(m.name + ","))
next.ServeHTTP(w, r)
})
}
func TestMiddleWareChaining(t *testing.T) {
var lastHandler http.Handler
lastHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("last"))
})
s := Server{}
s.AddMiddleware(&middleWareStruct{"first"})
s.AddMiddleware(&middleWareStruct{"second"})
s.AddMiddleware(&middleWareStruct{"third"})
s.AddMiddleware(&middleWareStruct{"fourth"})
rec := httptest.NewRecorder()
req, _ := http.NewRequest("get", "http://localhost/", nil)
s.chainAndServe(rec, req, lastHandler)
result, err := ioutil.ReadAll(rec.Result().Body)
if err != nil {
t.Fatal(err)
}
if string(result) != "first,second,third,fourth,last" {
t.Fatal("You failed to chain correctly.", string(result))
}
}

View File

@@ -23,12 +23,12 @@ import (
Patch accepts partial updates / skips validation of zero values. Patch accepts partial updates / skips validation of zero values.
*/ */
func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) { func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) {
ctx := c.MustGet("mctx").(MiddlewareContext) ctx := c.Request.Context()
method := strings.ToUpper(c.Request.Method) method := strings.ToUpper(c.Request.Method)
var wroute models.RouteWrapper var wroute models.RouteWrapper
err := s.bindAndValidate(ctx, c, method, &wroute) err := s.bindAndValidate(c, method, &wroute)
if err != nil { if err != nil {
handleErrorResponse(c, err) handleErrorResponse(c, err)
return return
@@ -53,7 +53,7 @@ func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) {
} }
// ensureApp will only execute if it is on post or put. Patch is not allowed to create apps. // ensureApp will only execute if it is on post or put. Patch is not allowed to create apps.
func (s *Server) ensureApp(ctx MiddlewareContext, wroute *models.RouteWrapper, method string) error { func (s *Server) ensureApp(ctx context.Context, wroute *models.RouteWrapper, method string) error {
if !(method == http.MethodPost || method == http.MethodPut) { if !(method == http.MethodPost || method == http.MethodPut) {
return nil return nil
} }
@@ -90,7 +90,7 @@ func (s *Server) ensureApp(ctx MiddlewareContext, wroute *models.RouteWrapper, m
If it is a put or patch it makes sure that the path in the url matches the provideed one in the body. If it is a put or patch it makes sure that the path in the url matches the provideed one in the body.
Defaults are set and if patch skipZero is true for validating the RouteWrapper Defaults are set and if patch skipZero is true for validating the RouteWrapper
*/ */
func (s *Server) bindAndValidate(ctx context.Context, c *gin.Context, method string, wroute *models.RouteWrapper) error { func (s *Server) bindAndValidate(c *gin.Context, method string, wroute *models.RouteWrapper) error {
err := c.BindJSON(wroute) err := c.BindJSON(wroute)
if err != nil { if err != nil {
return models.ErrInvalidJSON return models.ErrInvalidJSON

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"path" "path"
@@ -10,7 +9,7 @@ import (
) )
func (s *Server) handleRouteDelete(c *gin.Context) { func (s *Server) handleRouteDelete(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
appName := c.MustGet(api.AppName).(string) appName := c.MustGet(api.AppName).(string)
routePath := path.Clean(c.MustGet(api.Path).(string)) routePath := path.Clean(c.MustGet(api.Path).(string))

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"path" "path"
@@ -10,7 +9,7 @@ import (
) )
func (s *Server) handleRouteGet(c *gin.Context) { func (s *Server) handleRouteGet(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
appName := c.MustGet(api.AppName).(string) appName := c.MustGet(api.AppName).(string)
routePath := path.Clean(c.MustGet(api.Path).(string)) routePath := path.Clean(c.MustGet(api.Path).(string))

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -10,7 +9,7 @@ import (
) )
func (s *Server) handleRouteList(c *gin.Context) { func (s *Server) handleRouteList(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
filter := &models.RouteFilter{} filter := &models.RouteFilter{}
@@ -20,8 +19,10 @@ func (s *Server) handleRouteList(c *gin.Context) {
var routes []*models.Route var routes []*models.Route
var err error var err error
if appName, ok := c.MustGet(api.AppName).(string); ok && appName != "" { appName, exists := c.Get(api.AppName)
routes, err = s.Datastore.GetRoutesByApp(ctx, appName, filter) name, ok := appName.(string)
if exists && ok && name != "" {
routes, err = s.Datastore.GetRoutesByApp(ctx, name, filter)
} else { } else {
routes, err = s.Datastore.GetRoutes(ctx, filter) routes, err = s.Datastore.GetRoutes(ctx, filter)
} }
@@ -31,5 +32,5 @@ func (s *Server) handleRouteList(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, routesResponse{"Sucessfully listed routes", routes}) c.JSON(http.StatusOK, routesResponse{"Successfully listed routes", routes})
} }

View File

@@ -28,21 +28,21 @@ type runnerResponse struct {
} }
func (s *Server) handleSpecial(c *gin.Context) { func (s *Server) handleSpecial(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
ctx = context.WithValue(ctx, api.AppName, "") ctx = context.WithValue(ctx, api.AppName, "")
c.Set(api.AppName, "") c.Set(api.AppName, "")
ctx = context.WithValue(ctx, api.Path, c.Request.URL.Path) ctx = context.WithValue(ctx, api.Path, c.Request.URL.Path)
c.Set(api.Path, c.Request.URL.Path) c.Set(api.Path, c.Request.URL.Path)
ctx, err := s.UseSpecialHandlers(ctx, c.Request, c.Writer) r, err := s.UseSpecialHandlers(c.Writer, c.Request)
if err != nil { if err != nil {
handleErrorResponse(c, err) handleErrorResponse(c, err)
return return
} }
c.Set("ctx", ctx) c.Request = r
c.Set(api.AppName, ctx.Value(api.AppName).(string)) c.Set(api.AppName, r.Context().Value(api.AppName).(string))
if c.MustGet(api.AppName).(string) == "" { if c.MustGet(api.AppName).(string) == "" {
handleErrorResponse(c, models.ErrRunnerRouteNotFound) handleErrorResponse(c, models.ErrRunnerRouteNotFound)
return return
@@ -66,7 +66,7 @@ func (s *Server) handleRequest(c *gin.Context, enqueue models.Enqueue) {
return return
} }
ctx := c.MustGet("ctx").(context.Context) ctx := c.Request.Context()
reqID := id.New().String() reqID := id.New().String()
ctx, log := common.LoggerWithFields(ctx, logrus.Fields{"call_id": reqID}) ctx, log := common.LoggerWithFields(ctx, logrus.Fields{"call_id": reqID})

View File

@@ -161,8 +161,6 @@ func prepareMiddleware(ctx context.Context) gin.HandlerFunc {
c.Set(api.Path, routePath) c.Set(api.Path, routePath)
} }
// todo: can probably replace the "ctx" value with the Go 1.7 context on the http.Request
c.Set("ctx", ctx)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
c.Next() c.Next()
} }

View File

@@ -1,73 +1,36 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"gitlab-odx.oracle.com/odx/functions/api/models" "gitlab-odx.oracle.com/odx/functions/api/models"
) )
// SpecialHandler verysimilar to a handler but since it is not used as middle ware no way
// to get context without returning it. So we just return a request which could have newly made
// contexts.
type SpecialHandler interface { type SpecialHandler interface {
Handle(c HandlerContext) error Handle(w http.ResponseWriter, r *http.Request) (*http.Request, error)
}
// Each handler can modify the context here so when it gets passed along, it will use the new info.
type HandlerContext interface {
// Context return the context object
Context() context.Context
// Request returns the underlying http.Request object
Request() *http.Request
// Response returns the http.ResponseWriter
Response() http.ResponseWriter
// Overwrite value in the context
Set(key string, value interface{})
}
type SpecialHandlerContext struct {
request *http.Request
response http.ResponseWriter
ctx context.Context
}
func (c *SpecialHandlerContext) Context() context.Context {
return c.ctx
}
func (c *SpecialHandlerContext) Request() *http.Request {
return c.request
}
func (c *SpecialHandlerContext) Response() http.ResponseWriter {
return c.response
}
func (c *SpecialHandlerContext) Set(key string, value interface{}) {
c.ctx = context.WithValue(c.ctx, key, value)
} }
// AddSpecialHandler adds the SpecialHandler to the specialHandlers list.
func (s *Server) AddSpecialHandler(handler SpecialHandler) { func (s *Server) AddSpecialHandler(handler SpecialHandler) {
s.specialHandlers = append(s.specialHandlers, handler) s.specialHandlers = append(s.specialHandlers, handler)
} }
// UseSpecialHandlers execute all special handlers // UseSpecialHandlers execute all special handlers
func (s *Server) UseSpecialHandlers(ctx context.Context, req *http.Request, resp http.ResponseWriter) (context.Context, error) { func (s *Server) UseSpecialHandlers(resp http.ResponseWriter, req *http.Request) (*http.Request, error) {
if len(s.specialHandlers) == 0 { if len(s.specialHandlers) == 0 {
return ctx, models.ErrNoSpecialHandlerFound return req, models.ErrNoSpecialHandlerFound
} }
var r *http.Request
var err error
c := &SpecialHandlerContext{
request: req,
response: resp,
ctx: ctx,
}
for _, l := range s.specialHandlers { for _, l := range s.specialHandlers {
err := l.Handle(c) r, err = l.Handle(resp, req)
if err != nil { if err != nil {
return c.ctx, err return nil, err
} }
} }
return c.ctx, nil return r, nil
} }

View File

@@ -1,12 +1,15 @@
package server package server
import "testing" import (
"net/http"
"testing"
)
type testSpecialHandler struct{} type testSpecialHandler struct{}
func (h *testSpecialHandler) Handle(c HandlerContext) error { func (h *testSpecialHandler) Handle(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
// c.Set(api.AppName, "test") // r = r.WithContext(context.WithValue(r.Context(), api.AppName, "test"))
return nil return r, nil
} }
func TestSpecialHandlerSet(t *testing.T) { func TestSpecialHandlerSet(t *testing.T) {

View File

@@ -3,13 +3,11 @@ package main
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"gitlab-odx.oracle.com/odx/functions/api/models"
"gitlab-odx.oracle.com/odx/functions/api/server" "gitlab-odx.oracle.com/odx/functions/api/server"
) )
@@ -18,13 +16,15 @@ func main() {
funcServer := server.NewFromEnv(ctx) funcServer := server.NewFromEnv(ctx)
funcServer.AddMiddlewareFunc(func(ctx server.MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error { funcServer.AddMiddlewareFunc(func(next http.Handler) http.Handler {
start := time.Now() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println("CustomMiddlewareFunc called at:", start) start := time.Now()
ctx.Next() fmt.Println("CustomMiddlewareFunc called at:", start)
fmt.Println("Duration:", (time.Now().Sub(start))) next.ServeHTTP(w, r)
return nil fmt.Println("Duration:", (time.Now().Sub(start)))
})
}) })
funcServer.AddMiddleware(&CustomMiddleware{}) funcServer.AddMiddleware(&CustomMiddleware{})
funcServer.Start(ctx) funcServer.Start(ctx)
@@ -33,20 +33,22 @@ func main() {
type CustomMiddleware struct { type CustomMiddleware struct {
} }
func (h *CustomMiddleware) Serve(ctx server.MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error { func (h *CustomMiddleware) Serve(next http.Handler) http.Handler {
fmt.Println("CustomMiddleware called") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println("CustomMiddleware called")
// check auth header // check auth header
tokenHeader := strings.SplitN(r.Header.Get("Authorization"), " ", 3) tokenHeader := strings.SplitN(r.Header.Get("Authorization"), " ", 3)
if len(tokenHeader) < 2 || tokenHeader[1] != "KlaatuBaradaNikto" { if len(tokenHeader) < 2 || tokenHeader[1] != "KlaatuBaradaNikto" {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
m2 := map[string]string{"message": "Invalid Authorization token."} m2 := map[string]string{"message": "Invalid Authorization token."}
m := map[string]map[string]string{"error": m2} m := map[string]map[string]string{"error": m2}
json.NewEncoder(w).Encode(m) json.NewEncoder(w).Encode(m)
return errors.New("Invalid authorization token.") return
} }
fmt.Println("auth succeeded!") fmt.Println("auth succeeded!")
ctx.Set("user", "I'm in!") r = r.WithContext(context.WithValue(r.Context(), "user", "I'm in!"))
return nil next.ServeHTTP(w, r)
})
} }