Files
fn-serverless/vendor/github.com/grpc-ecosystem/go-grpc-middleware/auth/auth_test.go
Andrea Rosa e637661ea2 Adding a way to inject a request ID (#1046)
* Adding a way to inject a request ID

It is very useful to associate a request ID to each incoming request,
this change allows to provide a function to do that via Server Option.
The change comes with a default function which will generate a new
request ID. The request ID is put in the request context along with a
common logger which always logs the request-id

We add gRPC interceptors to the server so it can get the request ID out
of the gRPC metadata and put it in the common logger stored in the
context so as all the log lines using the common logger from the context
will have the request ID logged
2018-06-14 10:40:55 +01:00

207 lines
7.4 KiB
Go

// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_auth_test
import (
"testing"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"fmt"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/testing"
pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/oauth"
"google.golang.org/grpc/metadata"
)
var (
commonAuthToken = "some_good_token"
overrideAuthToken = "override_token"
authedMarker = "some_context_marker"
goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
)
// TODO(mwitkow): Add auth from metadata client dialer, which requires TLS.
func buildDummyAuthFunction(expectedScheme string, expectedToken string) func(ctx context.Context) (context.Context, error) {
return func(ctx context.Context) (context.Context, error) {
token, err := grpc_auth.AuthFromMD(ctx, expectedScheme)
if err != nil {
return nil, err
}
if token != expectedToken {
return nil, grpc.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token")
}
return context.WithValue(ctx, authedMarker, "marker_exists"), nil
}
}
func assertAuthMarkerExists(t *testing.T, ctx context.Context) {
assert.Equal(t, "marker_exists", ctx.Value(authedMarker).(string), "auth marker from buildDummyAuthFunction must be passed around")
}
type assertingPingService struct {
pb_testproto.TestServiceServer
T *testing.T
}
func (s *assertingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) {
assertAuthMarkerExists(s.T, ctx)
return s.TestServiceServer.PingError(ctx, ping)
}
func (s *assertingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
assertAuthMarkerExists(s.T, stream.Context())
return s.TestServiceServer.PingList(ping, stream)
}
func ctxWithToken(ctx context.Context, scheme string, token string) context.Context {
md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token))
nCtx := metautils.NiceMD(md).ToOutgoing(ctx)
return nCtx
}
func TestAuthTestSuite(t *testing.T) {
authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
s := &AuthTestSuite{
InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
TestService: &assertingPingService{&grpc_testing.TestPingService{T: t}, t},
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)),
grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)),
},
},
}
suite.Run(t, s)
}
type AuthTestSuite struct {
*grpc_testing.InterceptorTestSuite
}
func (s *AuthTestSuite) TestUnary_NoAuth() {
_, err := s.Client.Ping(s.SimpleCtx(), goodPing)
assert.Error(s.T(), err, "there must be an error")
assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated")
}
func (s *AuthTestSuite) TestUnary_BadAuth() {
_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
assert.Error(s.T(), err, "there must be an error")
assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied")
}
func (s *AuthTestSuite) TestUnary_PassesAuth() {
_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", commonAuthToken), goodPing)
require.NoError(s.T(), err, "no error must occur")
}
func (s *AuthTestSuite) TestUnary_PassesWithPerRpcCredentials() {
grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
_, err := client.Ping(s.SimpleCtx(), goodPing)
require.NoError(s.T(), err, "no error must occur")
}
func (s *AuthTestSuite) TestStream_NoAuth() {
stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
require.NoError(s.T(), err, "should not fail on establishing the stream")
_, err = stream.Recv()
assert.Error(s.T(), err, "there must be an error")
assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated")
}
func (s *AuthTestSuite) TestStream_BadAuth() {
stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
require.NoError(s.T(), err, "should not fail on establishing the stream")
_, err = stream.Recv()
assert.Error(s.T(), err, "there must be an error")
assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied")
}
func (s *AuthTestSuite) TestStream_PassesAuth() {
stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", commonAuthToken), goodPing)
require.NoError(s.T(), err, "should not fail on establishing the stream")
pong, err := stream.Recv()
require.NoError(s.T(), err, "no error must occur")
require.NotNil(s.T(), pong, "pong must not be nil")
}
func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() {
grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
stream, err := client.PingList(s.SimpleCtx(), goodPing)
require.NoError(s.T(), err, "should not fail on establishing the stream")
pong, err := stream.Recv()
require.NoError(s.T(), err, "no error must occur")
require.NotNil(s.T(), pong, "pong must not be nil")
}
type authOverrideTestService struct {
pb_testproto.TestServiceServer
T *testing.T
}
func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) {
assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around")
return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx)
}
func TestAuthOverrideTestSuite(t *testing.T) {
authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
s := &AuthOverrideTestSuite{
InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
TestService: &authOverrideTestService{&assertingPingService{&grpc_testing.TestPingService{T: t}, t}, t},
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)),
grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)),
},
},
}
suite.Run(t, s)
}
type AuthOverrideTestSuite struct {
*grpc_testing.InterceptorTestSuite
}
func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() {
_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), goodPing)
require.NoError(s.T(), err, "no error must occur")
}
func (s *AuthOverrideTestSuite) TestStream_PassesAuth() {
stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), goodPing)
require.NoError(s.T(), err, "should not fail on establishing the stream")
pong, err := stream.Recv()
require.NoError(s.T(), err, "no error must occur")
require.NotNil(s.T(), pong, "pong must not be nil")
}
// fakeOAuth2TokenSource implements a fake oauth2.TokenSource for the purpose of credentials test.
type fakeOAuth2TokenSource struct {
accessToken string
}
func (ts *fakeOAuth2TokenSource) Token() (*oauth2.Token, error) {
t := &oauth2.Token{
AccessToken: ts.accessToken,
Expiry: time.Now().Add(1 * time.Minute),
TokenType: "bearer",
}
return t, nil
}