update vendor/ dir to latest w/o heroku, moby

had to lock a lot of things in place
This commit is contained in:
Reed Allman
2017-08-03 02:38:15 -07:00
parent 780791da1c
commit 30f3c45dbc
5637 changed files with 191713 additions and 1133103 deletions

View File

@@ -4,7 +4,10 @@
package http2
import "fmt"
import (
"errors"
"fmt"
)
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
type ErrCode uint32
@@ -88,3 +91,32 @@ type connError struct {
func (e connError) Error() string {
return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
}
type pseudoHeaderError string
func (e pseudoHeaderError) Error() string {
return fmt.Sprintf("invalid pseudo-header %q", string(e))
}
type duplicatePseudoHeaderError string
func (e duplicatePseudoHeaderError) Error() string {
return fmt.Sprintf("duplicate pseudo-header %q", string(e))
}
type headerFieldNameError string
func (e headerFieldNameError) Error() string {
return fmt.Sprintf("invalid header field name %q", string(e))
}
type headerFieldValueError string
func (e headerFieldValueError) Error() string {
return fmt.Sprintf("invalid header field value %q", string(e))
}
var (
errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers")
errPseudoAfterRegular = errors.New("pseudo header field after regular")
)

View File

@@ -11,7 +11,10 @@ import (
"fmt"
"io"
"log"
"strings"
"sync"
"golang.org/x/net/http2/hpack"
)
const frameHeaderLen = 9
@@ -261,7 +264,7 @@ type Frame interface {
type Framer struct {
r io.Reader
lastFrame Frame
errReason string
errDetail error
// lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION.
@@ -293,8 +296,20 @@ type Framer struct {
// to return non-compliant frames or frame orders.
// This is for testing and permits using the Framer to test
// other HTTP/2 implementations' conformance to the spec.
// It is not compatible with ReadMetaHeaders.
AllowIllegalReads bool
// ReadMetaHeaders if non-nil causes ReadFrame to merge
// HEADERS and CONTINUATION frames together and return
// MetaHeadersFrame instead.
ReadMetaHeaders *hpack.Decoder
// MaxHeaderListSize is the http2 MAX_HEADER_LIST_SIZE.
// It's used only if ReadMetaHeaders is set; 0 means a sane default
// (currently 16MB)
// If the limit is hit, MetaHeadersFrame.Truncated is set true.
MaxHeaderListSize uint32
// TODO: track which type of frame & with which flags was sent
// last. Then return an error (unless AllowIllegalWrites) if
// we're in the middle of a header block and a
@@ -307,6 +322,13 @@ type Framer struct {
debugFramerBuf *bytes.Buffer
}
func (fr *Framer) maxHeaderListSize() uint32 {
if fr.MaxHeaderListSize == 0 {
return 16 << 20 // sane default, per docs
}
return fr.MaxHeaderListSize
}
func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) {
// Write the FrameHeader.
f.wbuf = append(f.wbuf[:0],
@@ -402,6 +424,17 @@ func (fr *Framer) SetMaxReadFrameSize(v uint32) {
fr.maxReadSize = v
}
// ErrorDetail returns a more detailed error of the last error
// returned by Framer.ReadFrame. For instance, if ReadFrame
// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail
// will say exactly what was invalid. ErrorDetail is not guaranteed
// to return a non-nil value and like the rest of the http2 package,
// its return value is not protected by an API compatibility promise.
// ErrorDetail is reset after the next call to ReadFrame.
func (fr *Framer) ErrorDetail() error {
return fr.errDetail
}
// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer
// sends a frame that is larger than declared with SetMaxReadFrameSize.
var ErrFrameTooLarge = errors.New("http2: frame too large")
@@ -423,6 +456,7 @@ func terminalReadFrameError(err error) bool {
// ConnectionError, StreamError, or anything else from from the underlying
// reader.
func (fr *Framer) ReadFrame() (Frame, error) {
fr.errDetail = nil
if fr.lastFrame != nil {
fr.lastFrame.invalidate()
}
@@ -450,6 +484,9 @@ func (fr *Framer) ReadFrame() (Frame, error) {
if fr.logReads {
log.Printf("http2: Framer %p: read %v", fr, summarizeFrame(f))
}
if fh.Type == FrameHeaders && fr.ReadMetaHeaders != nil {
return fr.readMetaFrame(f.(*HeadersFrame))
}
return f, nil
}
@@ -458,7 +495,7 @@ func (fr *Framer) ReadFrame() (Frame, error) {
// to the peer before hanging up on them. This might help others debug
// their implementations.
func (fr *Framer) connError(code ErrCode, reason string) error {
fr.errReason = reason
fr.errDetail = errors.New(reason)
return ConnectionError(code)
}
@@ -1225,6 +1262,196 @@ type headersEnder interface {
HeadersEnded() bool
}
type headersOrContinuation interface {
headersEnder
HeaderBlockFragment() []byte
}
// A MetaHeadersFrame is the representation of one HEADERS frame and
// zero or more contiguous CONTINUATION frames and the decoding of
// their HPACK-encoded contents.
//
// This type of frame does not appear on the wire and is only returned
// by the Framer when Framer.ReadMetaHeaders is set.
type MetaHeadersFrame struct {
*HeadersFrame
// Fields are the fields contained in the HEADERS and
// CONTINUATION frames. The underlying slice is owned by the
// Framer and must not be retained after the next call to
// ReadFrame.
//
// Fields are guaranteed to be in the correct http2 order and
// not have unknown pseudo header fields or invalid header
// field names or values. Required pseudo header fields may be
// missing, however. Use the MetaHeadersFrame.Pseudo accessor
// method access pseudo headers.
Fields []hpack.HeaderField
// Truncated is whether the max header list size limit was hit
// and Fields is incomplete. The hpack decoder state is still
// valid, however.
Truncated bool
}
// PseudoValue returns the given pseudo header field's value.
// The provided pseudo field should not contain the leading colon.
func (mh *MetaHeadersFrame) PseudoValue(pseudo string) string {
for _, hf := range mh.Fields {
if !hf.IsPseudo() {
return ""
}
if hf.Name[1:] == pseudo {
return hf.Value
}
}
return ""
}
// RegularFields returns the regular (non-pseudo) header fields of mh.
// The caller does not own the returned slice.
func (mh *MetaHeadersFrame) RegularFields() []hpack.HeaderField {
for i, hf := range mh.Fields {
if !hf.IsPseudo() {
return mh.Fields[i:]
}
}
return nil
}
// PseudoFields returns the pseudo header fields of mh.
// The caller does not own the returned slice.
func (mh *MetaHeadersFrame) PseudoFields() []hpack.HeaderField {
for i, hf := range mh.Fields {
if !hf.IsPseudo() {
return mh.Fields[:i]
}
}
return mh.Fields
}
func (mh *MetaHeadersFrame) checkPseudos() error {
var isRequest, isResponse bool
pf := mh.PseudoFields()
for i, hf := range pf {
switch hf.Name {
case ":method", ":path", ":scheme", ":authority":
isRequest = true
case ":status":
isResponse = true
default:
return pseudoHeaderError(hf.Name)
}
// Check for duplicates.
// This would be a bad algorithm, but N is 4.
// And this doesn't allocate.
for _, hf2 := range pf[:i] {
if hf.Name == hf2.Name {
return duplicatePseudoHeaderError(hf.Name)
}
}
}
if isRequest && isResponse {
return errMixPseudoHeaderTypes
}
return nil
}
func (fr *Framer) maxHeaderStringLen() int {
v := fr.maxHeaderListSize()
if uint32(int(v)) == v {
return int(v)
}
// They had a crazy big number for MaxHeaderBytes anyway,
// so give them unlimited header lengths:
return 0
}
// readMetaFrame returns 0 or more CONTINUATION frames from fr and
// merge them into into the provided hf and returns a MetaHeadersFrame
// with the decoded hpack values.
func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if fr.AllowIllegalReads {
return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders")
}
mh := &MetaHeadersFrame{
HeadersFrame: hf,
}
var remainSize = fr.maxHeaderListSize()
var sawRegular bool
var invalid error // pseudo header field errors
hdec := fr.ReadMetaHeaders
hdec.SetEmitEnabled(true)
hdec.SetMaxStringLength(fr.maxHeaderStringLen())
hdec.SetEmitFunc(func(hf hpack.HeaderField) {
if !validHeaderFieldValue(hf.Value) {
invalid = headerFieldValueError(hf.Value)
}
isPseudo := strings.HasPrefix(hf.Name, ":")
if isPseudo {
if sawRegular {
invalid = errPseudoAfterRegular
}
} else {
sawRegular = true
if !validHeaderFieldName(hf.Name) {
invalid = headerFieldNameError(hf.Name)
}
}
if invalid != nil {
hdec.SetEmitEnabled(false)
return
}
size := hf.Size()
if size > remainSize {
hdec.SetEmitEnabled(false)
mh.Truncated = true
return
}
remainSize -= size
mh.Fields = append(mh.Fields, hf)
})
// Lose reference to MetaHeadersFrame:
defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {})
var hc headersOrContinuation = hf
for {
frag := hc.HeaderBlockFragment()
if _, err := hdec.Write(frag); err != nil {
return nil, ConnectionError(ErrCodeCompression)
}
if hc.HeadersEnded() {
break
}
if f, err := fr.ReadFrame(); err != nil {
return nil, err
} else {
hc = f.(*ContinuationFrame) // guaranteed by checkFrameOrder
}
}
mh.HeadersFrame.headerFragBuf = nil
mh.HeadersFrame.invalidate()
if err := hdec.Close(); err != nil {
return nil, ConnectionError(ErrCodeCompression)
}
if invalid != nil {
fr.errDetail = invalid
return nil, StreamError{mh.StreamID, ErrCodeProtocol}
}
if err := mh.checkPseudos(); err != nil {
fr.errDetail = err
return nil, StreamError{mh.StreamID, ErrCodeProtocol}
}
return mh, nil
}
func summarizeFrame(f Frame) string {
var buf bytes.Buffer
f.Header().writeDebug(&buf)

View File

@@ -12,6 +12,8 @@ import (
"strings"
"testing"
"unsafe"
"golang.org/x/net/http2/hpack"
)
func testFramer() (*Framer, *bytes.Buffer) {
@@ -725,11 +727,249 @@ func TestReadFrameOrder(t *testing.T) {
t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes())
continue
}
if f.errReason != tt.wantErr {
t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errReason, tt.wantErr, log.Bytes())
if !((f.errDetail == nil && tt.wantErr == "") || (fmt.Sprint(f.errDetail) == tt.wantErr)) {
t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errDetail, tt.wantErr, log.Bytes())
}
if n < tt.atLeast {
t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes())
}
}
}
func TestMetaFrameHeader(t *testing.T) {
write := func(f *Framer, frags ...[]byte) {
for i, frag := range frags {
end := (i == len(frags)-1)
if i == 0 {
f.WriteHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: frag,
EndHeaders: end,
})
} else {
f.WriteContinuation(1, end, frag)
}
}
}
want := func(flags Flags, length uint32, pairs ...string) *MetaHeadersFrame {
mh := &MetaHeadersFrame{
HeadersFrame: &HeadersFrame{
FrameHeader: FrameHeader{
Type: FrameHeaders,
Flags: flags,
Length: length,
StreamID: 1,
},
},
Fields: []hpack.HeaderField(nil),
}
for len(pairs) > 0 {
mh.Fields = append(mh.Fields, hpack.HeaderField{
Name: pairs[0],
Value: pairs[1],
})
pairs = pairs[2:]
}
return mh
}
truncated := func(mh *MetaHeadersFrame) *MetaHeadersFrame {
mh.Truncated = true
return mh
}
const noFlags Flags = 0
oneKBString := strings.Repeat("a", 1<<10)
tests := [...]struct {
name string
w func(*Framer)
want interface{} // *MetaHeaderFrame or error
wantErrReason string
maxHeaderListSize uint32
}{
0: {
name: "single_headers",
w: func(f *Framer) {
var he hpackEncoder
all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/")
write(f, all)
},
want: want(FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"),
},
1: {
name: "with_continuation",
w: func(f *Framer) {
var he hpackEncoder
all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar")
write(f, all[:1], all[1:])
},
want: want(noFlags, 1, ":method", "GET", ":path", "/", "foo", "bar"),
},
2: {
name: "with_two_continuation",
w: func(f *Framer) {
var he hpackEncoder
all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar")
write(f, all[:2], all[2:4], all[4:])
},
want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", "bar"),
},
3: {
name: "big_string_okay",
w: func(f *Framer) {
var he hpackEncoder
all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString)
write(f, all[:2], all[2:])
},
want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", oneKBString),
},
4: {
name: "big_string_error",
w: func(f *Framer) {
var he hpackEncoder
all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString)
write(f, all[:2], all[2:])
},
maxHeaderListSize: (1 << 10) / 2,
want: ConnectionError(ErrCodeCompression),
},
5: {
name: "max_header_list_truncated",
w: func(f *Framer) {
var he hpackEncoder
var pairs = []string{":method", "GET", ":path", "/"}
for i := 0; i < 100; i++ {
pairs = append(pairs, "foo", "bar")
}
all := he.encodeHeaderRaw(t, pairs...)
write(f, all[:2], all[2:])
},
maxHeaderListSize: (1 << 10) / 2,
want: truncated(want(noFlags, 2,
":method", "GET",
":path", "/",
"foo", "bar",
"foo", "bar",
"foo", "bar",
"foo", "bar",
"foo", "bar",
"foo", "bar",
"foo", "bar",
"foo", "bar",
"foo", "bar",
"foo", "bar",
"foo", "bar", // 11
)),
},
6: {
name: "pseudo_order",
w: func(f *Framer) {
write(f, encodeHeaderRaw(t,
":method", "GET",
"foo", "bar",
":path", "/", // bogus
))
},
want: StreamError{1, ErrCodeProtocol},
wantErrReason: "pseudo header field after regular",
},
7: {
name: "pseudo_unknown",
w: func(f *Framer) {
write(f, encodeHeaderRaw(t,
":unknown", "foo", // bogus
"foo", "bar",
))
},
want: StreamError{1, ErrCodeProtocol},
wantErrReason: "invalid pseudo-header \":unknown\"",
},
8: {
name: "pseudo_mix_request_response",
w: func(f *Framer) {
write(f, encodeHeaderRaw(t,
":method", "GET",
":status", "100",
))
},
want: StreamError{1, ErrCodeProtocol},
wantErrReason: "mix of request and response pseudo headers",
},
9: {
name: "pseudo_dup",
w: func(f *Framer) {
write(f, encodeHeaderRaw(t,
":method", "GET",
":method", "POST",
))
},
want: StreamError{1, ErrCodeProtocol},
wantErrReason: "duplicate pseudo-header \":method\"",
},
10: {
name: "trailer_okay_no_pseudo",
w: func(f *Framer) { write(f, encodeHeaderRaw(t, "foo", "bar")) },
want: want(FlagHeadersEndHeaders, 8, "foo", "bar"),
},
11: {
name: "invalid_field_name",
w: func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) },
want: StreamError{1, ErrCodeProtocol},
wantErrReason: "invalid header field name \"CapitalBad\"",
},
12: {
name: "invalid_field_value",
w: func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) },
want: StreamError{1, ErrCodeProtocol},
wantErrReason: "invalid header field value \"bad_null\\x00\"",
},
}
for i, tt := range tests {
buf := new(bytes.Buffer)
f := NewFramer(buf, buf)
f.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
f.MaxHeaderListSize = tt.maxHeaderListSize
tt.w(f)
name := tt.name
if name == "" {
name = fmt.Sprintf("test index %d", i)
}
var got interface{}
var err error
got, err = f.ReadFrame()
if err != nil {
got = err
}
if !reflect.DeepEqual(got, tt.want) {
if mhg, ok := got.(*MetaHeadersFrame); ok {
if mhw, ok := tt.want.(*MetaHeadersFrame); ok {
hg := mhg.HeadersFrame
hw := mhw.HeadersFrame
if hg != nil && hw != nil && !reflect.DeepEqual(*hg, *hw) {
t.Errorf("%s: headers differ:\n got: %+v\nwant: %+v\n", name, *hg, *hw)
}
}
}
str := func(v interface{}) string {
if _, ok := v.(error); ok {
return fmt.Sprintf("error %v", v)
} else {
return fmt.Sprintf("value %#v", v)
}
}
t.Errorf("%s:\n got: %v\nwant: %s", name, str(got), str(tt.want))
}
if tt.wantErrReason != "" && tt.wantErrReason != fmt.Sprint(f.errDetail) {
t.Errorf("%s: got error reason %q; want %q", name, f.errDetail, tt.wantErrReason)
}
}
}
func encodeHeaderRaw(t *testing.T, pairs ...string) []byte {
var he hpackEncoder
return he.encodeHeaderRaw(t, pairs...)
}

