Files
fn-serverless/grpcutil/dial.go
Andrea Rosa 3261e48843 Add a timeout to the net dialer (#844)
This change add the option to set a timeout for the dialer used in
making gRPC connection, with that we remove the check on the state of
the connections and therefore remove any potential race conditions.
2018-03-12 13:36:53 +00:00

82 lines
2.5 KiB
Go

package grpcutil
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"time"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// DialWithBackoff creates a grpc connection using backoff strategy for reconnections
func DialWithBackoff(ctx context.Context, address string, creds credentials.TransportCredentials, timeout time.Duration, backoffCfg grpc.BackoffConfig) (*grpc.ClientConn, error) {
return dial(ctx, address, creds, timeout, grpc.WithBackoffConfig(backoffCfg))
}
// uses grpc connection backoff protocol https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md
func dial(ctx context.Context, address string, creds credentials.TransportCredentials, timeoutDialer time.Duration, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
dialer := func(address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
conn, err := (&net.Dialer{Cancel: ctx.Done(), Timeout: timeoutDialer}).Dial("tcp", address)
if err != nil {
logrus.WithField("grpc_addr", address).Warn("Failed to dial grpc connection")
return nil, err
}
if creds == nil {
logrus.WithField("grpc_addr", address).Warn("Created insecure grpc connection")
return conn, nil
}
conn, _, err = creds.ClientHandshake(ctx, address, conn)
if err != nil {
logrus.WithField("grpc_addr", address).Warn("Failed grpc handshake")
return nil, err
}
return conn, nil
}
opts = append(opts,
grpc.WithDialer(dialer),
grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
)
return grpc.DialContext(ctx, address, opts...)
}
// CreateCredentials creates a new set of TLS credentials
func CreateCredentials(certPath string, keyPath string, caCertPath string) (credentials.TransportCredentials, error) {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("could not load client key pair: %s", err)
}
// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(caCertPath)
if err != nil {
return nil, fmt.Errorf("could not read ca certificate: %s", err)
}
// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}
return credentials.NewTLS(&tls.Config{
ServerName: "127.0.0.1", // NOTE: this is required!
Certificates: []tls.Certificate{certificate},
RootCAs: certPool,
}), nil
}