Use chi style middle ware

This commit is contained in:
James Jeffrey
2017-07-19 13:44:26 -07:00
committed by Travis Reeder
parent 570e9265f1
commit cf2c3cf404
21 changed files with 192 additions and 240 deletions

View File

@@ -1,30 +1,34 @@
package server
import (
"context"
"gitlab-odx.oracle.com/odx/functions/api/models"
)
// AppListener is an interface used to inject custom code at key points in app lifecycle.
type AppListener interface {
// BeforeAppCreate called right before creating App in the database
BeforeAppCreate(ctx MiddlewareContext, app *models.App) error
BeforeAppCreate(ctx context.Context, app *models.App) error
// AfterAppCreate called after creating App in the database
AfterAppCreate(ctx MiddlewareContext, app *models.App) error
AfterAppCreate(ctx context.Context, app *models.App) error
// BeforeAppUpdate called right before updating App in the database
BeforeAppUpdate(ctx MiddlewareContext, app *models.App) error
BeforeAppUpdate(ctx context.Context, app *models.App) error
// AfterAppUpdate called after updating App in the database
AfterAppUpdate(ctx MiddlewareContext, app *models.App) error
AfterAppUpdate(ctx context.Context, app *models.App) error
// BeforeAppDelete called right before deleting App in the database
BeforeAppDelete(ctx MiddlewareContext, app *models.App) error
BeforeAppDelete(ctx context.Context, app *models.App) error
// AfterAppDelete called after deleting App in the database
AfterAppDelete(ctx MiddlewareContext, app *models.App) error
AfterAppDelete(ctx context.Context, app *models.App) error
}
// AddAppCreateListener adds a listener that will be notified on App created.
// AddAppListener adds a listener that will be notified on App created.
func (s *Server) AddAppListener(listener AppListener) {
s.appListeners = append(s.appListeners, listener)
}
func (s *Server) FireBeforeAppCreate(ctx MiddlewareContext, app *models.App) error {
// FireBeforeAppCreate is used to call all the server's Listeners BeforeAppCreate functions.
func (s *Server) FireBeforeAppCreate(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners {
err := l.BeforeAppCreate(ctx, app)
if err != nil {
@@ -34,7 +38,8 @@ func (s *Server) FireBeforeAppCreate(ctx MiddlewareContext, app *models.App) err
return nil
}
func (s *Server) FireAfterAppCreate(ctx MiddlewareContext, app *models.App) error {
// FireAfterAppCreate is used to call all the server's Listeners AfterAppCreate functions.
func (s *Server) FireAfterAppCreate(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners {
err := l.AfterAppCreate(ctx, app)
if err != nil {
@@ -44,7 +49,8 @@ func (s *Server) FireAfterAppCreate(ctx MiddlewareContext, app *models.App) erro
return nil
}
func (s *Server) FireBeforeAppUpdate(ctx MiddlewareContext, app *models.App) error {
// FireBeforeAppUpdate is used to call all the server's Listeners BeforeAppUpdate functions.
func (s *Server) FireBeforeAppUpdate(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners {
err := l.BeforeAppUpdate(ctx, app)
if err != nil {
@@ -54,7 +60,8 @@ func (s *Server) FireBeforeAppUpdate(ctx MiddlewareContext, app *models.App) err
return nil
}
func (s *Server) FireAfterAppUpdate(ctx MiddlewareContext, app *models.App) error {
// FireAfterAppUpdate is used to call all the server's Listeners AfterAppUpdate functions.
func (s *Server) FireAfterAppUpdate(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners {
err := l.AfterAppUpdate(ctx, app)
if err != nil {
@@ -64,7 +71,8 @@ func (s *Server) FireAfterAppUpdate(ctx MiddlewareContext, app *models.App) erro
return nil
}
func (s *Server) FireBeforeAppDelete(ctx MiddlewareContext, app *models.App) error {
// FireBeforeAppDelete is used to call all the server's Listeners BeforeAppDelete functions.
func (s *Server) FireBeforeAppDelete(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners {
err := l.BeforeAppDelete(ctx, app)
if err != nil {
@@ -74,7 +82,8 @@ func (s *Server) FireBeforeAppDelete(ctx MiddlewareContext, app *models.App) err
return nil
}
func (s *Server) FireAfterAppDelete(ctx MiddlewareContext, app *models.App) error {
// FireAfterAppDelete is used to call all the server's Listeners AfterAppDelete functions.
func (s *Server) FireAfterAppDelete(ctx context.Context, app *models.App) error {
for _, l := range s.appListeners {
err := l.AfterAppDelete(ctx, app)
if err != nil {

View File

@@ -8,7 +8,7 @@ import (
)
func (s *Server) handleAppCreate(c *gin.Context) {
ctx := c.MustGet("mctx").(MiddlewareContext)
ctx := c.Request.Context()
var wapp models.AppWrapper
@@ -23,7 +23,7 @@ func (s *Server) handleAppCreate(c *gin.Context) {
return
}
if err := wapp.Validate(); err != nil {
if err = wapp.Validate(); err != nil {
handleErrorResponse(c, err)
return
}

View File

@@ -10,7 +10,7 @@ import (
)
func (s *Server) handleAppDelete(c *gin.Context) {
ctx := c.MustGet("mctx").(MiddlewareContext)
ctx := c.Request.Context()
log := common.Logger(ctx)
app := &models.App{Name: c.MustGet(api.AppName).(string)}

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
@@ -9,7 +8,7 @@ import (
)
func (s *Server) handleAppGet(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
appName := c.MustGet(api.AppName).(string)
app, err := s.Datastore.GetApp(ctx, appName)

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
@@ -9,7 +8,7 @@ import (
)
func (s *Server) handleAppList(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
filter := &models.AppFilter{}

View File

@@ -9,7 +9,7 @@ import (
)
func (s *Server) handleAppUpdate(c *gin.Context) {
ctx := c.MustGet("mctx").(MiddlewareContext)
ctx := c.Request.Context()
wapp := models.AppWrapper{}

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
@@ -9,7 +8,7 @@ import (
)
func (s *Server) handleCallGet(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
callID := c.Param(api.Call)
callObj, err := s.Datastore.GetTask(ctx, callID)

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
@@ -10,7 +9,7 @@ import (
)
func (s *Server) handleCallList(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
appName, ok := c.MustGet(api.AppName).(string)
if ok && appName == "" {

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
@@ -9,7 +8,7 @@ import (
)
func (s *Server) handleCallLogGet(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
callID := c.Param(api.Call)
_, err := s.Datastore.GetTask(ctx, callID)
@@ -28,7 +27,7 @@ func (s *Server) handleCallLogGet(c *gin.Context) {
}
func (s *Server) handleCallLogDelete(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
callID := c.Param(api.Call)
_, err := s.Datastore.GetTask(ctx, callID)

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"errors"
"net/http"
"runtime/debug"
@@ -12,6 +11,7 @@ import (
"gitlab-odx.oracle.com/odx/functions/api/runner/common"
)
// ErrInternalServerError returned when something exceptional happens.
var ErrInternalServerError = errors.New("internal server error")
func simpleError(err error) *models.Error {
@@ -19,14 +19,14 @@ func simpleError(err error) *models.Error {
}
func handleErrorResponse(c *gin.Context, err error) {
ctx := c.MustGet("ctx").(context.Context)
log := common.Logger(ctx)
if aerr, ok := err.(models.APIError); ok {
log.WithFields(logrus.Fields{"code": aerr.Code()}).WithError(err).Error("api error")
c.JSON(aerr.Code(), simpleError(err))
} else if err != nil {
// get a stack trace so we can trace this error
log := common.Logger(c.Request.Context())
switch e := err.(type) {
case models.APIError:
if e.Code() >= 500 {
log.WithFields(logrus.Fields{"code": e.Code()}).WithError(e).Error("api error")
}
c.JSON(e.Code(), simpleError(e))
default:
log.WithError(err).WithFields(logrus.Fields{"stack": string(debug.Stack())}).Error("internal server error")
c.JSON(http.StatusInternalServerError, simpleError(ErrInternalServerError))
}

View File

@@ -1,132 +1,69 @@
// TODO: it would be nice to move these into the top level folder so people can use these with the "functions" package, eg: functions.AddMiddleware(...)
package server
import (
"context"
"net/http"
"github.com/Sirupsen/logrus"
"gitlab-odx.oracle.com/odx/functions/api/runner/common"
"github.com/gin-gonic/gin"
"gitlab-odx.oracle.com/odx/functions/api/models"
)
// Middleware is the interface required for implementing functions middlewar
// Middleware just takes a http.Handler and returns one. So the next middle ware must be called
// within the returned handler or it would be ignored.
type Middleware interface {
// Serve is what the Middleware must implement. Can modify the request, write output, etc.
// todo: should we abstract the HTTP out of this? In case we want to support other protocols.
Serve(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error
Chain(next http.Handler) http.Handler
}
// MiddlewareFunc func form of Middleware
type MiddlewareFunc func(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error
// MiddlewareFunc is a here to allow a plain function to be a middleware.
type MiddlewareFunc func(next http.Handler) http.Handler
// Serve wrapper
func (f MiddlewareFunc) Serve(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error {
return f(ctx, w, r, app)
}
// MiddlewareContext extends context.Context for Middleware
type MiddlewareContext interface {
context.Context
// Set is used to store a new key/value pair exclusively for this context.
// This is different than WithValue(), as it does not make a copy of the context with the new value, it will be available up the chain as well.
Set(key string, value interface{})
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
Get(key string) (value interface{}, exists bool)
// MustGet returns the value for the given key if it exists, otherwise it panics.
MustGet(key string) interface{}
// Middleware can call Next() explicitly to call the next middleware in the chain. If Next() is not called and an error is not returned, Next() will automatically be called.
Next()
// Index returns the index of where we're at in the chain
Index() int
}
type middlewareContextImpl struct {
context.Context
ginContext *gin.Context
nextCalled bool
index int
middlewares []Middleware
}
// Set is used to store a new key/value pair exclusively for this context.
// This is different than WithValue(), as it does not make a copy of the context with the new value, it will be available up the chain as well.
func (c *middlewareContextImpl) Set(key string, value interface{}) {
c.ginContext.Set(key, value)
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *middlewareContextImpl) Get(key string) (value interface{}, exists bool) {
return c.ginContext.Get(key)
}
// MustGet returns the value for the given key if it exists, otherwise it panics.
func (c *middlewareContextImpl) MustGet(key string) interface{} {
return c.ginContext.MustGet(key)
}
func (c *middlewareContextImpl) Next() {
c.nextCalled = true
c.index++
c.serveNext()
}
func (c *middlewareContextImpl) serveNext() {
if c.Index() >= len(c.middlewares) {
return
}
// make shallow copy:
fctx2 := *c
fctx2.nextCalled = false
r := c.ginContext.Request.WithContext(fctx2)
err := c.middlewares[c.Index()].Serve(&fctx2, c.ginContext.Writer, r, nil)
if err != nil {
logrus.WithError(err).Warnln("Middleware error")
// todo: might be a good idea to check if anything is written yet, and if not, output the error: simpleError(err)
// see: http://stackoverflow.com/questions/39415827/golang-http-check-if-responsewriter-has-been-written
c.ginContext.Error(err)
c.ginContext.Abort()
return
}
if !fctx2.nextCalled {
// then we automatically call next
fctx2.Next()
}
}
func (c *middlewareContextImpl) Index() int {
return c.index
// Chain used to allow middlewarefuncs to be middleware.
func (m MiddlewareFunc) Chain(next http.Handler) http.Handler {
return m(next)
}
func (s *Server) middlewareWrapperFunc(ctx context.Context) gin.HandlerFunc {
return func(c *gin.Context) {
// TODO: we should get rid of this, gin context and middleware context both implement context, don't need a third one here
ctx = c.MustGet("ctx").(context.Context)
fctx := &middlewareContextImpl{Context: ctx}
// add this context to gin context so we can grab it later
c.Set("mctx", fctx)
fctx.index = -1
fctx.ginContext = c
fctx.middlewares = s.middlewares
// start the chain:
fctx.Next()
if len(s.middlewares) > 0 {
defer func() {
//This is so that if the server errors or panics on a middleware the server will still respond and not send eof to client.
err := recover()
if err != nil {
common.Logger(c.Request.Context()).WithField("MiddleWarePanicRecovery:", err).Errorln("A panic occurred during middleware.")
handleErrorResponse(c, ErrInternalServerError)
}
}()
var h http.Handler
keepgoing := false
h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.Request = c.Request.WithContext(r.Context())
keepgoing = true
})
s.chainAndServe(c.Writer, c.Request, h)
if !keepgoing {
c.Abort()
}
}
}
}
func (s *Server) chainAndServe(w http.ResponseWriter, r *http.Request, h http.Handler) {
for _, m := range s.middlewares {
h = m.Chain(h)
}
h.ServeHTTP(w, r)
}
// AddMiddleware add middleware
func (s *Server) AddMiddleware(m Middleware) {
s.middlewares = append(s.middlewares, m)
//Prepend to array so that we can do first,second,third,last,third,second,first
//and not third,second,first,last,first,second,third
s.middlewares = append([]Middleware{m}, s.middlewares...)
}
// AddMiddlewareFunc adds middleware function
func (s *Server) AddMiddlewareFunc(m func(ctx MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error) {
s.AddMiddleware(MiddlewareFunc(m))
// AddMiddlewareFunc add middlewarefunc
func (s *Server) AddMiddlewareFunc(m MiddlewareFunc) {
s.AddMiddleware(m)
}

View File

@@ -0,0 +1,46 @@
package server
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
type middleWareStruct struct {
name string
}
func (m *middleWareStruct) Chain(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(m.name + ","))
next.ServeHTTP(w, r)
})
}
func TestMiddleWareChaining(t *testing.T) {
var lastHandler http.Handler
lastHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("last"))
})
s := Server{}
s.AddMiddleware(&middleWareStruct{"first"})
s.AddMiddleware(&middleWareStruct{"second"})
s.AddMiddleware(&middleWareStruct{"third"})
s.AddMiddleware(&middleWareStruct{"fourth"})
rec := httptest.NewRecorder()
req, _ := http.NewRequest("get", "http://localhost/", nil)
s.chainAndServe(rec, req, lastHandler)
result, err := ioutil.ReadAll(rec.Result().Body)
if err != nil {
t.Fatal(err)
}
if string(result) != "first,second,third,fourth,last" {
t.Fatal("You failed to chain correctly.", string(result))
}
}

View File

@@ -23,12 +23,12 @@ import (
Patch accepts partial updates / skips validation of zero values.
*/
func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) {
ctx := c.MustGet("mctx").(MiddlewareContext)
ctx := c.Request.Context()
method := strings.ToUpper(c.Request.Method)
var wroute models.RouteWrapper
err := s.bindAndValidate(ctx, c, method, &wroute)
err := s.bindAndValidate(c, method, &wroute)
if err != nil {
handleErrorResponse(c, err)
return
@@ -53,7 +53,7 @@ func (s *Server) handleRouteCreateOrUpdate(c *gin.Context) {
}
// ensureApp will only execute if it is on post or put. Patch is not allowed to create apps.
func (s *Server) ensureApp(ctx MiddlewareContext, wroute *models.RouteWrapper, method string) error {
func (s *Server) ensureApp(ctx context.Context, wroute *models.RouteWrapper, method string) error {
if !(method == http.MethodPost || method == http.MethodPut) {
return nil
}
@@ -90,7 +90,7 @@ func (s *Server) ensureApp(ctx MiddlewareContext, wroute *models.RouteWrapper, m
If it is a put or patch it makes sure that the path in the url matches the provideed one in the body.
Defaults are set and if patch skipZero is true for validating the RouteWrapper
*/
func (s *Server) bindAndValidate(ctx context.Context, c *gin.Context, method string, wroute *models.RouteWrapper) error {
func (s *Server) bindAndValidate(c *gin.Context, method string, wroute *models.RouteWrapper) error {
err := c.BindJSON(wroute)
if err != nil {
return models.ErrInvalidJSON

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/http"
"path"
@@ -10,7 +9,7 @@ import (
)
func (s *Server) handleRouteDelete(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
appName := c.MustGet(api.AppName).(string)
routePath := path.Clean(c.MustGet(api.Path).(string))

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/http"
"path"
@@ -10,7 +9,7 @@ import (
)
func (s *Server) handleRouteGet(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
appName := c.MustGet(api.AppName).(string)
routePath := path.Clean(c.MustGet(api.Path).(string))

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
@@ -10,7 +9,7 @@ import (
)
func (s *Server) handleRouteList(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
filter := &models.RouteFilter{}
@@ -20,8 +19,10 @@ func (s *Server) handleRouteList(c *gin.Context) {
var routes []*models.Route
var err error
if appName, ok := c.MustGet(api.AppName).(string); ok && appName != "" {
routes, err = s.Datastore.GetRoutesByApp(ctx, appName, filter)
appName, exists := c.Get(api.AppName)
name, ok := appName.(string)
if exists && ok && name != "" {
routes, err = s.Datastore.GetRoutesByApp(ctx, name, filter)
} else {
routes, err = s.Datastore.GetRoutes(ctx, filter)
}
@@ -31,5 +32,5 @@ func (s *Server) handleRouteList(c *gin.Context) {
return
}
c.JSON(http.StatusOK, routesResponse{"Sucessfully listed routes", routes})
c.JSON(http.StatusOK, routesResponse{"Successfully listed routes", routes})
}

View File

@@ -28,21 +28,21 @@ type runnerResponse struct {
}
func (s *Server) handleSpecial(c *gin.Context) {
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
ctx = context.WithValue(ctx, api.AppName, "")
c.Set(api.AppName, "")
ctx = context.WithValue(ctx, api.Path, c.Request.URL.Path)
c.Set(api.Path, c.Request.URL.Path)
ctx, err := s.UseSpecialHandlers(ctx, c.Request, c.Writer)
r, err := s.UseSpecialHandlers(c.Writer, c.Request)
if err != nil {
handleErrorResponse(c, err)
return
}
c.Set("ctx", ctx)
c.Set(api.AppName, ctx.Value(api.AppName).(string))
c.Request = r
c.Set(api.AppName, r.Context().Value(api.AppName).(string))
if c.MustGet(api.AppName).(string) == "" {
handleErrorResponse(c, models.ErrRunnerRouteNotFound)
return
@@ -66,7 +66,7 @@ func (s *Server) handleRequest(c *gin.Context, enqueue models.Enqueue) {
return
}
ctx := c.MustGet("ctx").(context.Context)
ctx := c.Request.Context()
reqID := id.New().String()
ctx, log := common.LoggerWithFields(ctx, logrus.Fields{"call_id": reqID})

View File

@@ -161,8 +161,6 @@ func prepareMiddleware(ctx context.Context) gin.HandlerFunc {
c.Set(api.Path, routePath)
}
// todo: can probably replace the "ctx" value with the Go 1.7 context on the http.Request
c.Set("ctx", ctx)
c.Request = c.Request.WithContext(ctx)
c.Next()
}

View File

@@ -1,73 +1,36 @@
package server
import (
"context"
"net/http"
"gitlab-odx.oracle.com/odx/functions/api/models"
)
// SpecialHandler verysimilar to a handler but since it is not used as middle ware no way
// to get context without returning it. So we just return a request which could have newly made
// contexts.
type SpecialHandler interface {
Handle(c HandlerContext) error
}
// Each handler can modify the context here so when it gets passed along, it will use the new info.
type HandlerContext interface {
// Context return the context object
Context() context.Context
// Request returns the underlying http.Request object
Request() *http.Request
// Response returns the http.ResponseWriter
Response() http.ResponseWriter
// Overwrite value in the context
Set(key string, value interface{})
}
type SpecialHandlerContext struct {
request *http.Request
response http.ResponseWriter
ctx context.Context
}
func (c *SpecialHandlerContext) Context() context.Context {
return c.ctx
}
func (c *SpecialHandlerContext) Request() *http.Request {
return c.request
}
func (c *SpecialHandlerContext) Response() http.ResponseWriter {
return c.response
}
func (c *SpecialHandlerContext) Set(key string, value interface{}) {
c.ctx = context.WithValue(c.ctx, key, value)
Handle(w http.ResponseWriter, r *http.Request) (*http.Request, error)
}
// AddSpecialHandler adds the SpecialHandler to the specialHandlers list.
func (s *Server) AddSpecialHandler(handler SpecialHandler) {
s.specialHandlers = append(s.specialHandlers, handler)
}
// UseSpecialHandlers execute all special handlers
func (s *Server) UseSpecialHandlers(ctx context.Context, req *http.Request, resp http.ResponseWriter) (context.Context, error) {
func (s *Server) UseSpecialHandlers(resp http.ResponseWriter, req *http.Request) (*http.Request, error) {
if len(s.specialHandlers) == 0 {
return ctx, models.ErrNoSpecialHandlerFound
return req, models.ErrNoSpecialHandlerFound
}
var r *http.Request
var err error
c := &SpecialHandlerContext{
request: req,
response: resp,
ctx: ctx,
}
for _, l := range s.specialHandlers {
err := l.Handle(c)
r, err = l.Handle(resp, req)
if err != nil {
return c.ctx, err
return nil, err
}
}
return c.ctx, nil
return r, nil
}

View File

@@ -1,12 +1,15 @@
package server
import "testing"
import (
"net/http"
"testing"
)
type testSpecialHandler struct{}
func (h *testSpecialHandler) Handle(c HandlerContext) error {
// c.Set(api.AppName, "test")
return nil
func (h *testSpecialHandler) Handle(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
// r = r.WithContext(context.WithValue(r.Context(), api.AppName, "test"))
return r, nil
}
func TestSpecialHandlerSet(t *testing.T) {

View File

@@ -3,13 +3,11 @@ package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"gitlab-odx.oracle.com/odx/functions/api/models"
"gitlab-odx.oracle.com/odx/functions/api/server"
)
@@ -18,13 +16,15 @@ func main() {
funcServer := server.NewFromEnv(ctx)
funcServer.AddMiddlewareFunc(func(ctx server.MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error {
start := time.Now()
fmt.Println("CustomMiddlewareFunc called at:", start)
ctx.Next()
fmt.Println("Duration:", (time.Now().Sub(start)))
return nil
funcServer.AddMiddlewareFunc(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
fmt.Println("CustomMiddlewareFunc called at:", start)
next.ServeHTTP(w, r)
fmt.Println("Duration:", (time.Now().Sub(start)))
})
})
funcServer.AddMiddleware(&CustomMiddleware{})
funcServer.Start(ctx)
@@ -33,20 +33,22 @@ func main() {
type CustomMiddleware struct {
}
func (h *CustomMiddleware) Serve(ctx server.MiddlewareContext, w http.ResponseWriter, r *http.Request, app *models.App) error {
fmt.Println("CustomMiddleware called")
func (h *CustomMiddleware) Serve(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println("CustomMiddleware called")
// check auth header
tokenHeader := strings.SplitN(r.Header.Get("Authorization"), " ", 3)
if len(tokenHeader) < 2 || tokenHeader[1] != "KlaatuBaradaNikto" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
m2 := map[string]string{"message": "Invalid Authorization token."}
m := map[string]map[string]string{"error": m2}
json.NewEncoder(w).Encode(m)
return errors.New("Invalid authorization token.")
}
fmt.Println("auth succeeded!")
ctx.Set("user", "I'm in!")
return nil
// check auth header
tokenHeader := strings.SplitN(r.Header.Get("Authorization"), " ", 3)
if len(tokenHeader) < 2 || tokenHeader[1] != "KlaatuBaradaNikto" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
m2 := map[string]string{"message": "Invalid Authorization token."}
m := map[string]map[string]string{"error": m2}
json.NewEncoder(w).Encode(m)
return
}
fmt.Println("auth succeeded!")
r = r.WithContext(context.WithValue(r.Context(), "user", "I'm in!"))
next.ServeHTTP(w, r)
})
}