View File

@@ -28,7 +28,7 @@ import (
"time"
"camlistore.org/pkg/googlestorage"
"camlistore.org/pkg/singleflight"
"go4.org/syncutil/singleflight"
"golang.org/x/net/http2"
)
@@ -79,7 +79,7 @@ is used transparently by the Go standard library from Go 1.6 and later.
</p>
<p>Contact info: <i>bradfitz@golang.org</i>, or <a
href="https://golang.org/issues">file a bug</a>.</p>
href="https://golang.org/s/http2bug">file a bug</a>.</p>
<h2>Handlers for testing</h2>
<ul>
@@ -440,11 +440,43 @@ func serveProd() error {
return <-errc
}
const idleTimeout = 5 * time.Minute
const activeTimeout = 10 * time.Minute
// TODO: put this into the standard library and actually send
// PING frames and GOAWAY, etc: golang.org/issue/14204
func idleTimeoutHook() func(net.Conn, http.ConnState) {
var mu sync.Mutex
m := map[net.Conn]*time.Timer{}
return func(c net.Conn, cs http.ConnState) {
mu.Lock()
defer mu.Unlock()
if t, ok := m[c]; ok {
delete(m, c)
t.Stop()
}
var d time.Duration
switch cs {
case http.StateNew, http.StateIdle:
d = idleTimeout
case http.StateActive:
d = activeTimeout
default:
return
}
m[c] = time.AfterFunc(d, func() {
log.Printf("closing idle conn %v after %v", c.RemoteAddr(), d)
go c.Close()
})
}
}
func main() {
var srv http.Server
flag.BoolVar(&http2.VerboseLogs, "verbose", false, "Verbose HTTP/2 debugging.")
flag.Parse()
srv.Addr = *httpsAddr
srv.ConnState = idleTimeoutHook()
registerHandlers()

View File

@@ -170,9 +170,9 @@ func main() {
},
},
NetworkInterfaces: []*compute.NetworkInterface{
&compute.NetworkInterface{
{
AccessConfigs: []*compute.AccessConfig{
&compute.AccessConfig{
{
Type: "ONE_TO_ONE_NAT",
Name: "External NAT",
NatIP: natIP,

View File

@@ -56,8 +56,8 @@ type command struct {
}
var commands = map[string]command{
"ping": command{run: (*h2i).cmdPing},
"settings": command{
"ping": {run: (*h2i).cmdPing},
"settings": {
run: (*h2i).cmdSettings,
complete: func() []string {
return []string{
@@ -71,14 +71,13 @@ var commands = map[string]command{
}
},
},
"quit": command{run: (*h2i).cmdQuit},
"headers": command{run: (*h2i).cmdHeaders},
"quit": {run: (*h2i).cmdQuit},
"headers": {run: (*h2i).cmdHeaders},
}
func usage() {
fmt.Fprintf(os.Stderr, "Usage: h2i <hostname>\n\n")
flag.PrintDefaults()
os.Exit(1)
}
// withPort adds ":443" if another port isn't already present.
@@ -111,6 +110,7 @@ func main() {
flag.Parse()
if flag.NArg() != 1 {
usage()
os.Exit(2)
}
log.SetFlags(0)

View File

@@ -144,7 +144,7 @@ func (e *Encoder) SetMaxDynamicTableSizeLimit(v uint32) {
// shouldIndex reports whether f should be indexed.
func (e *Encoder) shouldIndex(f HeaderField) bool {
return !f.Sensitive && f.size() <= e.dynTab.maxSize
return !f.Sensitive && f.Size() <= e.dynTab.maxSize
}
// appendIndexed appends index i, as encoded in "Indexed Header Field"

View File

@@ -41,6 +41,14 @@ type HeaderField struct {
Sensitive bool
}
// IsPseudo reports whether the header field is an http2 pseudo header.
// That is, it reports whether it starts with a colon.
// It is not otherwise guaranteed to be a valid psuedo header field,
// though.
func (hf HeaderField) IsPseudo() bool {
return len(hf.Name) != 0 && hf.Name[0] == ':'
}
func (hf HeaderField) String() string {
var suffix string
if hf.Sensitive {
@@ -49,7 +57,8 @@ func (hf HeaderField) String() string {
return fmt.Sprintf("header field %q = %q%s", hf.Name, hf.Value, suffix)
}
func (hf *HeaderField) size() uint32 {
// Size returns the size of an entry per RFC 7540 section 5.2.
func (hf HeaderField) Size() uint32 {
// http://http2.github.io/http2-spec/compression.html#rfc.section.4.1
// "The size of the dynamic table is the sum of the size of
// its entries. The size of an entry is the sum of its name's
@@ -171,7 +180,7 @@ func (dt *dynamicTable) setMaxSize(v uint32) {
func (dt *dynamicTable) add(f HeaderField) {
dt.ents = append(dt.ents, f)
dt.size += f.size()
dt.size += f.Size()
dt.evict()
}
@@ -179,7 +188,7 @@ func (dt *dynamicTable) add(f HeaderField) {
func (dt *dynamicTable) evict() {
base := dt.ents // keep base pointer of slice
for dt.size > dt.maxSize {
dt.size -= dt.ents[0].size()
dt.size -= dt.ents[0].Size()
dt.ents = dt.ents[1:]
}

View File

@@ -17,11 +17,13 @@ package http2
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
@@ -422,3 +424,40 @@ var isTokenTable = [127]bool{
'|': true,
'~': true,
}
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
var sorterPool = sync.Pool{New: func() interface{} { return new(sorter) }}
type sorter struct {
v []string // owned by sorter
}
func (s *sorter) Len() int { return len(s.v) }
func (s *sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] }
func (s *sorter) Less(i, j int) bool { return s.v[i] < s.v[j] }
// Keys returns the sorted keys of h.
//
// The returned slice is only valid until s used again or returned to
// its pool.
func (s *sorter) Keys(h http.Header) []string {
keys := s.v[:0]
for k := range h {
keys = append(keys, k)
}
s.v = keys
sort.Sort(s)
return keys
}
func (s *sorter) SortStrings(ss []string) {
// Our sorter works on s.v, which sorter owners, so
// stash it away while we sort the user's buffer.
save := s.v
s.v = ss
sort.Sort(s)
s.v = save
}

View File

@@ -172,3 +172,27 @@ func cleanDate(res *http.Response) {
d[0] = "XXX"
}
}
func TestSorterPoolAllocs(t *testing.T) {
ss := []string{"a", "b", "c"}
h := http.Header{
"a": nil,
"b": nil,
"c": nil,
}
sorter := new(sorter)
if allocs := testing.AllocsPerRun(100, func() {
sorter.SortStrings(ss)
}); allocs >= 1 {
t.Logf("SortStrings allocs = %v; want <1", allocs)
}
if allocs := testing.AllocsPerRun(5, func() {
if len(sorter.Keys(h)) != 3 {
t.Fatal("wrong result")
}
}); allocs > 0 {
t.Logf("Keys allocs = %v; want <1", allocs)
}
}

View File

@@ -48,6 +48,8 @@ import (
"net/http"
"net/textproto"
"net/url"
"os"
"reflect"
"runtime"
"strconv"
"strings"
@@ -192,28 +194,76 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if testHookOnConn != nil {
testHookOnConn()
}
conf.handleConn(hs, c, h)
conf.ServeConn(c, &ServeConnOpts{
Handler: h,
BaseConfig: hs,
})
}
s.TLSNextProto[NextProtoTLS] = protoHandler
s.TLSNextProto["h2-14"] = protoHandler // temporary; see above.
return nil
}
func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
// ServeConnOpts are options for the Server.ServeConn method.
type ServeConnOpts struct {
// BaseConfig optionally sets the base configuration
// for values. If nil, defaults are used.
BaseConfig *http.Server
// Handler specifies which handler to use for processing
// requests. If nil, BaseConfig.Handler is used. If BaseConfig
// or BaseConfig.Handler is nil, http.DefaultServeMux is used.
Handler http.Handler
}
func (o *ServeConnOpts) baseConfig() *http.Server {
if o != nil && o.BaseConfig != nil {
return o.BaseConfig
}
return new(http.Server)
}
func (o *ServeConnOpts) handler() http.Handler {
if o != nil {
if o.Handler != nil {
return o.Handler
}
if o.BaseConfig != nil && o.BaseConfig.Handler != nil {
return o.BaseConfig.Handler
}
}
return http.DefaultServeMux
}
// ServeConn serves HTTP/2 requests on the provided connection and
// blocks until the connection is no longer readable.
//
// ServeConn starts speaking HTTP/2 assuming that c has not had any
// reads or writes. It writes its initial settings frame and expects
// to be able to read the preface and settings frame from the
// client. If c has a ConnectionState method like a *tls.Conn, the
// ConnectionState is used to verify the TLS ciphersuite and to set
// the Request.TLS field in Handlers.
//
// ServeConn does not support h2c by itself. Any h2c support must be
// implemented in terms of providing a suitably-behaving net.Conn.
//
// The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc := &serverConn{
srv: srv,
hs: hs,
srv: s,
hs: opts.baseConfig(),
conn: c,
remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(c),
handler: h,
handler: opts.handler(),
streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult),
wantWriteFrameCh: make(chan frameWriteMsg, 8),
wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}),
advMaxStreams: srv.maxConcurrentStreams(),
advMaxStreams: s.maxConcurrentStreams(),
writeSched: writeScheduler{
maxFrameSize: initialMaxFrameSize,
},
@@ -225,14 +275,14 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
sc.flow.add(initialWindowSize)
sc.inflow.add(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, nil)
sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
fr := NewFramer(sc.bw, c)
fr.SetMaxReadFrameSize(srv.maxReadFrameSize())
fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize())
sc.framer = fr
if tc, ok := c.(*tls.Conn); ok {
if tc, ok := c.(connectionStater); ok {
sc.tlsState = new(tls.ConnectionState)
*sc.tlsState = tc.ConnectionState()
// 9.2 Use of TLS Features
@@ -262,7 +312,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
// So for now, do nothing here again.
}
if !srv.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated."
@@ -324,7 +374,6 @@ type serverConn struct {
bw *bufferedWriter // writing to conn
handler http.Handler
framer *Framer
hpackDecoder *hpack.Decoder
doneServing chan struct{} // closed when serverConn.serve ends
readFrameCh chan readFrameResult // written by serverConn.readFrames
wantWriteFrameCh chan frameWriteMsg // from handlers -> serve
@@ -351,7 +400,6 @@ type serverConn struct {
headerTableSize uint32
peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
req requestParam // non-zero while reading request headers
writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh
needsFrameFlush bool // last frame write wasn't a flush
writeSched writeScheduler
@@ -360,22 +408,13 @@ type serverConn struct {
goAwayCode ErrCode
shutdownTimerCh <-chan time.Time // nil until used
shutdownTimer *time.Timer // nil until used
freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
hpackEncoder *hpack.Encoder
}
func (sc *serverConn) maxHeaderStringLen() int {
v := sc.maxHeaderListSize()
if uint32(int(v)) == v {
return int(v)
}
// They had a crazy big number for MaxHeaderBytes anyway,
// so give them unlimited header lengths:
return 0
}
func (sc *serverConn) maxHeaderListSize() uint32 {
n := sc.hs.MaxHeaderBytes
if n <= 0 {
@@ -388,21 +427,6 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
return uint32(n + typicalHeaders*perFieldOverhead)
}
// requestParam is the state of the next request, initialized over
// potentially several frames HEADERS + zero or more CONTINUATION
// frames.
type requestParam struct {
// stream is non-nil if we're reading (HEADER or CONTINUATION)
// frames for a request (but not DATA).
stream *stream
header http.Header
method, path string
scheme, authority string
sawRegularHeader bool // saw a non-pseudo header already
invalidHeader bool // an invalid header was seen
headerListSize int64 // actually uint32, but easier math this way
}
// stream represents a stream. This is the minimal metadata needed by
// the serve goroutine. Most of the actual stream state is owned by
// the http.Handler's goroutine in the responseWriter. Because the
@@ -429,6 +453,7 @@ type stream struct {
sentReset bool // only true once detached from streams map
gotReset bool // only true once detacted from streams map
gotTrailerHeader bool // HEADER frame for trailers was seen
reqBuf []byte
trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer
@@ -482,12 +507,55 @@ func (sc *serverConn) logf(format string, args ...interface{}) {
}
}
// errno returns v's underlying uintptr, else 0.
//
// TODO: remove this helper function once http2 can use build
// tags. See comment in isClosedConnError.
func errno(v error) uintptr {
if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr {
return uintptr(rv.Uint())
}
return 0
}
// isClosedConnError reports whether err is an error from use of a closed
// network connection.
func isClosedConnError(err error) bool {
if err == nil {
return false
}
// TODO: remove this string search and be more like the Windows
// case below. That might involve modifying the standard library
// to return better error types.
str := err.Error()
if strings.Contains(str, "use of closed network connection") {
return true
}
// TODO(bradfitz): x/tools/cmd/bundle doesn't really support
// build tags, so I can't make an http2_windows.go file with
// Windows-specific stuff. Fix that and move this, once we
// have a way to bundle this into std's net/http somehow.
if runtime.GOOS == "windows" {
if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" {
const WSAECONNABORTED = 10053
const WSAECONNRESET = 10054
if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED {
return true
}
}
}
}
return false
}
func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
if err == nil {
return
}
str := err.Error()
if err == io.EOF || strings.Contains(str, "use of closed network connection") {
if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) {
// Boring, expected errors.
sc.vlogf(format, args...)
} else {
@@ -495,87 +563,6 @@ func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
}
}
func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
sc.serveG.check()
if VerboseLogs {
sc.vlogf("http2: server decoded %v", f)
}
switch {
case !validHeaderFieldValue(f.Value): // f.Name checked _after_ pseudo check, since ':' is invalid
sc.req.invalidHeader = true
case strings.HasPrefix(f.Name, ":"):
if sc.req.sawRegularHeader {
sc.logf("pseudo-header after regular header")
sc.req.invalidHeader = true
return
}
var dst *string
switch f.Name {
case ":method":
dst = &sc.req.method
case ":path":
dst = &sc.req.path
case ":scheme":
dst = &sc.req.scheme
case ":authority":
dst = &sc.req.authority
default:
// 8.1.2.1 Pseudo-Header Fields
// "Endpoints MUST treat a request or response
// that contains undefined or invalid
// pseudo-header fields as malformed (Section
// 8.1.2.6)."
sc.logf("invalid pseudo-header %q", f.Name)
sc.req.invalidHeader = true
return
}
if *dst != "" {
sc.logf("duplicate pseudo-header %q sent", f.Name)
sc.req.invalidHeader = true
return
}
*dst = f.Value
case !validHeaderFieldName(f.Name):
sc.req.invalidHeader = true
default:
sc.req.sawRegularHeader = true
sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value)
const headerFieldOverhead = 32 // per spec
sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
if sc.req.headerListSize > int64(sc.maxHeaderListSize()) {
sc.hpackDecoder.SetEmitEnabled(false)
}
}
}
func (st *stream) onNewTrailerField(f hpack.HeaderField) {
sc := st.sc
sc.serveG.check()
if VerboseLogs {
sc.vlogf("http2: server decoded trailer %v", f)
}
switch {
case strings.HasPrefix(f.Name, ":"):
sc.req.invalidHeader = true
return
case !validHeaderFieldName(f.Name) || !validHeaderFieldValue(f.Value):
sc.req.invalidHeader = true
return
default:
key := sc.canonicalHeader(f.Name)
if st.trailer != nil {
vv := append(st.trailer[key], f.Value)
st.trailer[key] = vv
// arbitrary; TODO: read spec about header list size limits wrt trailers
const tooBig = 1000
if len(vv) >= tooBig {
sc.hpackDecoder.SetEmitEnabled(false)
}
}
}
}
func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check()
cv, ok := commonCanonHeader[v]
@@ -610,10 +597,11 @@ type readFrameResult struct {
// It's run on its own goroutine.
func (sc *serverConn) readFrames() {
gate := make(gate)
gateDone := gate.Done
for {
f, err := sc.framer.ReadFrame()
select {
case sc.readFrameCh <- readFrameResult{f, err, gate.Done}:
case sc.readFrameCh <- readFrameResult{f, err, gateDone}:
case <-sc.doneServing:
return
}
@@ -1031,7 +1019,7 @@ func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
sc.goAway(ErrCodeFrameSize)
return true // goAway will close the loop
}
clientGone := err == io.EOF || strings.Contains(err.Error(), "use of closed network connection")
clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err)
if clientGone {
// TODO: could we also get into this state if
// the peer does a half close
@@ -1067,7 +1055,7 @@ func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
return true // goAway will handle shutdown
default:
if res.err != nil {
sc.logf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
} else {
sc.logf("http2: server closing client connection: %v", err)
}
@@ -1089,10 +1077,8 @@ func (sc *serverConn) processFrame(f Frame) error {
switch f := f.(type) {
case *SettingsFrame:
return sc.processSettings(f)
case *HeadersFrame:
case *MetaHeadersFrame:
return sc.processHeaders(f)
case *ContinuationFrame:
return sc.processContinuation(f)
case *WindowUpdateFrame:
return sc.processWindowUpdate(f)
case *PingFrame:
@@ -1192,6 +1178,18 @@ func (sc *serverConn) closeStream(st *stream, err error) {
}
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.forgetStream(st.id)
if st.reqBuf != nil {
// Stash this request body buffer (64k) away for reuse
// by a future POST/PUT/etc.
//
// TODO(bradfitz): share on the server? sync.Pool?
// Server requires locks and might hurt contention.
// sync.Pool might work, or might be worse, depending
// on goroutine CPU migrations. (get and put on
// separate CPUs). Maybe a mix of strategies. But
// this is an easy win for now.
sc.freeRequestBodyBuf = st.reqBuf
}
}
func (sc *serverConn) processSettings(f *SettingsFrame) error {
@@ -1348,7 +1346,7 @@ func (st *stream) copyTrailersToHandlerRequest() {
}
}
func (sc *serverConn) processHeaders(f *HeadersFrame) error {
func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
sc.serveG.check()
id := f.Header().StreamID
if sc.inGoAway {
@@ -1377,13 +1375,11 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
// endpoint has opened or reserved. [...] An endpoint that
// receives an unexpected stream identifier MUST respond with
// a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
if id <= sc.maxStreamID || sc.req.stream != nil {
if id <= sc.maxStreamID {
return ConnectionError(ErrCodeProtocol)
}
sc.maxStreamID = id
if id > sc.maxStreamID {
sc.maxStreamID = id
}
st = &stream{
sc: sc,
id: id,
@@ -1407,50 +1403,6 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
if sc.curOpenStreams == 1 {
sc.setConnState(http.StateActive)
}
sc.req = requestParam{
stream: st,
header: make(http.Header),
}
sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField)
sc.hpackDecoder.SetEmitEnabled(true)
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
}
func (st *stream) processTrailerHeaders(f *HeadersFrame) error {
sc := st.sc
sc.serveG.check()
if st.gotTrailerHeader {
return ConnectionError(ErrCodeProtocol)
}
st.gotTrailerHeader = true
if !f.StreamEnded() {
return StreamError{st.id, ErrCodeProtocol}
}
sc.resetPendingRequest() // we use invalidHeader from it for trailers
return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
}
func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
sc.serveG.check()
st := sc.streams[f.Header().StreamID]
if st.gotTrailerHeader {
return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
}
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
}
func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bool) error {
sc.serveG.check()
if _, err := sc.hpackDecoder.Write(frag); err != nil {
return ConnectionError(ErrCodeCompression)
}
if !end {
return nil
}
if err := sc.hpackDecoder.Close(); err != nil {
return ConnectionError(ErrCodeCompression)
}
defer sc.resetPendingRequest()
if sc.curOpenStreams > sc.advMaxStreams {
// "Endpoints MUST NOT exceed the limit set by their
// peer. An endpoint that receives a HEADERS frame
@@ -1470,7 +1422,7 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
return StreamError{st.id, ErrCodeRefusedStream}
}
rw, req, err := sc.newWriterAndRequest()
rw, req, err := sc.newWriterAndRequest(st, f)
if err != nil {
return err
}
@@ -1482,7 +1434,7 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
st.declBodyBytes = req.ContentLength
handler := sc.handler.ServeHTTP
if !sc.hpackDecoder.EmitEnabled() {
if f.Truncated {
// Their header list was too long. Send a 431 error.
handler = handleHeaderListTooLong
}
@@ -1491,27 +1443,27 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
return nil
}
func (st *stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error {
func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
sc := st.sc
sc.serveG.check()
sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField)
if _, err := sc.hpackDecoder.Write(frag); err != nil {
return ConnectionError(ErrCodeCompression)
if st.gotTrailerHeader {
return ConnectionError(ErrCodeProtocol)
}
if !end {
return nil
st.gotTrailerHeader = true
if !f.StreamEnded() {
return StreamError{st.id, ErrCodeProtocol}
}
rp := &sc.req
if rp.invalidHeader {
return StreamError{rp.stream.id, ErrCodeProtocol}
if len(f.PseudoFields()) > 0 {
return StreamError{st.id, ErrCodeProtocol}
}
if st.trailer != nil {
for _, hf := range f.RegularFields() {
key := sc.canonicalHeader(hf.Name)
st.trailer[key] = append(st.trailer[key], hf.Value)
}
}
err := sc.hpackDecoder.Close()
st.endStream()
if err != nil {
return ConnectionError(ErrCodeCompression)
}
return nil
}
@@ -1556,29 +1508,21 @@ func adjustStreamPriority(streams map[uint32]*stream, streamID uint32, priority
}
}
// resetPendingRequest zeros out all state related to a HEADERS frame
// and its zero or more CONTINUATION frames sent to start a new
// request.
func (sc *serverConn) resetPendingRequest() {
func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
sc.serveG.check()
sc.req = requestParam{}
}
func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) {
sc.serveG.check()
rp := &sc.req
method := f.PseudoValue("method")
path := f.PseudoValue("path")
scheme := f.PseudoValue("scheme")
authority := f.PseudoValue("authority")
if rp.invalidHeader {
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
}
isConnect := rp.method == "CONNECT"
isConnect := method == "CONNECT"
if isConnect {
if rp.path != "" || rp.scheme != "" || rp.authority == "" {
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
if path != "" || scheme != "" || authority == "" {
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
}
} else if rp.method == "" || rp.path == "" ||
(rp.scheme != "https" && rp.scheme != "http") {
} else if method == "" || path == "" ||
(scheme != "https" && scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
//
// Malformed requests or responses that are detected
@@ -1589,35 +1533,40 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
// "All HTTP/2 requests MUST include exactly one valid
// value for the :method, :scheme, and :path
// pseudo-header fields"
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
}
bodyOpen := rp.stream.state == stateOpen
if rp.method == "HEAD" && bodyOpen {
bodyOpen := !f.StreamEnded()
if method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
}
var tlsState *tls.ConnectionState // nil if not scheme https
if rp.scheme == "https" {
if scheme == "https" {
tlsState = sc.tlsState
}
authority := rp.authority
if authority == "" {
authority = rp.header.Get("Host")
header := make(http.Header)
for _, hf := range f.RegularFields() {
header.Add(sc.canonicalHeader(hf.Name), hf.Value)
}
needsContinue := rp.header.Get("Expect") == "100-continue"
if authority == "" {
authority = header.Get("Host")
}
needsContinue := header.Get("Expect") == "100-continue"
if needsContinue {
rp.header.Del("Expect")
header.Del("Expect")
}
// Merge Cookie headers into one "; "-delimited value.
if cookies := rp.header["Cookie"]; len(cookies) > 1 {
rp.header.Set("Cookie", strings.Join(cookies, "; "))
if cookies := header["Cookie"]; len(cookies) > 1 {
header.Set("Cookie", strings.Join(cookies, "; "))
}
// Setup Trailers
var trailer http.Header
for _, v := range rp.header["Trailer"] {
for _, v := range header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = http.CanonicalHeaderKey(strings.TrimSpace(key))
switch key {
@@ -1632,31 +1581,31 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
}
}
}
delete(rp.header, "Trailer")
delete(header, "Trailer")
body := &requestBody{
conn: sc,
stream: rp.stream,
stream: st,
needsContinue: needsContinue,
}
var url_ *url.URL
var requestURI string
if isConnect {
url_ = &url.URL{Host: rp.authority}
requestURI = rp.authority // mimic HTTP/1 server behavior
url_ = &url.URL{Host: authority}
requestURI = authority // mimic HTTP/1 server behavior
} else {
var err error
url_, err = url.ParseRequestURI(rp.path)
url_, err = url.ParseRequestURI(path)
if err != nil {
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
}
requestURI = rp.path
requestURI = path
}
req := &http.Request{
Method: rp.method,
Method: method,
URL: url_,
RemoteAddr: sc.remoteAddrStr,
Header: rp.header,
Header: header,
RequestURI: requestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
@@ -1667,11 +1616,12 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
Trailer: trailer,
}
if bodyOpen {
st.reqBuf = sc.getRequestBodyBuf()
body.pipe = &pipe{
b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: garbage
b: &fixedBuffer{buf: st.reqBuf},
}
if vv, ok := rp.header["Content-Length"]; ok {
if vv, ok := header["Content-Length"]; ok {
req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
} else {
req.ContentLength = -1
@@ -1684,7 +1634,7 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
rws.conn = sc
rws.bw = bwSave
rws.bw.Reset(chunkWriter{rws})
rws.stream = rp.stream
rws.stream = st
rws.req = req
rws.body = body
@@ -1692,6 +1642,15 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
return rw, req, nil
}
func (sc *serverConn) getRequestBodyBuf() []byte {
sc.serveG.check()
if buf := sc.freeRequestBodyBuf; buf != nil {
sc.freeRequestBodyBuf = nil
return buf
}
return make([]byte, initialWindowSize)
}
// Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
didPanic := true
@@ -1928,7 +1887,9 @@ func (rws *responseWriterState) declareTrailer(k string) {
// Forbidden by RFC 2616 14.40.
return
}
rws.trailers = append(rws.trailers, k)
if !strSliceContains(rws.trailers, k) {
rws.trailers = append(rws.trailers, k)
}
}
// writeChunk writes chunks from the bufio.Writer. But because
@@ -1996,6 +1957,10 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
return 0, nil
}
if rws.handlerDone {
rws.promoteUndeclaredTrailers()
}
endStream := rws.handlerDone && !rws.hasTrailers()
if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream.
@@ -2016,6 +1981,58 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
return len(p), nil
}
// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
// that, if present, signals that the map entry is actually for
// the response trailers, and not the response headers. The prefix
// is stripped after the ServeHTTP call finishes and the values are
// sent in the trailers.
//
// This mechanism is intended only for trailers that are not known
// prior to the headers being written. If the set of trailers is fixed
// or known before the header is written, the normal Go trailers mechanism
// is preferred:
// https://golang.org/pkg/net/http/#ResponseWriter
// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
const TrailerPrefix = "Trailer:"
// promoteUndeclaredTrailers permits http.Handlers to set trailers
// after the header has already been flushed. Because the Go
// ResponseWriter interface has no way to set Trailers (only the
// Header), and because we didn't want to expand the ResponseWriter
// interface, and because nobody used trailers, and because RFC 2616
// says you SHOULD (but not must) predeclare any trailers in the
// header, the official ResponseWriter rules said trailers in Go must
// be predeclared, and then we reuse the same ResponseWriter.Header()
// map to mean both Headers and Trailers. When it's time to write the
// Trailers, we pick out the fields of Headers that were declared as
// trailers. That worked for a while, until we found the first major
// user of Trailers in the wild: gRPC (using them only over http2),
// and gRPC libraries permit setting trailers mid-stream without
// predeclarnig them. So: change of plans. We still permit the old
// way, but we also permit this hack: if a Header() key begins with
// "Trailer:", the suffix of that key is a Trailer. Because ':' is an
// invalid token byte anyway, there is no ambiguity. (And it's already
// filtered out) It's mildly hacky, but not terrible.
//
// This method runs after the Handler is done and promotes any Header
// fields to be trailers.
func (rws *responseWriterState) promoteUndeclaredTrailers() {
for k, vv := range rws.handlerHeader {
if !strings.HasPrefix(k, TrailerPrefix) {
continue
}
trailerKey := strings.TrimPrefix(k, TrailerPrefix)
rws.declareTrailer(trailerKey)
rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv
}
if len(rws.trailers) > 1 {
sorter := sorterPool.Get().(*sorter)
sorter.SortStrings(rws.trailers)
sorterPool.Put(sorter)
}
}
func (w *responseWriter) Flush() {
rws := w.rws
if rws == nil {

View File

@@ -204,6 +204,13 @@ func (st *serverTester) awaitIdle() {
}
func (st *serverTester) Close() {
if st.t.Failed() {
// If we failed already (and are likely in a Fatal,
// unwindowing), force close the connection, so the
// httptest.Server doesn't wait forever for the conn
// to close.
st.cc.Close()
}
st.ts.Close()
if st.cc != nil {
st.cc.Close()
@@ -2515,7 +2522,7 @@ func TestCompressionErrorOnWrite(t *testing.T) {
defer st.Close()
st.greet()
maxAllowed := st.sc.maxHeaderStringLen()
maxAllowed := st.sc.framer.maxHeaderStringLen()
// Crank this up, now that we have a conn connected with the
// hpack.Decoder's max string length set has been initialized
@@ -2524,8 +2531,12 @@ func TestCompressionErrorOnWrite(t *testing.T) {
// the max string size.
serverConfig.MaxHeaderBytes = 1 << 20
// First a request with a header that's exactly the max allowed size.
// First a request with a header that's exactly the max allowed size
// for the hpack compression. It's still too long for the header list
// size, so we'll get the 431 error, but that keeps the compression
// context still valid.
hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: hbf,
@@ -2533,8 +2544,24 @@ func TestCompressionErrorOnWrite(t *testing.T) {
EndHeaders: true,
})
h := st.wantHeaders()
if !h.HeadersEnded() || !h.StreamEnded() {
t.Errorf("Unexpected HEADER frame %v", h)
if !h.HeadersEnded() {
t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
}
headers := st.decodeHeader(h.HeaderBlockFragment())
want := [][2]string{
{":status", "431"},
{"content-type", "text/html; charset=utf-8"},
{"content-length", "63"},
}
if !reflect.DeepEqual(headers, want) {
t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
}
df := st.wantData()
if !strings.Contains(string(df.Data()), "HTTP Error 431") {
t.Errorf("Unexpected data body: %q", df.Data())
}
if !df.StreamEnded() {
t.Fatalf("expect data stream end")
}
// And now send one that's just one byte too big.
@@ -2647,7 +2674,13 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) {
}
w.Header().Set("Server-Trailer-A", "valuea")
w.Header().Set("Server-Trailer-C", "valuec") // skipping B
// After a flush, random keys like Server-Surprise shouldn't show up:
w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
// But we do permit promoting keys to trailers after a
// flush if they start with the magic
// otherwise-invalid "Trailer:" prefix:
w.Header().Set("Trailer:Post-Header-Trailer", "hi1")
w.Header().Set("Trailer:post-header-trailer2", "hi2")
w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 2616 14.40")
w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 2616 14.40")
w.Header().Set("Trailer", "should not be included; Forbidden by RFC 2616 14.40")
@@ -2689,6 +2722,8 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) {
t.Fatalf("trailers HEADERS lacked END_HEADERS")
}
wanth = [][2]string{
{"post-header-trailer", "hi1"},
{"post-header-trailer2", "hi2"},
{"server-trailer-a", "valuea"},
{"server-trailer-c", "valuec"},
}
@@ -2699,7 +2734,39 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) {
})
}
// validate transmitted header field names & values
// golang.org/issue/14048
func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
w.Header().Add("OK1", "x")
w.Header().Add("Bad:Colon", "x") // colon (non-token byte) in key
w.Header().Add("Bad1\x00", "x") // null in key
w.Header().Add("Bad2", "x\x00y") // null in value
return nil
}, func(st *serverTester) {
getSlash(st)
hf := st.wantHeaders()
if !hf.StreamEnded() {
t.Error("response HEADERS lacked END_STREAM")
}
if !hf.HeadersEnded() {
t.Fatal("response HEADERS didn't have END_HEADERS")
}
goth := st.decodeHeader(hf.HeaderBlockFragment())
wanth := [][2]string{
{":status", "200"},
{"ok1", "x"},
{"content-type", "text/plain; charset=utf-8"},
{"content-length", "0"},
}
if !reflect.DeepEqual(goth, wanth) {
t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
}
})
}
func BenchmarkServerGets(b *testing.B) {
defer disableGoroutineTracking()()
b.ReportAllocs()
const msg = "Hello, world"
@@ -2731,6 +2798,7 @@ func BenchmarkServerGets(b *testing.B) {
}
func BenchmarkServerPosts(b *testing.B) {
defer disableGoroutineTracking()()
b.ReportAllocs()
const msg = "Hello, world"
@@ -2769,12 +2837,16 @@ func TestIssue53(t *testing.T) {
"\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
s := &http.Server{
ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags),
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("hello"))
}),
}
s2 := &Server{
MaxReadFrameSize: 1 << 16,
PermitProhibitedCipherSuites: true,
}
s2 := &Server{MaxReadFrameSize: 1 << 16, PermitProhibitedCipherSuites: true}
c := &issue53Conn{[]byte(data), false, false}
s2.handleConn(s, c, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("hello"))
}))
s2.ServeConn(c, &ServeConnOpts{BaseConfig: s})
if !c.closed {
t.Fatal("connection is not closed")
}
@@ -2938,3 +3010,172 @@ func TestServerNoDuplicateContentType(t *testing.T) {
t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
}
}
func disableGoroutineTracking() (restore func()) {
old := DebugGoroutines
DebugGoroutines = false
return func() { DebugGoroutines = old }
}
func BenchmarkServer_GetRequest(b *testing.B) {
defer disableGoroutineTracking()()
b.ReportAllocs()
const msg = "Hello, world."
st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
n, err := io.Copy(ioutil.Discard, r.Body)
if err != nil || n > 0 {
b.Error("Read %d bytes, error %v; want 0 bytes.", n, err)
}
io.WriteString(w, msg)
})
defer st.Close()
st.greet()
// Give the server quota to reply. (plus it has the the 64KB)
if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
b.Fatal(err)
}
hbf := st.encodeHeader(":method", "GET")
for i := 0; i < b.N; i++ {
streamID := uint32(1 + 2*i)
st.writeHeaders(HeadersFrameParam{
StreamID: streamID,
BlockFragment: hbf,
EndStream: true,
EndHeaders: true,
})
st.wantHeaders()
st.wantData()
}
}
func BenchmarkServer_PostRequest(b *testing.B) {
defer disableGoroutineTracking()()
b.ReportAllocs()
const msg = "Hello, world."
st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
n, err := io.Copy(ioutil.Discard, r.Body)
if err != nil || n > 0 {
b.Error("Read %d bytes, error %v; want 0 bytes.", n, err)
}
io.WriteString(w, msg)
})
defer st.Close()
st.greet()
// Give the server quota to reply. (plus it has the the 64KB)
if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
b.Fatal(err)
}
hbf := st.encodeHeader(":method", "POST")
for i := 0; i < b.N; i++ {
streamID := uint32(1 + 2*i)
st.writeHeaders(HeadersFrameParam{
StreamID: streamID,
BlockFragment: hbf,
EndStream: false,
EndHeaders: true,
})
st.writeData(streamID, true, nil)
st.wantHeaders()
st.wantData()
}
}
type connStateConn struct {
net.Conn
cs tls.ConnectionState
}
func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs }
// golang.org/issue/12737 -- handle any net.Conn, not just
// *tls.Conn.
func TestServerHandleCustomConn(t *testing.T) {
var s Server
c1, c2 := net.Pipe()
clientDone := make(chan struct{})
handlerDone := make(chan struct{})
var req *http.Request
go func() {
defer close(clientDone)
defer c2.Close()
fr := NewFramer(c2, c2)
io.WriteString(c2, ClientPreface)
fr.WriteSettings()
fr.WriteSettingsAck()
f, err := fr.ReadFrame()
if err != nil {
t.Error(err)
return
}
if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() {
t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f))
return
}
f, err = fr.ReadFrame()
if err != nil {
t.Error(err)
return
}
if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() {
t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f))
return
}
var henc hpackEncoder
fr.WriteHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"),
EndStream: true,
EndHeaders: true,
})
go io.Copy(ioutil.Discard, c2)
<-handlerDone
}()
const testString = "my custom ConnectionState"
fakeConnState := tls.ConnectionState{
ServerName: testString,
Version: tls.VersionTLS12,
}
go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{
BaseConfig: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(handlerDone)
req = r
}),
}})
select {
case <-clientDone:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for handler")
}
if req.TLS == nil {
t.Fatalf("Request.TLS is nil. Got: %#v", req)
}
if req.TLS.ServerName != testString {
t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString)
}
}
type hpackEncoder struct {
enc *hpack.Encoder
buf bytes.Buffer
}
func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte {
if len(headers)%2 == 1 {
panic("odd number of kv args")
}
he.buf.Reset()
if he.enc == nil {
he.enc = hpack.NewEncoder(&he.buf)
}
for len(headers) > 0 {
k, v := headers[0], headers[1]
err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
if err != nil {
t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
}
headers = headers[2:]
}
return he.buf.Bytes()
}

