From dbca70898bfc1dad2f0427f42b0795d09b50d34f Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Mon, 24 Oct 2022 10:30:52 -0400 Subject: [PATCH] Refactor cluster client --- cluster/client.go | 275 ++++++++++------------------------------------ 1 file changed, 60 insertions(+), 215 deletions(-) diff --git a/cluster/client.go b/cluster/client.go index 705ecec9..8f9c75ff 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -74,52 +74,12 @@ func (c *Client) GetNodeAPIAddr(nodeAddr string, timeout time.Duration) (string, command := &Command{ Type: Command_COMMAND_TYPE_GET_NODE_API_URL, } - p, err := proto.Marshal(command) - if err != nil { - return "", fmt.Errorf("command marshal: %s", err) + if err := writeCommand(conn, command, timeout); err != nil { + handleConnError(conn) + return "", err } - // Write length of Protobuf - b := make([]byte, 4) - binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return "", err - } - _, err = conn.Write(b) - if err != nil { - handleConnError(conn) - return "", fmt.Errorf("write protobuf length: %s", err) - } - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return "", err - } - _, err = conn.Write(p) - if err != nil { - handleConnError(conn) - return "", fmt.Errorf("write protobuf: %s", err) - } - - // Read length of response. - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return "", err - } - _, err = io.ReadFull(conn, b) - if err != nil { - handleConnError(conn) - return "", err - } - sz := binary.LittleEndian.Uint16(b[0:]) - - // Read in the actual response. - p = make([]byte, sz) - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return "", err - } - _, err = io.ReadFull(conn, p) + p, err := readResponse(conn, timeout) if err != nil { handleConnError(conn) return "", err @@ -152,54 +112,12 @@ func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, creds *Cre }, Credentials: creds, } - - p, err := proto.Marshal(command) - if err != nil { - return nil, fmt.Errorf("command marshal: %s", err) - } - - // Write length of Protobuf - b := make([]byte, 4) - binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) - - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return nil, err - } - _, err = conn.Write(b) - if err != nil { - handleConnError(conn) - return nil, err - } - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return nil, err - } - _, err = conn.Write(p) - if err != nil { + if err := writeCommand(conn, command, timeout); err != nil { handleConnError(conn) return nil, err } - // Read length of response. - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return nil, err - } - _, err = io.ReadFull(conn, b) - if err != nil { - handleConnError(conn) - return nil, err - } - sz := binary.LittleEndian.Uint32(b[0:]) - - // Read in the actual response. - p = make([]byte, sz) - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return nil, err - } - _, err = io.ReadFull(conn, p) + p, err := readResponse(conn, timeout) if err != nil { handleConnError(conn) return nil, err @@ -233,51 +151,12 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, creds *Credent }, Credentials: creds, } - - p, err := proto.Marshal(command) - if err != nil { - return nil, fmt.Errorf("command marshal: %s", err) - } - - // Write length of Protobuf, then the Protobuf - b := make([]byte, 4) - binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) - - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return nil, err - } - _, err = conn.Write(b) - if err != nil { - handleConnError(conn) - return nil, err - } - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return nil, err - } - _, err = conn.Write(p) - if err != nil { + if err := writeCommand(conn, command, timeout); err != nil { handleConnError(conn) return nil, err } - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return nil, err - } - - // Read length of response. - _, err = io.ReadFull(conn, b) - if err != nil { - handleConnError(conn) - return nil, err - } - sz := binary.LittleEndian.Uint32(b[0:]) - - // Read in the actual response. - p = make([]byte, sz) - _, err = io.ReadFull(conn, p) + p, err := readResponse(conn, timeout) if err != nil { handleConnError(conn) return nil, err @@ -311,52 +190,12 @@ func (c *Client) Backup(br *command.BackupRequest, nodeAddr string, creds *Crede }, Credentials: creds, } - p, err := proto.Marshal(command) - if err != nil { - return fmt.Errorf("command marshal: %s", err) - } - - // Write length of Protobuf - b := make([]byte, 4) - binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return err - } - _, err = conn.Write(b) - if err != nil { - handleConnError(conn) - return fmt.Errorf("write protobuf length: %s", err) - } - - // Now write backup request proto itself. - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return err - } - _, err = conn.Write(p) - if err != nil { - handleConnError(conn) - return fmt.Errorf("write protobuf: %s", err) - } - - // Read the backup response - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + if err := writeCommand(conn, command, timeout); err != nil { handleConnError(conn) return err } - // Read length of response. - _, err = io.ReadFull(conn, b) - if err != nil { - handleConnError(conn) - return err - } - sz := binary.LittleEndian.Uint32(b[0:]) - - // Read in the actual response. - p = make([]byte, sz) - _, err = io.ReadFull(conn, p) + p, err := readResponse(conn, timeout) if err != nil { handleConnError(conn) return err @@ -401,54 +240,12 @@ func (c *Client) Load(lr *command.LoadRequest, nodeAddr string, creds *Credentia }, Credentials: creds, } - - p, err := proto.Marshal(command) - if err != nil { - return fmt.Errorf("command marshal: %s", err) - } - - // Write length of Protobuf - b := make([]byte, 4) - binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) - - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return err - } - _, err = conn.Write(b) - if err != nil { - handleConnError(conn) - return err - } - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return err - } - _, err = conn.Write(p) - if err != nil { + if err := writeCommand(conn, command, timeout); err != nil { handleConnError(conn) return err } - // Read length of response. - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return err - } - _, err = io.ReadFull(conn, b) - if err != nil { - handleConnError(conn) - return err - } - sz := binary.LittleEndian.Uint32(b[0:]) - - // Read in the actual response. - p = make([]byte, sz) - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - handleConnError(conn) - return err - } - _, err = io.ReadFull(conn, p) + p, err := readResponse(conn, timeout) if err != nil { handleConnError(conn) return err @@ -532,6 +329,54 @@ func (c *Client) dial(nodeAddr string, timeout time.Duration) (net.Conn, error) return conn, nil } +func writeCommand(conn net.Conn, c *Command, timeout time.Duration) error { + p, err := proto.Marshal(c) + if err != nil { + return fmt.Errorf("command marshal: %s", err) + } + + // Write length of Protobuf + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return err + } + b := make([]byte, 4) + binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) + _, err = conn.Write(b) + if err != nil { + return err + } + // Write actual protobuf. + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return err + } + _, err = conn.Write(p) + return err +} + +func readResponse(conn net.Conn, timeout time.Duration) ([]byte, error) { + // Read length of incoming response. + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return nil, err + } + b := make([]byte, 4) + _, err := io.ReadFull(conn, b) + if err != nil { + return nil, err + } + sz := binary.LittleEndian.Uint32(b[0:]) + + // Read in the actual response. + p := make([]byte, sz) + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return nil, err + } + _, err = io.ReadFull(conn, p) + if err != nil { + return nil, err + } + return p, nil +} + func handleConnError(conn net.Conn) { if pc, ok := conn.(*pool.Conn); ok { pc.MarkUnusable()