From e0b82519aa996e9ff25e87e1d5402c2c06e9ff60 Mon Sep 17 00:00:00 2001 From: Srinidhi Chokkadi Puranik Date: Mon, 30 Apr 2018 13:13:24 -0700 Subject: [PATCH] Last middleware should use the request passed by preceding middleware. (#965) This is useful when preceding middleware reads httpRequest.Body to perform some logic, and assigns a new ReadCloser to httpRequest.Body (as body can be read only once). --- api/server/middleware.go | 2 +- api/server/middleware_test.go | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/api/server/middleware.go b/api/server/middleware.go index 82085f071..1d47d81d7 100644 --- a/api/server/middleware.go +++ b/api/server/middleware.go @@ -93,7 +93,7 @@ func (s *Server) runMiddleware(c *gin.Context, ms []fnext.Middleware) { c.Abort() return } - c.Request = c.Request.WithContext(ctx) + c.Request = r.WithContext(ctx) c.Next() }) diff --git a/api/server/middleware_test.go b/api/server/middleware_test.go index f603bba96..2701bfede 100644 --- a/api/server/middleware_test.go +++ b/api/server/middleware_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "context" "io/ioutil" "net/http" @@ -110,6 +111,14 @@ func TestRootMiddleware(t *testing.T) { }) }) srv.AddRootMiddleware(&middleWareStruct{"middle"}) + srv.AddRootMiddlewareFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Log("body reader log") + bodyBytes, _ := ioutil.ReadAll(r.Body) + r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + next.ServeHTTP(w, r) + }) + }) for i, test := range []struct { path string @@ -146,4 +155,16 @@ func TestRootMiddleware(t *testing.T) { t.Fatal(i, "middleware didn't work correctly", string(result)) } } + + req, err := http.NewRequest("POST", "http://127.0.0.1:8080/v1/apps", strings.NewReader("{\"app\": {\"name\": \"myapp3\"}}")) + if err != nil { + t.Fatalf("Test: Could not create create app request") + } + t.Log("TESTING: Create myapp3 when a middleware reads the body") + _, rec := routerRequest2(t, srv.Router, req) + + res, _ := ioutil.ReadAll(rec.Result().Body) + if !strings.Contains(string(res), "myapp3") { + t.Fatal("Middleware did not pass the request correctly to route handler") + } }