Files
fn-serverless/api/server/middleware_test.go
Travis Reeder 0798f9fac8 Middleware upgrade (#554)
* Adds root level middleware

* Added todo

* Better way for extensions to be added.

* Bad conflict merge?
2017-12-05 08:22:03 -08:00

150 lines
4.6 KiB
Go

package server
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/fnproject/fn/api/datastore"
"github.com/fnproject/fn/api/logs"
"github.com/fnproject/fn/api/models"
"github.com/fnproject/fn/api/mqs"
"github.com/fnproject/fn/fnext"
"github.com/gin-gonic/gin"
)
func TestMain(m *testing.M) {
// call flag.Parse() here if TestMain uses flags
os.Exit(m.Run())
}
type middleWareStruct struct {
name string
}
func (m *middleWareStruct) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(m.name + ","))
next.ServeHTTP(w, r)
})
}
func TestMiddlewareChaining(t *testing.T) {
var lastHandler http.Handler
lastHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("last"))
})
s := Server{}
s.AddAPIMiddleware(&middleWareStruct{"first"})
s.AddAPIMiddleware(&middleWareStruct{"second"})
s.AddAPIMiddleware(&middleWareStruct{"third"})
s.AddAPIMiddleware(&middleWareStruct{"fourth"})
c := &gin.Context{}
rec := httptest.NewRecorder()
req, _ := http.NewRequest("get", "http://localhost/", nil)
ctx := context.WithValue(req.Context(), fnext.MiddlewareControllerKey, s.newMiddlewareController(c))
req = req.WithContext(ctx)
c.Request = req
chainAndServe(s.apiMiddlewares, rec, req, lastHandler)
result, err := ioutil.ReadAll(rec.Result().Body)
if err != nil {
t.Fatal(err)
}
if string(result) != "first,second,third,fourth,last" {
t.Fatal("You failed to chain correctly:", string(result))
}
}
func TestRootMiddleware(t *testing.T) {
ds := datastore.NewMockInit(
[]*models.App{
{Name: "myapp", Config: models.Config{}},
{Name: "myapp2", Config: models.Config{}},
},
[]*models.Route{
{Path: "/", AppName: "myapp", Image: "fnproject/hello", Type: "sync", Memory: 128, Timeout: 30, IdleTimeout: 30, Headers: map[string][]string{"X-Function": {"Test"}}},
{Path: "/myroute", AppName: "myapp", Image: "fnproject/hello", Type: "sync", Memory: 128, Timeout: 30, IdleTimeout: 30, Headers: map[string][]string{"X-Function": {"Test"}}},
{Path: "/app2func", AppName: "myapp2", Image: "fnproject/hello", Type: "sync", Memory: 128, Timeout: 30, IdleTimeout: 30, Headers: map[string][]string{"X-Function": {"Test"}},
Config: map[string]string{"NAME": "johnny"},
},
}, nil,
)
rnr, cancelrnr := testRunner(t, ds)
defer cancelrnr()
fnl := logs.NewMock()
srv := testServer(ds, &mqs.Mock{}, fnl, rnr)
srv.AddRootMiddlewareFunc(func(next http.Handler) http.Handler {
// this one will override a call to the API based on a header
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("funcit") != "" {
fmt.Fprintf(os.Stderr, "breaker breaker!\n")
ctx := r.Context()
// TODO: this is a little dicey, should have some functions to set these in case the context keys change or something.
ctx = context.WithValue(ctx, "app_name", "myapp2")
ctx = context.WithValue(ctx, "path", "/app2func")
mctx := fnext.GetMiddlewareController(ctx)
mctx.CallFunction(w, r.WithContext(ctx))
return
}
// If any context changes, user should use this: next.ServeHTTP(w, r.WithContext(ctx))
next.ServeHTTP(w, r)
})
})
srv.AddRootMiddlewareFunc(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// fmt.Fprintf(os.Stderr, "middle log\n")
next.ServeHTTP(w, r)
})
})
srv.AddRootMiddleware(&middleWareStruct{"middle"})
for i, test := range []struct {
path string
body string
method string
headers map[string][]string
expectedCode int
expectedInBody string
}{
{"/r/myapp", ``, "GET", map[string][]string{}, http.StatusOK, "middle"},
{"/r/myapp/myroute", ``, "GET", map[string][]string{}, http.StatusOK, "middle"},
{"/v1/apps", ``, "GET", map[string][]string{"funcit": {"Test"}}, http.StatusOK, "johnny"},
} {
body := strings.NewReader(test.body)
req, err := http.NewRequest(test.method, "http://127.0.0.1:8080"+test.path, body)
if err != nil {
t.Fatalf("Test: Could not create %s request to %s: %v", test.method, test.path, err)
}
for k, v := range test.headers {
req.Header.Add(k, v[0])
}
fmt.Println("TESTING:", req.URL.String())
_, rec := routerRequest2(t, srv.Router, req)
// t.Log("REC: %+v\n", rec)
result, err := ioutil.ReadAll(rec.Result().Body)
if err != nil {
t.Fatal(err)
}
rbody := string(result)
t.Log("rbody:", rbody)
if !strings.Contains(rbody, test.expectedInBody) {
t.Fatal(i, "middleware didn't work correctly", string(result))
}
}
}