View File

@@ -187,8 +187,8 @@ type clientStream struct {
done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu
// owned by clientConnReadLoop:
pastHeaders bool // got HEADERS w/ END_HEADERS
pastTrailers bool // got second HEADERS frame w/ END_HEADERS
pastHeaders bool // got first MetaHeadersFrame (actual headers)
pastTrailers bool // got optional second MetaHeadersFrame (trailers)
trailer http.Header // accumulated trailers
resTrailer *http.Header // client's Response.Trailer
@@ -333,8 +333,12 @@ func (t *Transport) newTLSConfig(host string) *tls.Config {
if t.TLSClientConfig != nil {
*cfg = *t.TLSClientConfig
}
cfg.NextProtos = []string{NextProtoTLS} // TODO: don't override if already in list
cfg.ServerName = host
if !strSliceContains(cfg.NextProtos, NextProtoTLS) {
cfg.NextProtos = append([]string{NextProtoTLS}, cfg.NextProtos...)
}
if cfg.ServerName == "" {
cfg.ServerName = host
}
return cfg
}
@@ -401,22 +405,21 @@ func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
cc.bw = bufio.NewWriter(stickyErrWriter{c, &cc.werr})
cc.br = bufio.NewReader(c)
cc.fr = NewFramer(cc.bw, cc.br)
cc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
// TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on
// henc in response to SETTINGS frames?
cc.henc = hpack.NewEncoder(&cc.hbuf)
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
if cs, ok := c.(connectionStater); ok {
state := cs.ConnectionState()
cc.tlsState = &state
}
initialSettings := []Setting{
Setting{ID: SettingEnablePush, Val: 0},
Setting{ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
{ID: SettingEnablePush, Val: 0},
{ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
}
if max := t.maxHeaderListSize(); max != 0 {
initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max})
@@ -565,7 +568,27 @@ func (cc *ClientConn) responseHeaderTimeout() time.Duration {
return 0
}
// checkConnHeaders checks whether req has any invalid connection-level headers.
// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields.
// Certain headers are special-cased as okay but not transmitted later.
func checkConnHeaders(req *http.Request) error {
if v := req.Header.Get("Upgrade"); v != "" {
return errors.New("http2: invalid Upgrade request header")
}
if v := req.Header.Get("Transfer-Encoding"); (v != "" && v != "chunked") || len(req.Header["Transfer-Encoding"]) > 1 {
return errors.New("http2: invalid Transfer-Encoding request header")
}
if v := req.Header.Get("Connection"); (v != "" && v != "close" && v != "keep-alive") || len(req.Header["Connection"]) > 1 {
return errors.New("http2: invalid Connection request header")
}
return nil
}
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
if err := checkConnHeaders(req); err != nil {
return nil, err
}
trailers, err := commaSeparatedTrailers(req)
if err != nil {
return nil, err
@@ -917,13 +940,24 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
var didUA bool
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
if lowKey == "host" || lowKey == "content-length" {
switch lowKey {
case "host", "content-length":
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
}
if lowKey == "user-agent" {
case "connection", "proxy-connection", "transfer-encoding", "upgrade":
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We deal with these earlier in
// RoundTrip, deciding whether they're
// error-worthy, but we don't want to mutate
// the user's *Request so at this point, just
// skip over them at this point.
continue
case "user-agent":
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
@@ -1033,17 +1067,9 @@ func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream {
// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
type clientConnReadLoop struct {
cc *ClientConn
activeRes map[uint32]*clientStream // keyed by streamID
hdec *hpack.Decoder
// Fields reset on each HEADERS:
nextRes *http.Response
sawRegHeader bool // saw non-pseudo header
reqMalformed error // non-nil once known to be malformed
lastHeaderEndsStream bool
headerListSize int64 // actually uint32, but easier math this way
cc *ClientConn
activeRes map[uint32]*clientStream // keyed by streamID
closeWhenIdle bool
}
// readLoop runs in its own goroutine and reads and dispatches frames.
@@ -1052,7 +1078,6 @@ func (cc *ClientConn) readLoop() {
cc: cc,
activeRes: make(map[uint32]*clientStream),
}
rl.hdec = hpack.NewDecoder(initialHeaderTableSize, rl.onNewHeaderField)
defer rl.cleanup()
cc.readerErr = rl.run()
@@ -1094,7 +1119,7 @@ func (rl *clientConnReadLoop) cleanup() {
func (rl *clientConnReadLoop) run() error {
cc := rl.cc
closeWhenIdle := cc.t.disableKeepAlives()
rl.closeWhenIdle = cc.t.disableKeepAlives()
gotReply := false // ever saw a reply
for {
f, err := cc.fr.ReadFrame()
@@ -1102,8 +1127,10 @@ func (rl *clientConnReadLoop) run() error {
cc.vlogf("Transport readFrame error: (%T) %v", err, err)
}
if se, ok := err.(StreamError); ok {
// TODO: deal with stream errors from the framer.
return se
if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil {
rl.endStreamError(cs, cc.fr.errDetail)
}
continue
} else if err != nil {
return err
}
@@ -1113,13 +1140,10 @@ func (rl *clientConnReadLoop) run() error {
maybeIdle := false // whether frame might transition us to idle
switch f := f.(type) {
case *HeadersFrame:
case *MetaHeadersFrame:
err = rl.processHeaders(f)
maybeIdle = true
gotReply = true
case *ContinuationFrame:
err = rl.processContinuation(f)
maybeIdle = true
case *DataFrame:
err = rl.processData(f)
maybeIdle = true
@@ -1143,97 +1167,102 @@ func (rl *clientConnReadLoop) run() error {
if err != nil {
return err
}
if closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
if rl.closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
cc.closeIfIdle()
}
}
}
func (rl *clientConnReadLoop) processHeaders(f *HeadersFrame) error {
rl.sawRegHeader = false
rl.reqMalformed = nil
rl.lastHeaderEndsStream = f.StreamEnded()
rl.headerListSize = 0
rl.nextRes = &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: make(http.Header),
}
rl.hdec.SetEmitEnabled(true)
return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
}
func (rl *clientConnReadLoop) processContinuation(f *ContinuationFrame) error {
return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
}
func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, finalFrag bool) error {
func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
cc := rl.cc
streamEnded := rl.lastHeaderEndsStream
cs := cc.streamByID(streamID, streamEnded && finalFrag)
cs := cc.streamByID(f.StreamID, f.StreamEnded())
if cs == nil {
// We'd get here if we canceled a request while the
// server was mid-way through replying with its
// headers. (The case of a CONTINUATION arriving
// without HEADERS would be rejected earlier by the
// Framer). So if this was just something we canceled,
// ignore it.
// server had its response still in flight. So if this
// was just something we canceled, ignore it.
return nil
}
if cs.pastHeaders {
rl.hdec.SetEmitFunc(func(f hpack.HeaderField) { rl.onNewTrailerField(cs, f) })
} else {
rl.hdec.SetEmitFunc(rl.onNewHeaderField)
}
_, err := rl.hdec.Write(frag)
if err != nil {
return ConnectionError(ErrCodeCompression)
}
if finalFrag {
if err := rl.hdec.Close(); err != nil {
return ConnectionError(ErrCodeCompression)
}
}
if !finalFrag {
return nil
}
if !cs.pastHeaders {
cs.pastHeaders = true
} else {
// We're dealing with trailers. (and specifically the
// final frame of headers)
if cs.pastTrailers {
// Too many HEADERS frames for this stream.
return ConnectionError(ErrCodeProtocol)
}
cs.pastTrailers = true
if !streamEnded {
// We expect that any header block fragment
// frame for trailers with END_HEADERS also
// has END_STREAM.
return ConnectionError(ErrCodeProtocol)
}
rl.endStream(cs)
return nil
return rl.processTrailers(cs, f)
}
if rl.reqMalformed != nil {
cs.resc <- resAndError{err: rl.reqMalformed}
rl.cc.writeStreamReset(cs.ID, ErrCodeProtocol, rl.reqMalformed)
res, err := rl.handleResponse(cs, f)
if err != nil {
if _, ok := err.(ConnectionError); ok {
return err
}
// Any other error type is a stream error.
cs.cc.writeStreamReset(f.StreamID, ErrCodeProtocol, err)
cs.resc <- resAndError{err: err}
return nil // return nil from process* funcs to keep conn alive
}
if res == nil {
// (nil, nil) special case. See handleResponse docs.
return nil
}
if res.Body != noBody {
rl.activeRes[cs.ID] = cs
}
cs.resTrailer = &res.Trailer
cs.resc <- resAndError{res: res}
return nil
}
res := rl.nextRes
// may return error types nil, or ConnectionError. Any other error value
// is a StreamError of type ErrCodeProtocol. The returned error in that case
// is the detail.
//
// As a special case, handleResponse may return (nil, nil) to skip the
// frame (currently only used for 100 expect continue). This special
// case is going away after Issue 13851 is fixed.
func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFrame) (*http.Response, error) {
if f.Truncated {
return nil, errResponseHeaderListSize
}
if res.StatusCode == 100 {
status := f.PseudoValue("status")
if status == "" {
return nil, errors.New("missing status pseudo header")
}
statusCode, err := strconv.Atoi(status)
if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header")
}
if statusCode == 100 {
// Just skip 100-continue response headers for now.
// TODO: golang.org/issue/13851 for doing it properly.
cs.pastHeaders = false // do it all again
return nil
return nil, nil
}
header := make(http.Header)
res := &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: header,
StatusCode: statusCode,
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
t = make(http.Header)
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil
})
} else {
header[key] = append(header[key], hf.Value)
}
}
streamEnded := f.StreamEnded()
if !streamEnded || cs.req.Method == "HEAD" {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
@@ -1251,25 +1280,49 @@ func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID u
if streamEnded {
res.Body = noBody
} else {
buf := new(bytes.Buffer) // TODO(bradfitz): recycle this garbage
cs.bufPipe = pipe{b: buf}
cs.bytesRemain = res.ContentLength
res.Body = transportResponseBody{cs}
go cs.awaitRequestCancel(requestCancel(cs.req))
if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &gzipReader{body: res.Body}
}
return res, nil
}
cs.resTrailer = &res.Trailer
rl.activeRes[cs.ID] = cs
cs.resc <- resAndError{res: res}
rl.nextRes = nil // unused now; will be reset next HEADERS frame
buf := new(bytes.Buffer) // TODO(bradfitz): recycle this garbage
cs.bufPipe = pipe{b: buf}
cs.bytesRemain = res.ContentLength
res.Body = transportResponseBody{cs}
go cs.awaitRequestCancel(requestCancel(cs.req))
if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &gzipReader{body: res.Body}
}
return res, nil
}
func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFrame) error {
if cs.pastTrailers {
// Too many HEADERS frames for this stream.
return ConnectionError(ErrCodeProtocol)
}
cs.pastTrailers = true
if !f.StreamEnded() {
// We expect that any headers for trailers also
// has END_STREAM.
return ConnectionError(ErrCodeProtocol)
}
if len(f.PseudoFields()) > 0 {
// No pseudo header fields are defined for trailers.
// TODO: ConnectionError might be overly harsh? Check.
return ConnectionError(ErrCodeProtocol)
}
trailer := make(http.Header)
for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name)
trailer[key] = append(trailer[key], hf.Value)
}
cs.trailer = trailer
rl.endStream(cs)
return nil
}
@@ -1387,6 +1440,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
cc.mu.Unlock()
if _, err := cs.bufPipe.Write(data); err != nil {
rl.endStreamError(cs, err)
return err
}
}
@@ -1402,14 +1456,20 @@ var errInvalidTrailers = errors.New("http2: invalid trailers")
func (rl *clientConnReadLoop) endStream(cs *clientStream) {
// TODO: check that any declared content-length matches, like
// server.go's (*stream).endStream method.
err := io.EOF
code := cs.copyTrailers
if rl.reqMalformed != nil {
err = rl.reqMalformed
code = nil
rl.endStreamError(cs, nil)
}
func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
var code func()
if err == nil {
err = io.EOF
code = cs.copyTrailers
}
cs.bufPipe.closeWithErrorAndCode(err, code)
delete(rl.activeRes, cs.ID)
if cs.req.Close || cs.req.Header.Get("Connection") == "close" {
rl.closeWhenIdle = true
}
}
func (cs *clientStream) copyTrailers() {
@@ -1542,118 +1602,6 @@ var (
errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers")
)
func (rl *clientConnReadLoop) checkHeaderField(f hpack.HeaderField) bool {
if rl.reqMalformed != nil {
return false
}
const headerFieldOverhead = 32 // per spec
rl.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
if max := rl.cc.t.maxHeaderListSize(); max != 0 && rl.headerListSize > int64(max) {
rl.hdec.SetEmitEnabled(false)
rl.reqMalformed = errResponseHeaderListSize
return false
}
if !validHeaderFieldValue(f.Value) {
rl.reqMalformed = errInvalidHeaderFieldValue
return false
}
isPseudo := strings.HasPrefix(f.Name, ":")
if isPseudo {
if rl.sawRegHeader {
rl.reqMalformed = errors.New("http2: invalid pseudo header after regular header")
return false
}
} else {
if !validHeaderFieldName(f.Name) {
rl.reqMalformed = errInvalidHeaderFieldName
return false
}
rl.sawRegHeader = true
}
return true
}
// onNewHeaderField runs on the readLoop goroutine whenever a new
// hpack header field is decoded.
func (rl *clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) {
cc := rl.cc
if VerboseLogs {
cc.logf("http2: Transport decoded %v", f)
}
if !rl.checkHeaderField(f) {
return
}
isPseudo := strings.HasPrefix(f.Name, ":")
if isPseudo {
switch f.Name {
case ":status":
code, err := strconv.Atoi(f.Value)
if err != nil {
rl.reqMalformed = errors.New("http2: invalid :status")
return
}
rl.nextRes.Status = f.Value + " " + http.StatusText(code)
rl.nextRes.StatusCode = code
default:
// "Endpoints MUST NOT generate pseudo-header
// fields other than those defined in this
// document."
rl.reqMalformed = fmt.Errorf("http2: unknown response pseudo header %q", f.Name)
}
return
}
key := http.CanonicalHeaderKey(f.Name)
if key == "Trailer" {
t := rl.nextRes.Trailer
if t == nil {
t = make(http.Header)
rl.nextRes.Trailer = t
}
foreachHeaderElement(f.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil
})
} else {
rl.nextRes.Header.Add(key, f.Value)
}
}
func (rl *clientConnReadLoop) onNewTrailerField(cs *clientStream, f hpack.HeaderField) {
if VerboseLogs {
rl.cc.logf("http2: Transport decoded trailer %v", f)
}
if !rl.checkHeaderField(f) {
return
}
if strings.HasPrefix(f.Name, ":") {
// Pseudo-header fields MUST NOT appear in
// trailers. Endpoints MUST treat a request or
// response that contains undefined or invalid
// pseudo-header fields as malformed.
rl.reqMalformed = errPseudoTrailers
return
}
key := http.CanonicalHeaderKey(f.Name)
// The spec says one must predeclare their trailers but in practice
// popular users (which is to say the only user we found) do not so we
// violate the spec and accept all of them.
const acceptAllTrailers = true
if _, ok := (*cs.resTrailer)[key]; ok || acceptAllTrailers {
if cs.trailer == nil {
cs.trailer = make(http.Header)
}
cs.trailer[key] = append(cs.trailer[key], f.Value)
}
}
func (cc *ClientConn) logf(format string, args ...interface{}) {
cc.t.logf(format, args...)
}
@@ -1691,13 +1639,18 @@ func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr io.Reader // lazily-initialized gzip reader
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}

