diff --git a/api/agent/protocol/json.go b/api/agent/protocol/json.go index c85672790..77f065c27 100644 --- a/api/agent/protocol/json.go +++ b/api/agent/protocol/json.go @@ -9,7 +9,7 @@ import ( // This is sent into the function // All HTTP request headers should be set in env -type JSONIO struct { +type jsonio struct { Headers http.Header `json:"headers,omitempty"` Body string `json:"body"` StatusCode int `json:"status_code,omitempty"` @@ -25,65 +25,35 @@ func (p *JSONProtocol) IsStreamable() bool { return true } -type RequestEncoder struct { - *json.Encoder +func writeString(err error, dst io.Writer, str string) error { + if err != nil { + return err + } + _, err = io.WriteString(dst, str) + return err } -func (h *JSONProtocol) DumpJSON(w io.Writer, req *http.Request) error { +func (h *JSONProtocol) DumpJSON(req *http.Request) error { stdin := json.NewEncoder(h.in) - _, err := io.WriteString(h.in, `{`) - if err != nil { - // this shouldn't happen - return err - } - bb := new(bytes.Buffer) - _, err = bb.ReadFrom(req.Body) - if err != nil { - return err - } - reqData := bb.String() - if reqData != "" { - _, err := io.WriteString(h.in, `"body": `) - if err != nil { - // this shouldn't happen - return err - } - err = stdin.Encode(reqData) - if err != nil { - return err - } - _, err = io.WriteString(h.in, `,`) - if err != nil { - // this shouldn't happen - return err - } - defer bb.Reset() - } - _, err = io.WriteString(h.in, `"headers:"`) - if err != nil { - // this shouldn't happen - return err - } + _, err := bb.ReadFrom(req.Body) + err = writeString(err, h.in, "{") + err = writeString(err, h.in, `"body":`) + err = stdin.Encode(bb.String()) + err = writeString(err, h.in, ",") + defer bb.Reset() + err = writeString(err, h.in, `"headers":`) err = stdin.Encode(req.Header) - if err != nil { - // this shouldn't happen - return err - } - _, err = io.WriteString(h.in, `"}`) - if err != nil { - // this shouldn't happen - return err - } - return nil + err = writeString(err, h.in, "}") + return err } func (h *JSONProtocol) Dispatch(w io.Writer, req *http.Request) error { - err := h.DumpJSON(w, req) + err := h.DumpJSON(req) if err != nil { return err } - jout := new(JSONIO) + jout := new(jsonio) dec := json.NewDecoder(h.out) if err := dec.Decode(jout); err != nil { return err diff --git a/api/agent/protocol/json_test.go b/api/agent/protocol/json_test.go new file mode 100644 index 000000000..f485ad94a --- /dev/null +++ b/api/agent/protocol/json_test.go @@ -0,0 +1,102 @@ +package protocol + +import ( + "bytes" + "testing" + "net/http" + "net/url" + "io/ioutil" + "io" + "encoding/json" +) + +type RequestData struct { + A string `json:"a"` +} + +func TestJSONProtocolDumpJSONRequestWithData(t *testing.T) { + req := &http.Request{ + Method: http.MethodPost, + URL: &url.URL{ + Scheme: "http", + Host: "localhost:8080", + Path: "/v1/apps", + RawQuery: "something=something&etc=etc", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Host": []string{"localhost:8080"}, + "User-Agent": []string{"curl/7.51.0"}, + "Content-Type": []string{"application/json"}, + }, + Host: "localhost:8080", + } + var buf bytes.Buffer + json.NewEncoder(&buf).Encode(RequestData{A: "a"}) + req.Body = ioutil.NopCloser(&buf) + + r, w := io.Pipe() + proto := JSONProtocol{w,r} + go func() { + err := proto.DumpJSON(req) + if err != nil { + t.Error(err.Error()) + } + w.Close() + }() + incomingReq := new(jsonio) + bb := new(bytes.Buffer) + + _, err := bb.ReadFrom(r) + if err != nil { + t.Error(err.Error()) + } + err = json.Unmarshal(bb.Bytes(), incomingReq) + if err != nil { + t.Error(err.Error()) + } +} + +func TestJSONProtocolDumpJSONRequestWithoutData(t *testing.T) { + req := &http.Request{ + Method: http.MethodPost, + URL: &url.URL{ + Scheme: "http", + Host: "localhost:8080", + Path: "/v1/apps", + RawQuery: "something=something&etc=etc", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Host": []string{"localhost:8080"}, + "User-Agent": []string{"curl/7.51.0"}, + "Content-Type": []string{"application/json"}, + }, + Host: "localhost:8080", + } + var buf bytes.Buffer + req.Body = ioutil.NopCloser(&buf) + + r, w := io.Pipe() + proto := JSONProtocol{w,r} + go func() { + err := proto.DumpJSON(req) + if err != nil { + t.Error(err.Error()) + } + w.Close() + }() + incomingReq := new(jsonio) + bb := new(bytes.Buffer) + + _, err := bb.ReadFrom(r) + if err != nil { + t.Error(err.Error()) + } + err = json.Unmarshal(bb.Bytes(), incomingReq) + if err != nil { + t.Error(err.Error()) + } +}