diff --git a/api/agent/protocol/http.go b/api/agent/protocol/http.go index f15f46751..6da93792e 100644 --- a/api/agent/protocol/http.go +++ b/api/agent/protocol/http.go @@ -2,15 +2,23 @@ package protocol import ( "bufio" + "bytes" "context" "fmt" "io" + "io/ioutil" "net/http" + "strconv" + "sync" "github.com/fnproject/fn/api/models" opentracing "github.com/opentracing/opentracing-go" ) +var ( + bufPool = &sync.Pool{New: func() interface{} { return new(bytes.Buffer) }} +) + // HTTPProtocol converts stdin/stdout streams into HTTP/1.1 compliant // communication. It relies on Content-Length to know when to stop reading from // containers stdout. It also mandates valid HTTP headers back and forth, thus @@ -50,11 +58,20 @@ func (h *HTTPProtocol) Dispatch(ctx context.Context, ci CallInfo, w io.Writer) e if err != nil { return models.NewAPIError(http.StatusBadGateway, fmt.Errorf("invalid http response from function err: %v", err)) } - defer resp.Body.Close() span, _ = opentracing.StartSpanFromContext(ctx, "dispatch_http_write_response") defer span.Finish() + buf := bufPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufPool.Put(buf) + + // copy the response body into a buffer so that we read the whole thing. then set the content length. + io.Copy(buf, resp.Body) + resp.Body.Close() + resp.Body = ioutil.NopCloser(buf) + resp.Header.Set("Content-Length", strconv.Itoa(buf.Len())) + rw, ok := w.(http.ResponseWriter) if !ok { // async / [some] tests go through here. write a full http request to the writer diff --git a/api/agent/protocol/json.go b/api/agent/protocol/json.go index aaa363e16..49597e02e 100644 --- a/api/agent/protocol/json.go +++ b/api/agent/protocol/json.go @@ -7,21 +7,14 @@ import ( "fmt" "io" "net/http" + "strconv" "github.com/fnproject/fn/api/models" opentracing "github.com/opentracing/opentracing-go" ) -// This is sent into the function -// All HTTP request headers should be set in env -type jsonio struct { - Body string `json:"body"` - ContentType string `json:"content_type"` -} - // CallRequestHTTP for the protocol that was used by the end user to call this function. We only have HTTP right now. type CallRequestHTTP struct { - // TODO request method ? Type string `json:"type"` Method string `json:"method"` RequestURL string `json:"request_url"` @@ -36,19 +29,19 @@ type CallResponseHTTP struct { // jsonIn We're not using this since we're writing JSON directly right now, but trying to keep it current anyways, much easier to read/follow type jsonIn struct { - jsonio - CallID string `json:"call_id"` - ContentType string `json:"content_type"` - Type string `json:"type"` - Deadline string `json:"deadline"` - Body string `json:"body"` - Protocol *CallRequestHTTP `json:"protocol"` + CallID string `json:"call_id"` + Type string `json:"type"` + Deadline string `json:"deadline"` + Body string `json:"body"` + ContentType string `json:"content_type"` + Protocol CallRequestHTTP `json:"protocol"` } // jsonOut the expected response from the function container type jsonOut struct { - jsonio - Protocol *CallResponseHTTP `json:"protocol,omitempty"` + Body string `json:"body"` + ContentType string `json:"content_type"` + Protocol *CallResponseHTTP `json:"protocol,omitempty"` } // JSONProtocol converts stdin/stdout streams from HTTP into JSON format. @@ -62,131 +55,33 @@ func (p *JSONProtocol) IsStreamable() bool { return true } -func writeString(err error, dst io.Writer, str string) error { - if err != nil { - return err - } - _, err = io.WriteString(dst, str) - return err -} - -// TODO(xxx): headers, query parameters, body - what else should we add to func's payload? -// TODO(xxx): get rid of request body buffering somehow -// @treeder: I don't know why we don't just JSON marshal this, this is rough... func (h *JSONProtocol) writeJSONToContainer(ci CallInfo) error { - stdin := json.NewEncoder(h.in) - bb := new(bytes.Buffer) - _, err := bb.ReadFrom(ci.Input()) - // todo: better/simpler err handling - if err != nil { - return err - } - // open - err = writeString(err, h.in, "{\n") + buf := bufPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufPool.Put(buf) + + _, err := io.Copy(buf, ci.Input()) if err != nil { return err } - // call_id - err = writeString(err, h.in, `"call_id":`) - if err != nil { - return err - } - err = stdin.Encode(ci.CallID()) - if err != nil { - return err + body := buf.String() + + in := jsonIn{ + Body: body, + ContentType: ci.ContentType(), + CallID: ci.CallID(), + Type: ci.CallType(), + Deadline: ci.Deadline().String(), + Protocol: CallRequestHTTP{ + Type: ci.ProtocolType(), + Method: ci.Method(), + RequestURL: ci.RequestURL(), + Headers: ci.Headers(), + }, } - // content_type - err = writeString(err, h.in, ",") - err = writeString(err, h.in, `"content_type":`) - if err != nil { - return err - } - err = stdin.Encode(ci.ContentType()) - if err != nil { - return err - } - - // Call type (sync or async) - err = writeString(err, h.in, ",") - err = writeString(err, h.in, `"type":`) - if err != nil { - return err - } - err = stdin.Encode(ci.CallType()) - if err != nil { - return err - } - - // deadline - err = writeString(err, h.in, ",") - err = writeString(err, h.in, `"deadline":`) - if err != nil { - return err - } - err = stdin.Encode(ci.Deadline().String()) - if err != nil { - return err - } - - // body - err = writeString(err, h.in, ",") - err = writeString(err, h.in, `"body":`) - if err != nil { - return err - } - err = stdin.Encode(bb.String()) - if err != nil { - return err - } - - // now the extras - err = writeString(err, h.in, ",") - err = writeString(err, h.in, `"protocol":{`) // OK name? This is what OpenEvents is calling it in initial proposal - { - // Protocol type used to initiate the call. - err = writeString(err, h.in, `"type":`) - if err != nil { - return err - } - err = stdin.Encode(ci.ProtocolType()) - - // request method - err = writeString(err, h.in, ",") - err = writeString(err, h.in, `"method":`) - if err != nil { - return err - } - err = stdin.Encode(ci.Method()) - if err != nil { - return err - } - - // request URL - err = writeString(err, h.in, ",") - err = writeString(err, h.in, `"request_url":`) - if err != nil { - return err - } - err = stdin.Encode(ci.RequestURL()) - if err != nil { - return err - } - - // HTTP headers - err = writeString(err, h.in, ",") - err = writeString(err, h.in, `"headers":`) - if err != nil { - return err - } - err = stdin.Encode(ci.Headers()) - } - err = writeString(err, h.in, "}") - - // close - err = writeString(err, h.in, "\n}\n\n") - return err + return json.NewEncoder(h.in).Encode(in) } func (h *JSONProtocol) Dispatch(ctx context.Context, ci CallInfo, w io.Writer) error { @@ -232,6 +127,7 @@ func (h *JSONProtocol) Dispatch(ctx context.Context, ci CallInfo, w io.Writer) e rw.WriteHeader(p.StatusCode) } } + rw.Header().Set("Content-Length", strconv.Itoa(len(jout.Body))) _, err = io.WriteString(rw, jout.Body) return err }