View File

@@ -20,6 +20,7 @@ import (
"net/url"
"os"
"reflect"
"sort"
"strconv"
"strings"
"sync"
@@ -100,8 +101,7 @@ func TestTransport(t *testing.T) {
t.Errorf("Body = %q; want %q", slurp, body)
}
}
func TestTransportReusesConns(t *testing.T) {
func onSameConn(t *testing.T, modReq func(*http.Request)) bool {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.RemoteAddr)
}, optOnlyServer)
@@ -113,6 +113,7 @@ func TestTransportReusesConns(t *testing.T) {
if err != nil {
t.Fatal(err)
}
modReq(req)
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
@@ -130,8 +131,24 @@ func TestTransportReusesConns(t *testing.T) {
}
first := get()
second := get()
if first != second {
t.Errorf("first and second responses were on different connections: %q vs %q", first, second)
return first == second
}
func TestTransportReusesConns(t *testing.T) {
if !onSameConn(t, func(*http.Request) {}) {
t.Errorf("first and second responses were on different connections")
}
}
func TestTransportReusesConn_RequestClose(t *testing.T) {
if onSameConn(t, func(r *http.Request) { r.Close = true }) {
t.Errorf("first and second responses were not on different connections")
}
}
func TestTransportReusesConn_ConnClose(t *testing.T) {
if onSameConn(t, func(r *http.Request) { r.Header.Set("Connection", "close") }) {
t.Errorf("first and second responses were not on different connections")
}
}
@@ -309,28 +326,28 @@ func randString(n int) string {
return string(b)
}
var bodyTests = []struct {
body string
noContentLen bool
}{
{body: "some message"},
{body: "some message", noContentLen: true},
{body: ""},
{body: "", noContentLen: true},
{body: strings.Repeat("a", 1<<20), noContentLen: true},
{body: strings.Repeat("a", 1<<20)},
{body: randString(16<<10 - 1)},
{body: randString(16 << 10)},
{body: randString(16<<10 + 1)},
{body: randString(512<<10 - 1)},
{body: randString(512 << 10)},
{body: randString(512<<10 + 1)},
{body: randString(1<<20 - 1)},
{body: randString(1 << 20)},
{body: randString(1<<20 + 2)},
}
func TestTransportBody(t *testing.T) {
bodyTests := []struct {
body string
noContentLen bool
}{
{body: "some message"},
{body: "some message", noContentLen: true},
{body: ""},
{body: "", noContentLen: true},
{body: strings.Repeat("a", 1<<20), noContentLen: true},
{body: strings.Repeat("a", 1<<20)},
{body: randString(16<<10 - 1)},
{body: randString(16 << 10)},
{body: randString(16<<10 + 1)},
{body: randString(512<<10 - 1)},
{body: randString(512 << 10)},
{body: randString(512<<10 + 1)},
{body: randString(1<<20 - 1)},
{body: randString(1 << 20)},
{body: randString(1<<20 + 2)},
}
type reqInfo struct {
req *http.Request
slurp []byte
@@ -370,7 +387,7 @@ func TestTransportBody(t *testing.T) {
defer res.Body.Close()
ri := <-gotc
if ri.err != nil {
t.Errorf("%#d: read error: %v", i, ri.err)
t.Errorf("#%d: read error: %v", i, ri.err)
continue
}
if got := string(ri.slurp); got != tt.body {
@@ -1087,7 +1104,7 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
testTransportInvalidTrailer_Pseudo(t, splitHeader)
}
func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
testInvalidTrailer(t, trailers, errPseudoTrailers, func(enc *hpack.Encoder) {
testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
})
@@ -1100,19 +1117,19 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) {
testTransportInvalidTrailer_Capital(t, splitHeader)
}
func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
testInvalidTrailer(t, trailers, errInvalidHeaderFieldName, func(enc *hpack.Encoder) {
testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
})
}
func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
testInvalidTrailer(t, oneHeader, errInvalidHeaderFieldName, func(enc *hpack.Encoder) {
testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
})
}
func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
testInvalidTrailer(t, oneHeader, errInvalidHeaderFieldValue, func(enc *hpack.Encoder) {
enc.WriteField(hpack.HeaderField{Name: "", Value: "has\nnewline"})
testInvalidTrailer(t, oneHeader, headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) {
enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
})
}
@@ -1130,7 +1147,7 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT
}
slurp, err := ioutil.ReadAll(res.Body)
if err != wantErr {
return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, wantErr)
return fmt.Errorf("res.Body ReadAll error = %q, %#v; want %T of %#v", slurp, err, wantErr, wantErr)
}
if len(slurp) > 0 {
return fmt.Errorf("body = %q; want nothing", slurp)
@@ -1549,3 +1566,175 @@ func TestTransportDisableCompression(t *testing.T) {
}
defer res.Body.Close()
}
// RFC 7540 section 8.1.2.2
func TestTransportRejectsConnHeaders(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
var got []string
for k := range r.Header {
got = append(got, k)
}
sort.Strings(got)
w.Header().Set("Got-Header", strings.Join(got, ","))
}, optOnlyServer)
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
tests := []struct {
key string
value []string
want string
}{
{
key: "Upgrade",
value: []string{"anything"},
want: "ERROR: http2: invalid Upgrade request header",
},
{
key: "Connection",
value: []string{"foo"},
want: "ERROR: http2: invalid Connection request header",
},
{
key: "Connection",
value: []string{"close"},
want: "Accept-Encoding,User-Agent",
},
{
key: "Connection",
value: []string{"close", "something-else"},
want: "ERROR: http2: invalid Connection request header",
},
{
key: "Connection",
value: []string{"keep-alive"},
want: "Accept-Encoding,User-Agent",
},
{
key: "Proxy-Connection", // just deleted and ignored
value: []string{"keep-alive"},
want: "Accept-Encoding,User-Agent",
},
{
key: "Transfer-Encoding",
value: []string{""},
want: "Accept-Encoding,User-Agent",
},
{
key: "Transfer-Encoding",
value: []string{"foo"},
want: "ERROR: http2: invalid Transfer-Encoding request header",
},
{
key: "Transfer-Encoding",
value: []string{"chunked"},
want: "Accept-Encoding,User-Agent",
},
{
key: "Transfer-Encoding",
value: []string{"chunked", "other"},
want: "ERROR: http2: invalid Transfer-Encoding request header",
},
{
key: "Content-Length",
value: []string{"123"},
want: "Accept-Encoding,User-Agent",
},
}
for _, tt := range tests {
req, _ := http.NewRequest("GET", st.ts.URL, nil)
req.Header[tt.key] = tt.value
res, err := tr.RoundTrip(req)
var got string
if err != nil {
got = fmt.Sprintf("ERROR: %v", err)
} else {
got = res.Header.Get("Got-Header")
res.Body.Close()
}
if got != tt.want {
t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
}
}
}
// Tests that gzipReader doesn't crash on a second Read call following
// the first Read call's gzip.NewReader returning an error.
func TestGzipReader_DoubleReadCrash(t *testing.T) {
gz := &gzipReader{
body: ioutil.NopCloser(strings.NewReader("0123456789")),
}
var buf [1]byte
n, err1 := gz.Read(buf[:])
if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") {
t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1)
}
n, err2 := gz.Read(buf[:])
if n != 0 || err2 != err1 {
t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1)
}
}
func TestTransportNewTLSConfig(t *testing.T) {
tests := [...]struct {
conf *tls.Config
host string
want *tls.Config
}{
// Normal case.
0: {
conf: nil,
host: "foo.com",
want: &tls.Config{
ServerName: "foo.com",
NextProtos: []string{NextProtoTLS},
},
},
// User-provided name (bar.com) takes precedence:
1: {
conf: &tls.Config{
ServerName: "bar.com",
},
host: "foo.com",
want: &tls.Config{
ServerName: "bar.com",
NextProtos: []string{NextProtoTLS},
},
},
// NextProto is prepended:
2: {
conf: &tls.Config{
NextProtos: []string{"foo", "bar"},
},
host: "example.com",
want: &tls.Config{
ServerName: "example.com",
NextProtos: []string{NextProtoTLS, "foo", "bar"},
},
},
// NextProto is not duplicated:
3: {
conf: &tls.Config{
NextProtos: []string{"foo", "bar", NextProtoTLS},
},
host: "example.com",
want: &tls.Config{
ServerName: "example.com",
NextProtos: []string{"foo", "bar", NextProtoTLS},
},
},
}
for i, tt := range tests {
tr := &Transport{TLSClientConfig: tt.conf}
got := tr.newTLSConfig(tt.host)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("%d. got %#v; want %#v", i, got, tt.want)
}
}
}

View File

@@ -9,7 +9,6 @@ import (
"fmt"
"log"
"net/http"
"sort"
"time"
"golang.org/x/net/http2/hpack"
@@ -230,19 +229,29 @@ func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
}
func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
// TODO: garbage. pool sorters like http1? hot path for 1 key?
if keys == nil {
keys = make([]string, 0, len(h))
for k := range h {
keys = append(keys, k)
}
sort.Strings(keys)
sorter := sorterPool.Get().(*sorter)
// Using defer here, since the returned keys from the
// sorter.Keys method is only valid until the sorter
// is returned:
defer sorterPool.Put(sorter)
keys = sorter.Keys(h)
}
for _, k := range keys {
vv := h[k]
k = lowerHeader(k)
if !validHeaderFieldName(k) {
// TODO: return an error? golang.org/issue/14048
// For now just omit it.
continue
}
isTE := k == "transfer-encoding"
for _, v := range vv {
if !validHeaderFieldValue(v) {
// TODO: return an error? golang.org/issue/14048
// For now just omit it.
continue
}
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
if isTE && v != "trailers" {
continue