123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765 |
- package nsq
- import (
- "bufio"
- "bytes"
- "compress/flate"
- "crypto/tls"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net"
- "strings"
- "sync"
- "sync/atomic"
- "time"
- "github.com/golang/snappy"
- )
- // IdentifyResponse represents the metadata
- // returned from an IDENTIFY command to nsqd
- type IdentifyResponse struct {
- MaxRdyCount int64 `json:"max_rdy_count"`
- TLSv1 bool `json:"tls_v1"`
- Deflate bool `json:"deflate"`
- Snappy bool `json:"snappy"`
- AuthRequired bool `json:"auth_required"`
- }
- // AuthResponse represents the metadata
- // returned from an AUTH command to nsqd
- type AuthResponse struct {
- Identity string `json:"identity"`
- IdentityUrl string `json:"identity_url"`
- PermissionCount int64 `json:"permission_count"`
- }
- type msgResponse struct {
- msg *Message
- cmd *Command
- success bool
- backoff bool
- }
- // Conn represents a connection to nsqd
- //
- // Conn exposes a set of callbacks for the
- // various events that occur on a connection
- type Conn struct {
- // 64bit atomic vars need to be first for proper alignment on 32bit platforms
- messagesInFlight int64
- maxRdyCount int64
- rdyCount int64
- lastRdyTimestamp int64
- lastMsgTimestamp int64
- mtx sync.Mutex
- config *Config
- conn *net.TCPConn
- tlsConn *tls.Conn
- addr string
- delegate ConnDelegate
- logger []logger
- logLvl LogLevel
- logFmt []string
- logGuard sync.RWMutex
- r io.Reader
- w io.Writer
- cmdChan chan *Command
- msgResponseChan chan *msgResponse
- exitChan chan int
- drainReady chan int
- closeFlag int32
- stopper sync.Once
- wg sync.WaitGroup
- readLoopRunning int32
- }
- // NewConn returns a new Conn instance
- func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn {
- if !config.initialized {
- panic("Config must be created with NewConfig()")
- }
- return &Conn{
- addr: addr,
- config: config,
- delegate: delegate,
- maxRdyCount: 2500,
- lastMsgTimestamp: time.Now().UnixNano(),
- cmdChan: make(chan *Command),
- msgResponseChan: make(chan *msgResponse),
- exitChan: make(chan int),
- drainReady: make(chan int),
- logger: make([]logger, LogLevelMax+1),
- logFmt: make([]string, LogLevelMax+1),
- }
- }
- // SetLogger assigns the logger to use as well as a level.
- //
- // The format parameter is expected to be a printf compatible string with
- // a single %s argument. This is useful if you want to provide additional
- // context to the log messages that the connection will print, the default
- // is '(%s)'.
- //
- // The logger parameter is an interface that requires the following
- // method to be implemented (such as the the stdlib log.Logger):
- //
- // Output(calldepth int, s string)
- //
- func (c *Conn) SetLogger(l logger, lvl LogLevel, format string) {
- c.logGuard.Lock()
- defer c.logGuard.Unlock()
- if format == "" {
- format = "(%s)"
- }
- for level := range c.logger {
- c.logger[level] = l
- c.logFmt[level] = format
- }
- c.logLvl = lvl
- }
- func (c *Conn) SetLoggerForLevel(l logger, lvl LogLevel, format string) {
- c.logGuard.Lock()
- defer c.logGuard.Unlock()
- if format == "" {
- format = "(%s)"
- }
- c.logger[lvl] = l
- c.logFmt[lvl] = format
- }
- // SetLoggerLevel sets the package logging level.
- func (c *Conn) SetLoggerLevel(lvl LogLevel) {
- c.logGuard.Lock()
- defer c.logGuard.Unlock()
- c.logLvl = lvl
- }
- func (c *Conn) getLogger(lvl LogLevel) (logger, LogLevel, string) {
- c.logGuard.RLock()
- defer c.logGuard.RUnlock()
- return c.logger[lvl], c.logLvl, c.logFmt[lvl]
- }
- func (c *Conn) getLogLevel() LogLevel {
- c.logGuard.RLock()
- defer c.logGuard.RUnlock()
- return c.logLvl
- }
- // Connect dials and bootstraps the nsqd connection
- // (including IDENTIFY) and returns the IdentifyResponse
- func (c *Conn) Connect() (*IdentifyResponse, error) {
- dialer := &net.Dialer{
- LocalAddr: c.config.LocalAddr,
- Timeout: c.config.DialTimeout,
- }
- conn, err := dialer.Dial("tcp", c.addr)
- if err != nil {
- return nil, err
- }
- c.conn = conn.(*net.TCPConn)
- c.r = conn
- c.w = conn
- _, err = c.Write(MagicV2)
- if err != nil {
- c.Close()
- return nil, fmt.Errorf("[%s] failed to write magic - %s", c.addr, err)
- }
- resp, err := c.identify()
- if err != nil {
- return nil, err
- }
- if resp != nil && resp.AuthRequired {
- if c.config.AuthSecret == "" {
- c.log(LogLevelError, "Auth Required")
- return nil, errors.New("Auth Required")
- }
- err := c.auth(c.config.AuthSecret)
- if err != nil {
- c.log(LogLevelError, "Auth Failed %s", err)
- return nil, err
- }
- }
- c.wg.Add(2)
- atomic.StoreInt32(&c.readLoopRunning, 1)
- go c.readLoop()
- go c.writeLoop()
- return resp, nil
- }
- // Close idempotently initiates connection close
- func (c *Conn) Close() error {
- atomic.StoreInt32(&c.closeFlag, 1)
- if c.conn != nil && atomic.LoadInt64(&c.messagesInFlight) == 0 {
- return c.conn.CloseRead()
- }
- return nil
- }
- // IsClosing indicates whether or not the
- // connection is currently in the processing of
- // gracefully closing
- func (c *Conn) IsClosing() bool {
- return atomic.LoadInt32(&c.closeFlag) == 1
- }
- // RDY returns the current RDY count
- func (c *Conn) RDY() int64 {
- return atomic.LoadInt64(&c.rdyCount)
- }
- // LastRDY returns the previously set RDY count
- func (c *Conn) LastRDY() int64 {
- return atomic.LoadInt64(&c.rdyCount)
- }
- // SetRDY stores the specified RDY count
- func (c *Conn) SetRDY(rdy int64) {
- atomic.StoreInt64(&c.rdyCount, rdy)
- if rdy > 0 {
- atomic.StoreInt64(&c.lastRdyTimestamp, time.Now().UnixNano())
- }
- }
- // MaxRDY returns the nsqd negotiated maximum
- // RDY count that it will accept for this connection
- func (c *Conn) MaxRDY() int64 {
- return c.maxRdyCount
- }
- // LastRdyTime returns the time of the last non-zero RDY
- // update for this connection
- func (c *Conn) LastRdyTime() time.Time {
- return time.Unix(0, atomic.LoadInt64(&c.lastRdyTimestamp))
- }
- // LastMessageTime returns a time.Time representing
- // the time at which the last message was received
- func (c *Conn) LastMessageTime() time.Time {
- return time.Unix(0, atomic.LoadInt64(&c.lastMsgTimestamp))
- }
- // RemoteAddr returns the configured destination nsqd address
- func (c *Conn) RemoteAddr() net.Addr {
- return c.conn.RemoteAddr()
- }
- // String returns the fully-qualified address
- func (c *Conn) String() string {
- return c.addr
- }
- // Read performs a deadlined read on the underlying TCP connection
- func (c *Conn) Read(p []byte) (int, error) {
- c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
- return c.r.Read(p)
- }
- // Write performs a deadlined write on the underlying TCP connection
- func (c *Conn) Write(p []byte) (int, error) {
- c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
- return c.w.Write(p)
- }
- // WriteCommand is a goroutine safe method to write a Command
- // to this connection, and flush.
- func (c *Conn) WriteCommand(cmd *Command) error {
- c.mtx.Lock()
- _, err := cmd.WriteTo(c)
- if err != nil {
- goto exit
- }
- err = c.Flush()
- exit:
- c.mtx.Unlock()
- if err != nil {
- c.log(LogLevelError, "IO error - %s", err)
- c.delegate.OnIOError(c, err)
- }
- return err
- }
- type flusher interface {
- Flush() error
- }
- // Flush writes all buffered data to the underlying TCP connection
- func (c *Conn) Flush() error {
- if f, ok := c.w.(flusher); ok {
- return f.Flush()
- }
- return nil
- }
- func (c *Conn) identify() (*IdentifyResponse, error) {
- ci := make(map[string]interface{})
- ci["client_id"] = c.config.ClientID
- ci["hostname"] = c.config.Hostname
- ci["user_agent"] = c.config.UserAgent
- ci["short_id"] = c.config.ClientID // deprecated
- ci["long_id"] = c.config.Hostname // deprecated
- ci["tls_v1"] = c.config.TlsV1
- ci["deflate"] = c.config.Deflate
- ci["deflate_level"] = c.config.DeflateLevel
- ci["snappy"] = c.config.Snappy
- ci["feature_negotiation"] = true
- if c.config.HeartbeatInterval == -1 {
- ci["heartbeat_interval"] = -1
- } else {
- ci["heartbeat_interval"] = int64(c.config.HeartbeatInterval / time.Millisecond)
- }
- ci["sample_rate"] = c.config.SampleRate
- ci["output_buffer_size"] = c.config.OutputBufferSize
- if c.config.OutputBufferTimeout == -1 {
- ci["output_buffer_timeout"] = -1
- } else {
- ci["output_buffer_timeout"] = int64(c.config.OutputBufferTimeout / time.Millisecond)
- }
- ci["msg_timeout"] = int64(c.config.MsgTimeout / time.Millisecond)
- cmd, err := Identify(ci)
- if err != nil {
- return nil, ErrIdentify{err.Error()}
- }
- err = c.WriteCommand(cmd)
- if err != nil {
- return nil, ErrIdentify{err.Error()}
- }
- frameType, data, err := ReadUnpackedResponse(c)
- if err != nil {
- return nil, ErrIdentify{err.Error()}
- }
- if frameType == FrameTypeError {
- return nil, ErrIdentify{string(data)}
- }
- // check to see if the server was able to respond w/ capabilities
- // i.e. it was a JSON response
- if data[0] != '{' {
- return nil, nil
- }
- resp := &IdentifyResponse{}
- err = json.Unmarshal(data, resp)
- if err != nil {
- return nil, ErrIdentify{err.Error()}
- }
- c.log(LogLevelDebug, "IDENTIFY response: %+v", resp)
- c.maxRdyCount = resp.MaxRdyCount
- if resp.TLSv1 {
- c.log(LogLevelInfo, "upgrading to TLS")
- err := c.upgradeTLS(c.config.TlsConfig)
- if err != nil {
- return nil, ErrIdentify{err.Error()}
- }
- }
- if resp.Deflate {
- c.log(LogLevelInfo, "upgrading to Deflate")
- err := c.upgradeDeflate(c.config.DeflateLevel)
- if err != nil {
- return nil, ErrIdentify{err.Error()}
- }
- }
- if resp.Snappy {
- c.log(LogLevelInfo, "upgrading to Snappy")
- err := c.upgradeSnappy()
- if err != nil {
- return nil, ErrIdentify{err.Error()}
- }
- }
- // now that connection is bootstrapped, enable read buffering
- // (and write buffering if it's not already capable of Flush())
- c.r = bufio.NewReader(c.r)
- if _, ok := c.w.(flusher); !ok {
- c.w = bufio.NewWriter(c.w)
- }
- return resp, nil
- }
- func (c *Conn) upgradeTLS(tlsConf *tls.Config) error {
- host, _, err := net.SplitHostPort(c.addr)
- if err != nil {
- return err
- }
- // create a local copy of the config to set ServerName for this connection
- conf := &tls.Config{}
- if tlsConf != nil {
- conf = tlsConf.Clone()
- }
- conf.ServerName = host
- c.tlsConn = tls.Client(c.conn, conf)
- err = c.tlsConn.Handshake()
- if err != nil {
- return err
- }
- c.r = c.tlsConn
- c.w = c.tlsConn
- frameType, data, err := ReadUnpackedResponse(c)
- if err != nil {
- return err
- }
- if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) {
- return errors.New("invalid response from TLS upgrade")
- }
- return nil
- }
- func (c *Conn) upgradeDeflate(level int) error {
- conn := net.Conn(c.conn)
- if c.tlsConn != nil {
- conn = c.tlsConn
- }
- fw, _ := flate.NewWriter(conn, level)
- c.r = flate.NewReader(conn)
- c.w = fw
- frameType, data, err := ReadUnpackedResponse(c)
- if err != nil {
- return err
- }
- if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) {
- return errors.New("invalid response from Deflate upgrade")
- }
- return nil
- }
- func (c *Conn) upgradeSnappy() error {
- conn := net.Conn(c.conn)
- if c.tlsConn != nil {
- conn = c.tlsConn
- }
- c.r = snappy.NewReader(conn)
- c.w = snappy.NewWriter(conn)
- frameType, data, err := ReadUnpackedResponse(c)
- if err != nil {
- return err
- }
- if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) {
- return errors.New("invalid response from Snappy upgrade")
- }
- return nil
- }
- func (c *Conn) auth(secret string) error {
- cmd, err := Auth(secret)
- if err != nil {
- return err
- }
- err = c.WriteCommand(cmd)
- if err != nil {
- return err
- }
- frameType, data, err := ReadUnpackedResponse(c)
- if err != nil {
- return err
- }
- if frameType == FrameTypeError {
- return errors.New("Error authenticating " + string(data))
- }
- resp := &AuthResponse{}
- err = json.Unmarshal(data, resp)
- if err != nil {
- return err
- }
- c.log(LogLevelInfo, "Auth accepted. Identity: %q %s Permissions: %d",
- resp.Identity, resp.IdentityUrl, resp.PermissionCount)
- return nil
- }
- func (c *Conn) readLoop() {
- delegate := &connMessageDelegate{c}
- for {
- if atomic.LoadInt32(&c.closeFlag) == 1 {
- goto exit
- }
- frameType, data, err := ReadUnpackedResponse(c)
- if err != nil {
- if err == io.EOF && atomic.LoadInt32(&c.closeFlag) == 1 {
- goto exit
- }
- if !strings.Contains(err.Error(), "use of closed network connection") {
- c.log(LogLevelError, "IO error - %s", err)
- c.delegate.OnIOError(c, err)
- }
- goto exit
- }
- if frameType == FrameTypeResponse && bytes.Equal(data, []byte("_heartbeat_")) {
- c.log(LogLevelDebug, "heartbeat received")
- c.delegate.OnHeartbeat(c)
- err := c.WriteCommand(Nop())
- if err != nil {
- c.log(LogLevelError, "IO error - %s", err)
- c.delegate.OnIOError(c, err)
- goto exit
- }
- continue
- }
- switch frameType {
- case FrameTypeResponse:
- c.delegate.OnResponse(c, data)
- case FrameTypeMessage:
- msg, err := DecodeMessage(data)
- if err != nil {
- c.log(LogLevelError, "IO error - %s", err)
- c.delegate.OnIOError(c, err)
- goto exit
- }
- msg.Delegate = delegate
- msg.NSQDAddress = c.String()
- atomic.AddInt64(&c.messagesInFlight, 1)
- atomic.StoreInt64(&c.lastMsgTimestamp, time.Now().UnixNano())
- c.delegate.OnMessage(c, msg)
- case FrameTypeError:
- c.log(LogLevelError, "protocol error - %s", data)
- c.delegate.OnError(c, data)
- default:
- c.log(LogLevelError, "IO error - %s", err)
- c.delegate.OnIOError(c, fmt.Errorf("unknown frame type %d", frameType))
- }
- }
- exit:
- atomic.StoreInt32(&c.readLoopRunning, 0)
- // start the connection close
- messagesInFlight := atomic.LoadInt64(&c.messagesInFlight)
- if messagesInFlight == 0 {
- // if we exited readLoop with no messages in flight
- // we need to explicitly trigger the close because
- // writeLoop won't
- c.close()
- } else {
- c.log(LogLevelWarning, "delaying close, %d outstanding messages", messagesInFlight)
- }
- c.wg.Done()
- c.log(LogLevelInfo, "readLoop exiting")
- }
- func (c *Conn) writeLoop() {
- for {
- select {
- case <-c.exitChan:
- c.log(LogLevelInfo, "breaking out of writeLoop")
- // Indicate drainReady because we will not pull any more off msgResponseChan
- close(c.drainReady)
- goto exit
- case cmd := <-c.cmdChan:
- err := c.WriteCommand(cmd)
- if err != nil {
- c.log(LogLevelError, "error sending command %s - %s", cmd, err)
- c.close()
- continue
- }
- case resp := <-c.msgResponseChan:
- // Decrement this here so it is correct even if we can't respond to nsqd
- msgsInFlight := atomic.AddInt64(&c.messagesInFlight, -1)
- if resp.success {
- c.log(LogLevelDebug, "FIN %s", resp.msg.ID)
- c.delegate.OnMessageFinished(c, resp.msg)
- c.delegate.OnResume(c)
- } else {
- c.log(LogLevelDebug, "REQ %s", resp.msg.ID)
- c.delegate.OnMessageRequeued(c, resp.msg)
- if resp.backoff {
- c.delegate.OnBackoff(c)
- } else {
- c.delegate.OnContinue(c)
- }
- }
- err := c.WriteCommand(resp.cmd)
- if err != nil {
- c.log(LogLevelError, "error sending command %s - %s", resp.cmd, err)
- c.close()
- continue
- }
- if msgsInFlight == 0 &&
- atomic.LoadInt32(&c.closeFlag) == 1 {
- c.close()
- continue
- }
- }
- }
- exit:
- c.wg.Done()
- c.log(LogLevelInfo, "writeLoop exiting")
- }
- func (c *Conn) close() {
- // a "clean" connection close is orchestrated as follows:
- //
- // 1. CLOSE cmd sent to nsqd
- // 2. CLOSE_WAIT response received from nsqd
- // 3. set c.closeFlag
- // 4. readLoop() exits
- // a. if messages-in-flight > 0 delay close()
- // i. writeLoop() continues receiving on c.msgResponseChan chan
- // x. when messages-in-flight == 0 call close()
- // b. else call close() immediately
- // 5. c.exitChan close
- // a. writeLoop() exits
- // i. c.drainReady close
- // 6a. launch cleanup() goroutine (we're racing with intraprocess
- // routed messages, see comments below)
- // a. wait on c.drainReady
- // b. loop and receive on c.msgResponseChan chan
- // until messages-in-flight == 0
- // i. ensure that readLoop has exited
- // 6b. launch waitForCleanup() goroutine
- // b. wait on waitgroup (covers readLoop() and writeLoop()
- // and cleanup goroutine)
- // c. underlying TCP connection close
- // d. trigger Delegate OnClose()
- //
- c.stopper.Do(func() {
- c.log(LogLevelInfo, "beginning close")
- close(c.exitChan)
- c.conn.CloseRead()
- c.wg.Add(1)
- go c.cleanup()
- go c.waitForCleanup()
- })
- }
- func (c *Conn) cleanup() {
- <-c.drainReady
- ticker := time.NewTicker(100 * time.Millisecond)
- lastWarning := time.Now()
- // writeLoop has exited, drain any remaining in flight messages
- for {
- // we're racing with readLoop which potentially has a message
- // for handling so infinitely loop until messagesInFlight == 0
- // and readLoop has exited
- var msgsInFlight int64
- select {
- case <-c.msgResponseChan:
- msgsInFlight = atomic.AddInt64(&c.messagesInFlight, -1)
- case <-ticker.C:
- msgsInFlight = atomic.LoadInt64(&c.messagesInFlight)
- }
- if msgsInFlight > 0 {
- if time.Now().Sub(lastWarning) > time.Second {
- c.log(LogLevelWarning, "draining... waiting for %d messages in flight", msgsInFlight)
- lastWarning = time.Now()
- }
- continue
- }
- // until the readLoop has exited we cannot be sure that there
- // still won't be a race
- if atomic.LoadInt32(&c.readLoopRunning) == 1 {
- if time.Now().Sub(lastWarning) > time.Second {
- c.log(LogLevelWarning, "draining... readLoop still running")
- lastWarning = time.Now()
- }
- continue
- }
- goto exit
- }
- exit:
- ticker.Stop()
- c.wg.Done()
- c.log(LogLevelInfo, "finished draining, cleanup exiting")
- }
- func (c *Conn) waitForCleanup() {
- // this blocks until readLoop and writeLoop
- // (and cleanup goroutine above) have exited
- c.wg.Wait()
- c.conn.CloseWrite()
- c.log(LogLevelInfo, "clean close complete")
- c.delegate.OnClose(c)
- }
- func (c *Conn) onMessageFinish(m *Message) {
- c.msgResponseChan <- &msgResponse{msg: m, cmd: Finish(m.ID), success: true}
- }
- func (c *Conn) onMessageRequeue(m *Message, delay time.Duration, backoff bool) {
- if delay == -1 {
- // linear delay
- delay = c.config.DefaultRequeueDelay * time.Duration(m.Attempts)
- // bound the requeueDelay to configured max
- if delay > c.config.MaxRequeueDelay {
- delay = c.config.MaxRequeueDelay
- }
- }
- c.msgResponseChan <- &msgResponse{msg: m, cmd: Requeue(m.ID, delay), success: false, backoff: backoff}
- }
- func (c *Conn) onMessageTouch(m *Message) {
- select {
- case c.cmdChan <- Touch(m.ID):
- case <-c.exitChan:
- }
- }
- func (c *Conn) log(lvl LogLevel, line string, args ...interface{}) {
- logger, logLvl, logFmt := c.getLogger(lvl)
- if logger == nil {
- return
- }
- if logLvl > lvl {
- return
- }
- logger.Output(2, fmt.Sprintf("%-4s %s %s", lvl,
- fmt.Sprintf(logFmt, c.String()),
- fmt.Sprintf(line, args...)))
- }
|