mirror of
https://github.com/fnproject/fn.git
synced 2022-10-28 21:29:17 +03:00
Use chi style middle ware
This commit is contained in:
committed by
Travis Reeder
parent
570e9265f1
commit
cf2c3cf404
@@ -1,30 +1,34 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"gitlab-odx.oracle.com/odx/functions/api/models"
|
"gitlab-odx.oracle.com/odx/functions/api/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AppListener is an interface used to inject custom code at key points in app lifecycle.
|
||||||
type AppListener interface {
|
type AppListener interface {
|
||||||
// BeforeAppCreate called right before creating App in the database
|
// BeforeAppCreate called right before creating App in the database
|
||||||
BeforeAppCreate(ctx MiddlewareContext, app *models.App) error
|
BeforeAppCreate(ctx context.Context, app *models.App) error
|
||||||
// AfterAppCreate called after creating App in the database
|
// AfterAppCreate called after creating App in the database
|
||||||
AfterAppCreate(ctx MiddlewareContext, app *models.App) error
|
AfterAppCreate(ctx context.Context, app *models.App) error
|
||||||
// BeforeAppUpdate called right before updating App in the database
|
// BeforeAppUpdate called right before updating App in the database
|
||||||
BeforeAppUpdate(ctx MiddlewareContext, app *models.App) error
|
BeforeAppUpdate(ctx context.Context, app *models.App) error
|
||||||
// AfterAppUpdate called after updating App in the database
|
// AfterAppUpdate called after updating App in the database
|
||||||
AfterAppUpdate(ctx MiddlewareContext, app *models.App) error
|
AfterAppUpdate(ctx context.Context, app *models.App) error
|
||||||
// BeforeAppDelete called right before deleting App in the database
|
// BeforeAppDelete called right before deleting App in the database
|
||||||
BeforeAppDelete(ctx MiddlewareContext, app *models.App) error
|
BeforeAppDelete(ctx context.Context, app *models.App) error
|
||||||
// AfterAppDelete called after deleting App in the database
|
// AfterAppDelete called after deleting App in the database
|
||||||
AfterAppDelete(ctx MiddlewareContext, app *models.App) error
|
AfterAppDelete(ctx context.Context, app *models.App) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddAppCreateListener adds a listener that will be notified on App created.
|
// AddAppListener adds a listener that will be notified on App created.
|
||||||
func (s *Server) AddAppListener(listener AppListener) {
|
func (s *Server) AddAppListener(listener AppListener) {
|
||||||
s.appListeners = append(s.appListeners, listener)
|
s.appListeners = append(s.appListeners, listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) FireBeforeAppCreate(ctx MiddlewareContext, app *models.App) error {
|
// FireBeforeAppCreate is used to call all the server's Listeners BeforeAppCreate functions.
|
||||||
|
func (s *Server) FireBeforeAppCreate(ctx context.Context, app *models.App) error {
|
||||||
for _, l := range s.appListeners {
|
for _, l := range s.appListeners {
|
||||||
err := l.BeforeAppCreate(ctx, app)
|
err := l.BeforeAppCreate(ctx, app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -34,7 +38,8 @@ func (s *Server) FireBeforeAppCreate(ctx MiddlewareContext, app *models.App) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) FireAfterAppCreate(ctx MiddlewareContext, app *models.App) error {
|
// FireAfterAppCreate is used to call all the server's Listeners AfterAppCreate functions.
|
||||||
|
func (s *Server) FireAfterAppCreate(ctx context.Context, app *models.App) error {
|
||||||
for _, l := range s.appListeners {
|
for _, l := range s.appListeners {
|
||||||
err := l.AfterAppCreate(ctx, app)
|
err := l.AfterAppCreate(ctx, app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -44,7 +49,8 @@ func (s *Server) FireAfterAppCreate(ctx MiddlewareContext, app *models.App) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) FireBeforeAppUpdate(ctx MiddlewareContext, app *models.App) error {
|
// FireBeforeAppUpdate is used to call all the server's Listeners BeforeAppUpdate functions.
|
||||||
|
func (s *Server) FireBeforeAppUpdate(ctx context.Context, app *models.App) error {
|
||||||
for _, l := range s.appListeners {
|
for _, l := range s.appListeners {
|
||||||
err := l.BeforeAppUpdate(ctx, app)
|
err := l.BeforeAppUpdate(ctx, app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -54,7 +60,8 @@ func (s *Server) FireBeforeAppUpdate(ctx MiddlewareContext, app *models.App) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) FireAfterAppUpdate(ctx MiddlewareContext, app *models.App) error {
|
// FireAfterAppUpdate is used to call all the server's Listeners AfterAppUpdate functions.
|
||||||
|
func (s *Server) FireAfterAppUpdate(ctx context.Context, app *models.App) error {
|
||||||
for _, l := range s.appListeners {
|
for _, l := range s.appListeners {
|
||||||
err := l.AfterAppUpdate(ctx, app)
|
err := l.AfterAppUpdate(ctx, app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -64,7 +71,8 @@ func (s *Server) FireAfterAppUpdate(ctx MiddlewareContext, app *models.App) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) FireBeforeAppDelete(ctx MiddlewareContext, app *models.App) error {
|
// FireBeforeAppDelete is used to call all the server's Listeners BeforeAppDelete functions.
|
||||||
|
func (s *Server) FireBeforeAppDelete(ctx context.Context, app *models.App) error {
|
||||||
for _, l := range s.appListeners {
|
for _, l := range s.appListeners {
|
||||||
err := l.BeforeAppDelete(ctx, app)
|
err := l.BeforeAppDelete(ctx, app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,7 +82,8 @@ func (s *Server) FireBeforeAppDelete(ctx MiddlewareContext, app *models.App) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) FireAfterAppDelete(ctx MiddlewareContext, app *models.App) error {
|
// FireAfterAppDelete is used to call all the server's Listeners AfterAppDelete functions.
|
||||||
|
func (s *Server) FireAfterAppDelete(ctx context.Context, app *models.App) error {
|
||||||
for _, l := range s.appListeners {
|
for _, l := range s.appListeners {
|
||||||
err := l.AfterAppDelete(ctx, app)
|
err := l.AfterAppDelete(ctx, app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleAppCreate(c *gin.Context) {
|
func (s *Server) handleAppCreate(c *gin.Context) {
|
||||||
ctx := c.MustGet("mctx").(MiddlewareContext)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
var wapp models.AppWrapper
|
var wapp models.AppWrapper
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ func (s *Server) handleAppCreate(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := wapp.Validate(); err != nil {
|
if err = wapp.Validate(); err != nil {
|
||||||
handleErrorResponse(c, err)
|
handleErrorResponse(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleAppDelete(c *gin.Context) {
|
func (s *Server) handleAppDelete(c *gin.Context) {
|
||||||
ctx := c.MustGet("mctx").(MiddlewareContext)
|
ctx := c.Request.Context()
|
||||||
log := common.Logger(ctx)
|
log := common.Logger(ctx)
|
||||||
|
|
||||||
app := &models.App{Name: c.MustGet(api.AppName).(string)}
|
app := &models.App{Name: c.MustGet(api.AppName).(string)}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -9,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleAppGet(c *gin.Context) {
|
func (s *Server) handleAppGet(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
appName := c.MustGet(api.AppName).(string)
|
appName := c.MustGet(api.AppName).(string)
|
||||||
app, err := s.Datastore.GetApp(ctx, appName)
|
app, err := s.Datastore.GetApp(ctx, appName)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -9,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleAppList(c *gin.Context) {
|
func (s *Server) handleAppList(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
filter := &models.AppFilter{}
|
filter := &models.AppFilter{}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleAppUpdate(c *gin.Context) {
|
func (s *Server) handleAppUpdate(c *gin.Context) {
|
||||||
ctx := c.MustGet("mctx").(MiddlewareContext)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
wapp := models.AppWrapper{}
|
wapp := models.AppWrapper{}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -9,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleCallGet(c *gin.Context) {
|
func (s *Server) handleCallGet(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
callID := c.Param(api.Call)
|
callID := c.Param(api.Call)
|
||||||
callObj, err := s.Datastore.GetTask(ctx, callID)
|
callObj, err := s.Datastore.GetTask(ctx, callID)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -10,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleCallList(c *gin.Context) {
|
func (s *Server) handleCallList(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
appName, ok := c.MustGet(api.AppName).(string)
|
appName, ok := c.MustGet(api.AppName).(string)
|
||||||
if ok && appName == "" {
|
if ok && appName == "" {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -9,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleCallLogGet(c *gin.Context) {
|
func (s *Server) handleCallLogGet(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
callID := c.Param(api.Call)
|
callID := c.Param(api.Call)
|
||||||
_, err := s.Datastore.GetTask(ctx, callID)
|
_, err := s.Datastore.GetTask(ctx, callID)
|
||||||
@@ -28,7 +27,7 @@ func (s *Server) handleCallLogGet(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleCallLogDelete(c *gin.Context) {
|
func (s *Server) handleCallLogDelete(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
callID := c.Param(api.Call)
|
callID := c.Param(api.Call)
|
||||||
_, err := s.Datastore.GetTask(ctx, callID)
|
_, err := s.Datastore.GetTask(ctx, callID)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
@@ -12,6 +11,7 @@ import (
|
|||||||
"gitlab-odx.oracle.com/odx/functions/api/runner/common"
|
"gitlab-odx.oracle.com/odx/functions/api/runner/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrInternalServerError returned when something exceptional happens.
|
||||||
var ErrInternalServerError = errors.New("internal server error")
|
var ErrInternalServerError = errors.New("internal server error")
|
||||||
|
|
||||||
func simpleError(err error) *models.Error {
|
func simpleError(err error) *models.Error {
|
||||||
@@ -19,14 +19,14 @@ func simpleError(err error) *models.Error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleErrorResponse(c *gin.Context, err error) {
|
func handleErrorResponse(c *gin.Context, err error) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
log := common.Logger(c.Request.Context())
|
||||||
log := common.Logger(ctx)
|
switch e := err.(type) {
|
||||||
|
case models.APIError:
|
||||||
if aerr, ok := err.(models.APIError); ok {
|
if e.Code() >= 500 {
|
||||||
log.WithFields(logrus.Fields{"code": aerr.Code()}).WithError(err).Error("api error")
|
log.WithFields(logrus.Fields{"code": e.Code()}).WithError(e).Error("api error")
|
||||||
c.JSON(aerr.Code(), simpleError(err))
|
}
|
||||||
} else if err != nil {
|
c.JSON(e.Code(), simpleError(e))
|
||||||
// get a stack trace so we can trace this error
|
default:
|
||||||
log.WithError(err).WithFields(logrus.Fields{"stack": string(debug.Stack())}).Error("internal server error")
|
log.WithError(err).WithFields(logrus.Fields{"stack": string(debug.Stack())}).Error("internal server error")
|
||||||
c.JSON(http.StatusInternalServerError, simpleError(ErrInternalServerError))
|
c.JSON(http.StatusInternalServerError, simpleError(ErrInternalServerError))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,132 +1,69 @@
|
|||||||
// TODO: it would be nice to move these into the top level folder so people can use these with the "functions" package, eg: functions.AddMiddleware(...)
|
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
"gitlab-odx.oracle.com/odx/functions/api/runner/common"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gitlab-odx.oracle.com/odx/functions/api/models"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Middleware is the interface required for implementing functions middlewar
|
// Middleware just takes a http.Handler and returns one. So the next middle ware must be called
|
||||||
|
// within the returned handler or it would be ignored.
|
||||||
type Middleware interface {
|
type Middleware interface {
|
||||||
// Serve is what the Middleware must implement. Can modify the request, write output, etc.
|
Chain(next http.Handler) http.Handler
|
||||||
// todo: should we abstract the HTTP out of this? In case we want to support other protocols.
|
|
||||||
Serve(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MiddlewareFunc func form of Middleware
|
// MiddlewareFunc is a here to allow a plain function to be a middleware.
|
||||||
type MiddlewareFunc func(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error
|
type MiddlewareFunc func(next http.Handler) http.Handler
|
||||||
|
|
||||||
// Serve wrapper
|
// Chain used to allow middlewarefuncs to be middleware.
|
||||||
func (f MiddlewareFunc) Serve(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error {
|
func (m MiddlewareFunc) Chain(next http.Handler) http.Handler {
|
||||||
return f(ctx, w, r, app)
|
return m(next)
|
||||||
}
|
|
||||||
|
|
||||||
// MiddlewareContext extends context.Context for Middleware
|
|
||||||
type MiddlewareContext interface {
|
|
||||||
context.Context
|
|
||||||
|
|
||||||
// Set is used to store a new key/value pair exclusively for this context.
|
|
||||||
// This is different than WithValue(), as it does not make a copy of the context with the new value, it will be available up the chain as well.
|
|
||||||
Set(key string, value interface{})
|
|
||||||
|
|
||||||
// Get returns the value for the given key, ie: (value, true).
|
|
||||||
// If the value does not exists it returns (nil, false)
|
|
||||||
Get(key string) (value interface{}, exists bool)
|
|
||||||
|
|
||||||
// MustGet returns the value for the given key if it exists, otherwise it panics.
|
|
||||||
MustGet(key string) interface{}
|
|
||||||
|
|
||||||
// Middleware can call Next() explicitly to call the next middleware in the chain. If Next() is not called and an error is not returned, Next() will automatically be called.
|
|
||||||
Next()
|
|
||||||
// Index returns the index of where we're at in the chain
|
|
||||||
Index() int
|
|
||||||
}
|
|
||||||
|
|
||||||
type middlewareContextImpl struct {
|
|
||||||
context.Context
|
|
||||||
|
|
||||||
ginContext *gin.Context
|
|
||||||
nextCalled bool
|
|
||||||
index int
|
|
||||||
middlewares []Middleware
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set is used to store a new key/value pair exclusively for this context.
|
|
||||||
// This is different than WithValue(), as it does not make a copy of the context with the new value, it will be available up the chain as well.
|
|
||||||
func (c *middlewareContextImpl) Set(key string, value interface{}) {
|
|
||||||
c.ginContext.Set(key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the value for the given key, ie: (value, true).
|
|
||||||
// If the value does not exists it returns (nil, false)
|
|
||||||
func (c *middlewareContextImpl) Get(key string) (value interface{}, exists bool) {
|
|
||||||
return c.ginContext.Get(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustGet returns the value for the given key if it exists, otherwise it panics.
|
|
||||||
func (c *middlewareContextImpl) MustGet(key string) interface{} {
|
|
||||||
return c.ginContext.MustGet(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *middlewareContextImpl) Next() {
|
|
||||||
c.nextCalled = true
|
|
||||||
c.index++
|
|
||||||
c.serveNext()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *middlewareContextImpl) serveNext() {
|
|
||||||
if c.Index() >= len(c.middlewares) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// make shallow copy:
|
|
||||||
fctx2 := *c
|
|
||||||
fctx2.nextCalled = false
|
|
||||||
r := c.ginContext.Request.WithContext(fctx2)
|
|
||||||
err := c.middlewares[c.Index()].Serve(&fctx2, c.ginContext.Writer, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Warnln("Middleware error")
|
|
||||||
// todo: might be a good idea to check if anything is written yet, and if not, output the error: simpleError(err)
|
|
||||||
// see: http://stackoverflow.com/questions/39415827/golang-http-check-if-responsewriter-has-been-written
|
|
||||||
c.ginContext.Error(err)
|
|
||||||
c.ginContext.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !fctx2.nextCalled {
|
|
||||||
// then we automatically call next
|
|
||||||
fctx2.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *middlewareContextImpl) Index() int {
|
|
||||||
return c.index
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) middlewareWrapperFunc(ctx context.Context) gin.HandlerFunc {
|
func (s *Server) middlewareWrapperFunc(ctx context.Context) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// TODO: we should get rid of this, gin context and middleware context both implement context, don't need a third one here
|
if len(s.middlewares) > 0 {
|
||||||
ctx = c.MustGet("ctx").(context.Context)
|
defer func() {
|
||||||
fctx := &middlewareContextImpl{Context: ctx}
|
//This is so that if the server errors or panics on a middleware the server will still respond and not send eof to client.
|
||||||
// add this context to gin context so we can grab it later
|
err := recover()
|
||||||
c.Set("mctx", fctx)
|
if err != nil {
|
||||||
fctx.index = -1
|
common.Logger(c.Request.Context()).WithField("MiddleWarePanicRecovery:", err).Errorln("A panic occurred during middleware.")
|
||||||
fctx.ginContext = c
|
handleErrorResponse(c, ErrInternalServerError)
|
||||||
fctx.middlewares = s.middlewares
|
}
|
||||||
// start the chain:
|
}()
|
||||||
fctx.Next()
|
var h http.Handler
|
||||||
|
keepgoing := false
|
||||||
|
h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
c.Request = c.Request.WithContext(r.Context())
|
||||||
|
keepgoing = true
|
||||||
|
})
|
||||||
|
|
||||||
|
s.chainAndServe(c.Writer, c.Request, h)
|
||||||
|
if !keepgoing {
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) chainAndServe(w http.ResponseWriter, r *http.Request, h http.Handler) {
|
||||||
|
for _, m := range s.middlewares {
|
||||||
|
h = m.Chain(h)
|
||||||
|
}
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
// AddMiddleware add middleware
|
// AddMiddleware add middleware
|
||||||
func (s *Server) AddMiddleware(m Middleware) {
|
func (s *Server) AddMiddleware(m Middleware) {
|
||||||
s.middlewares = append(s.middlewares, m)
|
//Prepend to array so that we can do first,second,third,last,third,second,first
|
||||||
|
//and not third,second,first,last,first,second,third
|
||||||
|
s.middlewares = append([]Middleware{m}, s.middlewares...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMiddlewareFunc adds middleware function
|
// AddMiddlewareFunc add middlewarefunc
|
||||||
func (s *Server) AddMiddlewareFunc(m func(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error) {
|
func (s *Server) AddMiddlewareFunc(m MiddlewareFunc) {
|
||||||
s.AddMiddleware(MiddlewareFunc(m))
|
s.AddMiddleware(m)
|
||||||
}
|
}
|
||||||
|
|||||||
46
api/server/middleware_test.go
Normal file
46
api/server/middleware_test.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type middleWareStruct struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *middleWareStruct) Chain(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(m.name + ","))
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddleWareChaining(t *testing.T) {
|
||||||
|
var lastHandler http.Handler
|
||||||
|
lastHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("last"))
|
||||||
|
})
|
||||||
|
|
||||||
|
s := Server{}
|
||||||
|
s.AddMiddleware(&middleWareStruct{"first"})
|
||||||
|
s.AddMiddleware(&middleWareStruct{"second"})
|
||||||
|
s.AddMiddleware(&middleWareStruct{"third"})
|
||||||
|
s.AddMiddleware(&middleWareStruct{"fourth"})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("get", "http://localhost/", nil)
|
||||||
|
|
||||||
|
s.chainAndServe(rec, req, lastHandler)
|
||||||
|
|
||||||
|
result, err := ioutil.ReadAll(rec.Result().Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result) != "first,second,third,fourth,last" {
|
||||||
|
t.Fatal("You failed to chain correctly.", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,12 +23,12 @@ import (
|
|||||||
Patch accepts partial updates / skips validation of zero values.
|
Patch accepts partial updates / skips validation of zero values.
|
||||||
*/
|
*/
|
||||||
func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) {
|
func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) {
|
||||||
ctx := c.MustGet("mctx").(MiddlewareContext)
|
ctx := c.Request.Context()
|
||||||
method := strings.ToUpper(c.Request.Method)
|
method := strings.ToUpper(c.Request.Method)
|
||||||
|
|
||||||
var wroute models.RouteWrapper
|
var wroute models.RouteWrapper
|
||||||
|
|
||||||
err := s.bindAndValidate(ctx, c, method, &wroute)
|
err := s.bindAndValidate(c, method, &wroute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleErrorResponse(c, err)
|
handleErrorResponse(c, err)
|
||||||
return
|
return
|
||||||
@@ -53,7 +53,7 @@ func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ensureApp will only execute if it is on post or put. Patch is not allowed to create apps.
|
// ensureApp will only execute if it is on post or put. Patch is not allowed to create apps.
|
||||||
func (s *Server) ensureApp(ctx MiddlewareContext, wroute *models.RouteWrapper, method string) error {
|
func (s *Server) ensureApp(ctx context.Context, wroute *models.RouteWrapper, method string) error {
|
||||||
if !(method == http.MethodPost || method == http.MethodPut) {
|
if !(method == http.MethodPost || method == http.MethodPut) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,7 @@ func (s *Server) ensureApp(ctx MiddlewareContext, wroute *models.RouteWrapper, m
|
|||||||
If it is a put or patch it makes sure that the path in the url matches the provideed one in the body.
|
If it is a put or patch it makes sure that the path in the url matches the provideed one in the body.
|
||||||
Defaults are set and if patch skipZero is true for validating the RouteWrapper
|
Defaults are set and if patch skipZero is true for validating the RouteWrapper
|
||||||
*/
|
*/
|
||||||
func (s *Server) bindAndValidate(ctx context.Context, c *gin.Context, method string, wroute *models.RouteWrapper) error {
|
func (s *Server) bindAndValidate(c *gin.Context, method string, wroute *models.RouteWrapper) error {
|
||||||
err := c.BindJSON(wroute)
|
err := c.BindJSON(wroute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return models.ErrInvalidJSON
|
return models.ErrInvalidJSON
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
@@ -10,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleRouteDelete(c *gin.Context) {
|
func (s *Server) handleRouteDelete(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
appName := c.MustGet(api.AppName).(string)
|
appName := c.MustGet(api.AppName).(string)
|
||||||
routePath := path.Clean(c.MustGet(api.Path).(string))
|
routePath := path.Clean(c.MustGet(api.Path).(string))
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
@@ -10,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleRouteGet(c *gin.Context) {
|
func (s *Server) handleRouteGet(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
appName := c.MustGet(api.AppName).(string)
|
appName := c.MustGet(api.AppName).(string)
|
||||||
routePath := path.Clean(c.MustGet(api.Path).(string))
|
routePath := path.Clean(c.MustGet(api.Path).(string))
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -10,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleRouteList(c *gin.Context) {
|
func (s *Server) handleRouteList(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
filter := &models.RouteFilter{}
|
filter := &models.RouteFilter{}
|
||||||
|
|
||||||
@@ -20,8 +19,10 @@ func (s *Server) handleRouteList(c *gin.Context) {
|
|||||||
|
|
||||||
var routes []*models.Route
|
var routes []*models.Route
|
||||||
var err error
|
var err error
|
||||||
if appName, ok := c.MustGet(api.AppName).(string); ok && appName != "" {
|
appName, exists := c.Get(api.AppName)
|
||||||
routes, err = s.Datastore.GetRoutesByApp(ctx, appName, filter)
|
name, ok := appName.(string)
|
||||||
|
if exists && ok && name != "" {
|
||||||
|
routes, err = s.Datastore.GetRoutesByApp(ctx, name, filter)
|
||||||
} else {
|
} else {
|
||||||
routes, err = s.Datastore.GetRoutes(ctx, filter)
|
routes, err = s.Datastore.GetRoutes(ctx, filter)
|
||||||
}
|
}
|
||||||
@@ -31,5 +32,5 @@ func (s *Server) handleRouteList(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, routesResponse{"Sucessfully listed routes", routes})
|
c.JSON(http.StatusOK, routesResponse{"Successfully listed routes", routes})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,21 +28,21 @@ type runnerResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleSpecial(c *gin.Context) {
|
func (s *Server) handleSpecial(c *gin.Context) {
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, api.AppName, "")
|
ctx = context.WithValue(ctx, api.AppName, "")
|
||||||
c.Set(api.AppName, "")
|
c.Set(api.AppName, "")
|
||||||
ctx = context.WithValue(ctx, api.Path, c.Request.URL.Path)
|
ctx = context.WithValue(ctx, api.Path, c.Request.URL.Path)
|
||||||
c.Set(api.Path, c.Request.URL.Path)
|
c.Set(api.Path, c.Request.URL.Path)
|
||||||
|
|
||||||
ctx, err := s.UseSpecialHandlers(ctx, c.Request, c.Writer)
|
r, err := s.UseSpecialHandlers(c.Writer, c.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleErrorResponse(c, err)
|
handleErrorResponse(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set("ctx", ctx)
|
c.Request = r
|
||||||
c.Set(api.AppName, ctx.Value(api.AppName).(string))
|
c.Set(api.AppName, r.Context().Value(api.AppName).(string))
|
||||||
if c.MustGet(api.AppName).(string) == "" {
|
if c.MustGet(api.AppName).(string) == "" {
|
||||||
handleErrorResponse(c, models.ErrRunnerRouteNotFound)
|
handleErrorResponse(c, models.ErrRunnerRouteNotFound)
|
||||||
return
|
return
|
||||||
@@ -66,7 +66,7 @@ func (s *Server) handleRequest(c *gin.Context, enqueue models.Enqueue) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := c.MustGet("ctx").(context.Context)
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
reqID := id.New().String()
|
reqID := id.New().String()
|
||||||
ctx, log := common.LoggerWithFields(ctx, logrus.Fields{"call_id": reqID})
|
ctx, log := common.LoggerWithFields(ctx, logrus.Fields{"call_id": reqID})
|
||||||
|
|||||||
@@ -161,8 +161,6 @@ func prepareMiddleware(ctx context.Context) gin.HandlerFunc {
|
|||||||
c.Set(api.Path, routePath)
|
c.Set(api.Path, routePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: can probably replace the "ctx" value with the Go 1.7 context on the http.Request
|
|
||||||
c.Set("ctx", ctx)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,73 +1,36 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"gitlab-odx.oracle.com/odx/functions/api/models"
|
"gitlab-odx.oracle.com/odx/functions/api/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SpecialHandler verysimilar to a handler but since it is not used as middle ware no way
|
||||||
|
// to get context without returning it. So we just return a request which could have newly made
|
||||||
|
// contexts.
|
||||||
type SpecialHandler interface {
|
type SpecialHandler interface {
|
||||||
Handle(c HandlerContext) error
|
Handle(w http.ResponseWriter, r *http.Request) (*http.Request, error)
|
||||||
}
|
|
||||||
|
|
||||||
// Each handler can modify the context here so when it gets passed along, it will use the new info.
|
|
||||||
type HandlerContext interface {
|
|
||||||
// Context return the context object
|
|
||||||
Context() context.Context
|
|
||||||
|
|
||||||
// Request returns the underlying http.Request object
|
|
||||||
Request() *http.Request
|
|
||||||
|
|
||||||
// Response returns the http.ResponseWriter
|
|
||||||
Response() http.ResponseWriter
|
|
||||||
|
|
||||||
// Overwrite value in the context
|
|
||||||
Set(key string, value interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
type SpecialHandlerContext struct {
|
|
||||||
request *http.Request
|
|
||||||
response http.ResponseWriter
|
|
||||||
ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SpecialHandlerContext) Context() context.Context {
|
|
||||||
return c.ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SpecialHandlerContext) Request() *http.Request {
|
|
||||||
return c.request
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SpecialHandlerContext) Response() http.ResponseWriter {
|
|
||||||
return c.response
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SpecialHandlerContext) Set(key string, value interface{}) {
|
|
||||||
c.ctx = context.WithValue(c.ctx, key, value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddSpecialHandler adds the SpecialHandler to the specialHandlers list.
|
||||||
func (s *Server) AddSpecialHandler(handler SpecialHandler) {
|
func (s *Server) AddSpecialHandler(handler SpecialHandler) {
|
||||||
s.specialHandlers = append(s.specialHandlers, handler)
|
s.specialHandlers = append(s.specialHandlers, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UseSpecialHandlers execute all special handlers
|
// UseSpecialHandlers execute all special handlers
|
||||||
func (s *Server) UseSpecialHandlers(ctx context.Context, req *http.Request, resp http.ResponseWriter) (context.Context, error) {
|
func (s *Server) UseSpecialHandlers(resp http.ResponseWriter, req *http.Request) (*http.Request, error) {
|
||||||
if len(s.specialHandlers) == 0 {
|
if len(s.specialHandlers) == 0 {
|
||||||
return ctx, models.ErrNoSpecialHandlerFound
|
return req, models.ErrNoSpecialHandlerFound
|
||||||
}
|
}
|
||||||
|
var r *http.Request
|
||||||
|
var err error
|
||||||
|
|
||||||
c := &SpecialHandlerContext{
|
|
||||||
request: req,
|
|
||||||
response: resp,
|
|
||||||
ctx: ctx,
|
|
||||||
}
|
|
||||||
for _, l := range s.specialHandlers {
|
for _, l := range s.specialHandlers {
|
||||||
err := l.Handle(c)
|
r, err = l.Handle(resp, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.ctx, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return c.ctx, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
type testSpecialHandler struct{}
|
type testSpecialHandler struct{}
|
||||||
|
|
||||||
func (h *testSpecialHandler) Handle(c HandlerContext) error {
|
func (h *testSpecialHandler) Handle(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
|
||||||
// c.Set(api.AppName, "test")
|
// r = r.WithContext(context.WithValue(r.Context(), api.AppName, "test"))
|
||||||
return nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSpecialHandlerSet(t *testing.T) {
|
func TestSpecialHandlerSet(t *testing.T) {
|
||||||
|
|||||||
@@ -3,13 +3,11 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gitlab-odx.oracle.com/odx/functions/api/models"
|
|
||||||
"gitlab-odx.oracle.com/odx/functions/api/server"
|
"gitlab-odx.oracle.com/odx/functions/api/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,13 +16,15 @@ func main() {
|
|||||||
|
|
||||||
funcServer := server.NewFromEnv(ctx)
|
funcServer := server.NewFromEnv(ctx)
|
||||||
|
|
||||||
funcServer.AddMiddlewareFunc(func(ctx server.MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error {
|
funcServer.AddMiddlewareFunc(func(next http.Handler) http.Handler {
|
||||||
start := time.Now()
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Println("CustomMiddlewareFunc called at:", start)
|
start := time.Now()
|
||||||
ctx.Next()
|
fmt.Println("CustomMiddlewareFunc called at:", start)
|
||||||
fmt.Println("Duration:", (time.Now().Sub(start)))
|
next.ServeHTTP(w, r)
|
||||||
return nil
|
fmt.Println("Duration:", (time.Now().Sub(start)))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
funcServer.AddMiddleware(&CustomMiddleware{})
|
funcServer.AddMiddleware(&CustomMiddleware{})
|
||||||
|
|
||||||
funcServer.Start(ctx)
|
funcServer.Start(ctx)
|
||||||
@@ -33,20 +33,22 @@ func main() {
|
|||||||
type CustomMiddleware struct {
|
type CustomMiddleware struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *CustomMiddleware) Serve(ctx server.MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error {
|
func (h *CustomMiddleware) Serve(next http.Handler) http.Handler {
|
||||||
fmt.Println("CustomMiddleware called")
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Println("CustomMiddleware called")
|
||||||
|
|
||||||
// check auth header
|
// check auth header
|
||||||
tokenHeader := strings.SplitN(r.Header.Get("Authorization"), " ", 3)
|
tokenHeader := strings.SplitN(r.Header.Get("Authorization"), " ", 3)
|
||||||
if len(tokenHeader) < 2 || tokenHeader[1] != "KlaatuBaradaNikto" {
|
if len(tokenHeader) < 2 || tokenHeader[1] != "KlaatuBaradaNikto" {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
m2 := map[string]string{"message": "Invalid Authorization token."}
|
m2 := map[string]string{"message": "Invalid Authorization token."}
|
||||||
m := map[string]map[string]string{"error": m2}
|
m := map[string]map[string]string{"error": m2}
|
||||||
json.NewEncoder(w).Encode(m)
|
json.NewEncoder(w).Encode(m)
|
||||||
return errors.New("Invalid authorization token.")
|
return
|
||||||
}
|
}
|
||||||
fmt.Println("auth succeeded!")
|
fmt.Println("auth succeeded!")
|
||||||
ctx.Set("user", "I'm in!")
|
r = r.WithContext(context.WithValue(r.Context(), "user", "I'm in!"))
|
||||||
return nil
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user