package nsqd import ( "bytes" "encoding/binary" "encoding/json" "errors" "fmt" "io" "math/rand" "net" "sync/atomic" "time" "unsafe" "github.com/nsqio/nsq/internal/protocol" "github.com/nsqio/nsq/internal/version" ) const maxTimeout = time.Hour const ( frameTypeResponse int32 = 0 frameTypeError int32 = 1 frameTypeMessage int32 = 2 ) var separatorBytes = []byte(" ") var heartbeatBytes = []byte("_heartbeat_") var okBytes = []byte("OK") type protocolV2 struct { nsqd *NSQD } func (p *protocolV2) NewClient(conn net.Conn) protocol.Client { clientID := atomic.AddInt64(&p.nsqd.clientIDSequence, 1) return newClientV2(clientID, conn, p.nsqd) } func (p *protocolV2) IOLoop(c protocol.Client) error { var err error var line []byte var zeroTime time.Time client := c.(*clientV2) // synchronize the startup of messagePump in order // to guarantee that it gets a chance to initialize // goroutine local state derived from client attributes // and avoid a potential race with IDENTIFY (where a client // could have changed or disabled said attributes) messagePumpStartedChan := make(chan bool) go p.messagePump(client, messagePumpStartedChan) <-messagePumpStartedChan for { if client.HeartbeatInterval > 0 { client.SetReadDeadline(time.Now().Add(client.HeartbeatInterval * 2)) } else { client.SetReadDeadline(zeroTime) } // ReadSlice does not allocate new space for the data each request // ie. the returned slice is only valid until the next call to it line, err = client.Reader.ReadSlice('\n') if err != nil { if err == io.EOF { err = nil } else { err = fmt.Errorf("failed to read command - %s", err) } break } // trim the '\n' line = line[:len(line)-1] // optionally trim the '\r' if len(line) > 0 && line[len(line)-1] == '\r' { line = line[:len(line)-1] } params := bytes.Split(line, separatorBytes) p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): [%s] %s", client, params) var response []byte response, err = p.Exec(client, params) if err != nil { ctx := "" if parentErr := err.(protocol.ChildErr).Parent(); parentErr != nil { ctx = " - " + parentErr.Error() } p.nsqd.logf(LOG_ERROR, "[%s] - %s%s", client, err, ctx) sendErr := p.Send(client, frameTypeError, []byte(err.Error())) if sendErr != nil { p.nsqd.logf(LOG_ERROR, "[%s] - %s%s", client, sendErr, ctx) break } // errors of type FatalClientErr should forceably close the connection if _, ok := err.(*protocol.FatalClientErr); ok { break } continue } if response != nil { err = p.Send(client, frameTypeResponse, response) if err != nil { err = fmt.Errorf("failed to send response - %s", err) break } } } p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] exiting ioloop", client) close(client.ExitChan) if client.Channel != nil { client.Channel.RemoveClient(client.ID) } return err } func (p *protocolV2) SendMessage(client *clientV2, msg *Message) error { p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): writing msg(%s) to client(%s) - %s", msg.ID, client, msg.Body) buf := bufferPoolGet() defer bufferPoolPut(buf) _, err := msg.WriteTo(buf) if err != nil { return err } err = p.Send(client, frameTypeMessage, buf.Bytes()) if err != nil { return err } return nil } func (p *protocolV2) Send(client *clientV2, frameType int32, data []byte) error { client.writeLock.Lock() var zeroTime time.Time if client.HeartbeatInterval > 0 { client.SetWriteDeadline(time.Now().Add(client.HeartbeatInterval)) } else { client.SetWriteDeadline(zeroTime) } _, err := protocol.SendFramedResponse(client.Writer, frameType, data) if err != nil { client.writeLock.Unlock() return err } if frameType != frameTypeMessage { err = client.Flush() } client.writeLock.Unlock() return err } func (p *protocolV2) Exec(client *clientV2, params [][]byte) ([]byte, error) { if bytes.Equal(params[0], []byte("IDENTIFY")) { return p.IDENTIFY(client, params) } err := enforceTLSPolicy(client, p, params[0]) if err != nil { return nil, err } switch { case bytes.Equal(params[0], []byte("FIN")): return p.FIN(client, params) case bytes.Equal(params[0], []byte("RDY")): return p.RDY(client, params) case bytes.Equal(params[0], []byte("REQ")): return p.REQ(client, params) case bytes.Equal(params[0], []byte("PUB")): return p.PUB(client, params) case bytes.Equal(params[0], []byte("MPUB")): return p.MPUB(client, params) case bytes.Equal(params[0], []byte("DPUB")): return p.DPUB(client, params) case bytes.Equal(params[0], []byte("NOP")): return p.NOP(client, params) case bytes.Equal(params[0], []byte("TOUCH")): return p.TOUCH(client, params) case bytes.Equal(params[0], []byte("SUB")): return p.SUB(client, params) case bytes.Equal(params[0], []byte("CLS")): return p.CLS(client, params) case bytes.Equal(params[0], []byte("AUTH")): return p.AUTH(client, params) } return nil, protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("invalid command %s", params[0])) } func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) { var err error var memoryMsgChan chan *Message var backendMsgChan <-chan []byte var subChannel *Channel // NOTE: `flusherChan` is used to bound message latency for // the pathological case of a channel on a low volume topic // with >1 clients having >1 RDY counts var flusherChan <-chan time.Time var sampleRate int32 subEventChan := client.SubEventChan identifyEventChan := client.IdentifyEventChan outputBufferTicker := time.NewTicker(client.OutputBufferTimeout) heartbeatTicker := time.NewTicker(client.HeartbeatInterval) heartbeatChan := heartbeatTicker.C msgTimeout := client.MsgTimeout // v2 opportunistically buffers data to clients to reduce write system calls // we force flush in two cases: // 1. when the client is not ready to receive messages // 2. we're buffered and the channel has nothing left to send us // (ie. we would block in this loop anyway) // flushed := true // signal to the goroutine that started the messagePump // that we've started up close(startedChan) for { if subChannel == nil || !client.IsReadyForMessages() { // the client is not ready to receive messages... memoryMsgChan = nil backendMsgChan = nil flusherChan = nil // force flush client.writeLock.Lock() err = client.Flush() client.writeLock.Unlock() if err != nil { goto exit } flushed = true } else if flushed { // last iteration we flushed... // do not select on the flusher ticker channel memoryMsgChan = subChannel.memoryMsgChan backendMsgChan = subChannel.backend.ReadChan() flusherChan = nil } else { // we're buffered (if there isn't any more data we should flush)... // select on the flusher ticker channel, too memoryMsgChan = subChannel.memoryMsgChan backendMsgChan = subChannel.backend.ReadChan() flusherChan = outputBufferTicker.C } select { case <-flusherChan: // if this case wins, we're either starved // or we won the race between other channels... // in either case, force flush client.writeLock.Lock() err = client.Flush() client.writeLock.Unlock() if err != nil { goto exit } flushed = true case <-client.ReadyStateChan: case subChannel = <-subEventChan: // you can't SUB anymore subEventChan = nil case identifyData := <-identifyEventChan: // you can't IDENTIFY anymore identifyEventChan = nil outputBufferTicker.Stop() if identifyData.OutputBufferTimeout > 0 { outputBufferTicker = time.NewTicker(identifyData.OutputBufferTimeout) } heartbeatTicker.Stop() heartbeatChan = nil if identifyData.HeartbeatInterval > 0 { heartbeatTicker = time.NewTicker(identifyData.HeartbeatInterval) heartbeatChan = heartbeatTicker.C } if identifyData.SampleRate > 0 { sampleRate = identifyData.SampleRate } msgTimeout = identifyData.MsgTimeout case <-heartbeatChan: err = p.Send(client, frameTypeResponse, heartbeatBytes) if err != nil { goto exit } case b := <-backendMsgChan: if sampleRate > 0 && rand.Int31n(100) > sampleRate { continue } msg, err := decodeMessage(b) if err != nil { p.nsqd.logf(LOG_ERROR, "failed to decode message - %s", err) continue } msg.Attempts++ subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout) client.SendingMessage() err = p.SendMessage(client, msg) if err != nil { goto exit } flushed = false case msg := <-memoryMsgChan: if sampleRate > 0 && rand.Int31n(100) > sampleRate { continue } msg.Attempts++ subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout) client.SendingMessage() err = p.SendMessage(client, msg) if err != nil { goto exit } flushed = false case <-client.ExitChan: goto exit } } exit: p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] exiting messagePump", client) heartbeatTicker.Stop() outputBufferTicker.Stop() if err != nil { p.nsqd.logf(LOG_ERROR, "PROTOCOL(V2): [%s] messagePump error - %s", client, err) } } func (p *protocolV2) IDENTIFY(client *clientV2, params [][]byte) ([]byte, error) { var err error if atomic.LoadInt32(&client.State) != stateInit { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot IDENTIFY in current state") } bodyLen, err := readLen(client.Reader, client.lenSlice) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body size") } if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize { return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY", fmt.Sprintf("IDENTIFY body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize)) } if bodyLen <= 0 { return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY", fmt.Sprintf("IDENTIFY invalid body size %d", bodyLen)) } body := make([]byte, bodyLen) _, err = io.ReadFull(client.Reader, body) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body") } // body is a json structure with producer information var identifyData identifyDataV2 err = json.Unmarshal(body, &identifyData) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to decode JSON body") } p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): [%s] %+v", client, identifyData) err = client.Identify(identifyData) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY "+err.Error()) } // bail out early if we're not negotiating features if !identifyData.FeatureNegotiation { return okBytes, nil } tlsv1 := p.nsqd.tlsConfig != nil && identifyData.TLSv1 deflate := p.nsqd.getOpts().DeflateEnabled && identifyData.Deflate deflateLevel := 6 if deflate && identifyData.DeflateLevel > 0 { deflateLevel = identifyData.DeflateLevel } if max := p.nsqd.getOpts().MaxDeflateLevel; max < deflateLevel { deflateLevel = max } snappy := p.nsqd.getOpts().SnappyEnabled && identifyData.Snappy if deflate && snappy { return nil, protocol.NewFatalClientErr(nil, "E_IDENTIFY_FAILED", "cannot enable both deflate and snappy compression") } resp, err := json.Marshal(struct { MaxRdyCount int64 `json:"max_rdy_count"` Version string `json:"version"` MaxMsgTimeout int64 `json:"max_msg_timeout"` MsgTimeout int64 `json:"msg_timeout"` TLSv1 bool `json:"tls_v1"` Deflate bool `json:"deflate"` DeflateLevel int `json:"deflate_level"` MaxDeflateLevel int `json:"max_deflate_level"` Snappy bool `json:"snappy"` SampleRate int32 `json:"sample_rate"` AuthRequired bool `json:"auth_required"` OutputBufferSize int `json:"output_buffer_size"` OutputBufferTimeout int64 `json:"output_buffer_timeout"` }{ MaxRdyCount: p.nsqd.getOpts().MaxRdyCount, Version: version.Binary, MaxMsgTimeout: int64(p.nsqd.getOpts().MaxMsgTimeout / time.Millisecond), MsgTimeout: int64(client.MsgTimeout / time.Millisecond), TLSv1: tlsv1, Deflate: deflate, DeflateLevel: deflateLevel, MaxDeflateLevel: p.nsqd.getOpts().MaxDeflateLevel, Snappy: snappy, SampleRate: client.SampleRate, AuthRequired: p.nsqd.IsAuthEnabled(), OutputBufferSize: client.OutputBufferSize, OutputBufferTimeout: int64(client.OutputBufferTimeout / time.Millisecond), }) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) } err = p.Send(client, frameTypeResponse, resp) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) } if tlsv1 { p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to TLS", client) err = client.UpgradeTLS() if err != nil { return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) } err = p.Send(client, frameTypeResponse, okBytes) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) } } if snappy { p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to snappy", client) err = client.UpgradeSnappy() if err != nil { return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) } err = p.Send(client, frameTypeResponse, okBytes) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) } } if deflate { p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to deflate (level %d)", client, deflateLevel) err = client.UpgradeDeflate(deflateLevel) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) } err = p.Send(client, frameTypeResponse, okBytes) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) } } return nil, nil } func (p *protocolV2) AUTH(client *clientV2, params [][]byte) ([]byte, error) { if atomic.LoadInt32(&client.State) != stateInit { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot AUTH in current state") } if len(params) != 1 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH invalid number of parameters") } bodyLen, err := readLen(client.Reader, client.lenSlice) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body size") } if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize { return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY", fmt.Sprintf("AUTH body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize)) } if bodyLen <= 0 { return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY", fmt.Sprintf("AUTH invalid body size %d", bodyLen)) } body := make([]byte, bodyLen) _, err = io.ReadFull(client.Reader, body) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body") } if client.HasAuthorizations() { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH already set") } if !client.nsqd.IsAuthEnabled() { return nil, protocol.NewFatalClientErr(err, "E_AUTH_DISABLED", "AUTH disabled") } if err := client.Auth(string(body)); err != nil { // we don't want to leak errors contacting the auth server to untrusted clients p.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] AUTH failed %s", client, err) return nil, protocol.NewFatalClientErr(err, "E_AUTH_FAILED", "AUTH failed") } if !client.HasAuthorizations() { return nil, protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", "AUTH no authorizations found") } resp, err := json.Marshal(struct { Identity string `json:"identity"` IdentityURL string `json:"identity_url"` PermissionCount int `json:"permission_count"` }{ Identity: client.AuthState.Identity, IdentityURL: client.AuthState.IdentityURL, PermissionCount: len(client.AuthState.Authorizations), }) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error()) } err = p.Send(client, frameTypeResponse, resp) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error()) } return nil, nil } func (p *protocolV2) CheckAuth(client *clientV2, cmd, topicName, channelName string) error { // if auth is enabled, the client must have authorized already // compare topic/channel against cached authorization data (refetching if expired) if client.nsqd.IsAuthEnabled() { if !client.HasAuthorizations() { return protocol.NewFatalClientErr(nil, "E_AUTH_FIRST", fmt.Sprintf("AUTH required before %s", cmd)) } ok, err := client.IsAuthorized(topicName, channelName) if err != nil { // we don't want to leak errors contacting the auth server to untrusted clients p.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] AUTH failed %s", client, err) return protocol.NewFatalClientErr(nil, "E_AUTH_FAILED", "AUTH failed") } if !ok { return protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", fmt.Sprintf("AUTH failed for %s on %q %q", cmd, topicName, channelName)) } } return nil } func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) { if atomic.LoadInt32(&client.State) != stateInit { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB in current state") } if client.HeartbeatInterval <= 0 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB with heartbeats disabled") } if len(params) < 3 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "SUB insufficient number of parameters") } topicName := string(params[1]) if !protocol.IsValidTopicName(topicName) { return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC", fmt.Sprintf("SUB topic name %q is not valid", topicName)) } channelName := string(params[2]) if !protocol.IsValidChannelName(channelName) { return nil, protocol.NewFatalClientErr(nil, "E_BAD_CHANNEL", fmt.Sprintf("SUB channel name %q is not valid", channelName)) } if err := p.CheckAuth(client, "SUB", topicName, channelName); err != nil { return nil, err } // This retry-loop is a work-around for a race condition, where the // last client can leave the channel between GetChannel() and AddClient(). // Avoid adding a client to an ephemeral channel / topic which has started exiting. var channel *Channel for i := 1; ; i++ { topic := p.nsqd.GetTopic(topicName) channel = topic.GetChannel(channelName) if err := channel.AddClient(client.ID, client); err != nil { return nil, protocol.NewFatalClientErr(err, "E_SUB_FAILED", "SUB failed "+err.Error()) } if (channel.ephemeral && channel.Exiting()) || (topic.ephemeral && topic.Exiting()) { channel.RemoveClient(client.ID) if i < 2 { time.Sleep(100 * time.Millisecond) continue } return nil, protocol.NewFatalClientErr(nil, "E_SUB_FAILED", "SUB failed to deleted topic/channel") } break } atomic.StoreInt32(&client.State, stateSubscribed) client.Channel = channel // update message pump client.SubEventChan <- channel return okBytes, nil } func (p *protocolV2) RDY(client *clientV2, params [][]byte) ([]byte, error) { state := atomic.LoadInt32(&client.State) if state == stateClosing { // just ignore ready changes on a closing channel p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] ignoring RDY after CLS in state ClientStateV2Closing", client) return nil, nil } if state != stateSubscribed { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot RDY in current state") } count := int64(1) if len(params) > 1 { b10, err := protocol.ByteToBase10(params[1]) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_INVALID", fmt.Sprintf("RDY could not parse count %s", params[1])) } count = int64(b10) } if count < 0 || count > p.nsqd.getOpts().MaxRdyCount { // this needs to be a fatal error otherwise clients would have // inconsistent state return nil, protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("RDY count %d out of range 0-%d", count, p.nsqd.getOpts().MaxRdyCount)) } client.SetReadyCount(count) return nil, nil } func (p *protocolV2) FIN(client *clientV2, params [][]byte) ([]byte, error) { state := atomic.LoadInt32(&client.State) if state != stateSubscribed && state != stateClosing { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot FIN in current state") } if len(params) < 2 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "FIN insufficient number of params") } id, err := getMessageID(params[1]) if err != nil { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error()) } err = client.Channel.FinishMessage(client.ID, *id) if err != nil { return nil, protocol.NewClientErr(err, "E_FIN_FAILED", fmt.Sprintf("FIN %s failed %s", *id, err.Error())) } client.FinishedMessage() return nil, nil } func (p *protocolV2) REQ(client *clientV2, params [][]byte) ([]byte, error) { state := atomic.LoadInt32(&client.State) if state != stateSubscribed && state != stateClosing { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot REQ in current state") } if len(params) < 3 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "REQ insufficient number of params") } id, err := getMessageID(params[1]) if err != nil { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error()) } timeoutMs, err := protocol.ByteToBase10(params[2]) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_INVALID", fmt.Sprintf("REQ could not parse timeout %s", params[2])) } timeoutDuration := time.Duration(timeoutMs) * time.Millisecond maxReqTimeout := p.nsqd.getOpts().MaxReqTimeout clampedTimeout := timeoutDuration if timeoutDuration < 0 { clampedTimeout = 0 } else if timeoutDuration > maxReqTimeout { clampedTimeout = maxReqTimeout } if clampedTimeout != timeoutDuration { p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] REQ timeout %d out of range 0-%d. Setting to %d", client, timeoutDuration, maxReqTimeout, clampedTimeout) timeoutDuration = clampedTimeout } err = client.Channel.RequeueMessage(client.ID, *id, timeoutDuration) if err != nil { return nil, protocol.NewClientErr(err, "E_REQ_FAILED", fmt.Sprintf("REQ %s failed %s", *id, err.Error())) } client.RequeuedMessage() return nil, nil } func (p *protocolV2) CLS(client *clientV2, params [][]byte) ([]byte, error) { if atomic.LoadInt32(&client.State) != stateSubscribed { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot CLS in current state") } client.StartClose() return []byte("CLOSE_WAIT"), nil } func (p *protocolV2) NOP(client *clientV2, params [][]byte) ([]byte, error) { return nil, nil } func (p *protocolV2) PUB(client *clientV2, params [][]byte) ([]byte, error) { var err error if len(params) < 2 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "PUB insufficient number of parameters") } topicName := string(params[1]) if !protocol.IsValidTopicName(topicName) { return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC", fmt.Sprintf("PUB topic name %q is not valid", topicName)) } bodyLen, err := readLen(client.Reader, client.lenSlice) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body size") } if bodyLen <= 0 { return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE", fmt.Sprintf("PUB invalid message body size %d", bodyLen)) } if int64(bodyLen) > p.nsqd.getOpts().MaxMsgSize { return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE", fmt.Sprintf("PUB message too big %d > %d", bodyLen, p.nsqd.getOpts().MaxMsgSize)) } messageBody := make([]byte, bodyLen) _, err = io.ReadFull(client.Reader, messageBody) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body") } if err := p.CheckAuth(client, "PUB", topicName, ""); err != nil { return nil, err } topic := p.nsqd.GetTopic(topicName) msg := NewMessage(topic.GenerateID(), messageBody) err = topic.PutMessage(msg) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed "+err.Error()) } client.PublishedMessage(topicName, 1) return okBytes, nil } func (p *protocolV2) MPUB(client *clientV2, params [][]byte) ([]byte, error) { var err error if len(params) < 2 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "MPUB insufficient number of parameters") } topicName := string(params[1]) if !protocol.IsValidTopicName(topicName) { return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC", fmt.Sprintf("E_BAD_TOPIC MPUB topic name %q is not valid", topicName)) } if err := p.CheckAuth(client, "MPUB", topicName, ""); err != nil { return nil, err } topic := p.nsqd.GetTopic(topicName) bodyLen, err := readLen(client.Reader, client.lenSlice) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read body size") } if bodyLen <= 0 { return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY", fmt.Sprintf("MPUB invalid body size %d", bodyLen)) } if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize { return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY", fmt.Sprintf("MPUB body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize)) } messages, err := readMPUB(client.Reader, client.lenSlice, topic, p.nsqd.getOpts().MaxMsgSize, p.nsqd.getOpts().MaxBodySize) if err != nil { return nil, err } // if we've made it this far we've validated all the input, // the only possible error is that the topic is exiting during // this next call (and no messages will be queued in that case) err = topic.PutMessages(messages) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_MPUB_FAILED", "MPUB failed "+err.Error()) } client.PublishedMessage(topicName, uint64(len(messages))) return okBytes, nil } func (p *protocolV2) DPUB(client *clientV2, params [][]byte) ([]byte, error) { var err error if len(params) < 3 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "DPUB insufficient number of parameters") } topicName := string(params[1]) if !protocol.IsValidTopicName(topicName) { return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC", fmt.Sprintf("DPUB topic name %q is not valid", topicName)) } timeoutMs, err := protocol.ByteToBase10(params[2]) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_INVALID", fmt.Sprintf("DPUB could not parse timeout %s", params[2])) } timeoutDuration := time.Duration(timeoutMs) * time.Millisecond if timeoutDuration < 0 || timeoutDuration > p.nsqd.getOpts().MaxReqTimeout { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("DPUB timeout %d out of range 0-%d", timeoutMs, p.nsqd.getOpts().MaxReqTimeout/time.Millisecond)) } bodyLen, err := readLen(client.Reader, client.lenSlice) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body size") } if bodyLen <= 0 { return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE", fmt.Sprintf("DPUB invalid message body size %d", bodyLen)) } if int64(bodyLen) > p.nsqd.getOpts().MaxMsgSize { return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE", fmt.Sprintf("DPUB message too big %d > %d", bodyLen, p.nsqd.getOpts().MaxMsgSize)) } messageBody := make([]byte, bodyLen) _, err = io.ReadFull(client.Reader, messageBody) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body") } if err := p.CheckAuth(client, "DPUB", topicName, ""); err != nil { return nil, err } topic := p.nsqd.GetTopic(topicName) msg := NewMessage(topic.GenerateID(), messageBody) msg.deferred = timeoutDuration err = topic.PutMessage(msg) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_DPUB_FAILED", "DPUB failed "+err.Error()) } client.PublishedMessage(topicName, 1) return okBytes, nil } func (p *protocolV2) TOUCH(client *clientV2, params [][]byte) ([]byte, error) { state := atomic.LoadInt32(&client.State) if state != stateSubscribed && state != stateClosing { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot TOUCH in current state") } if len(params) < 2 { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "TOUCH insufficient number of params") } id, err := getMessageID(params[1]) if err != nil { return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error()) } client.writeLock.RLock() msgTimeout := client.MsgTimeout client.writeLock.RUnlock() err = client.Channel.TouchMessage(client.ID, *id, msgTimeout) if err != nil { return nil, protocol.NewClientErr(err, "E_TOUCH_FAILED", fmt.Sprintf("TOUCH %s failed %s", *id, err.Error())) } return nil, nil } func readMPUB(r io.Reader, tmp []byte, topic *Topic, maxMessageSize int64, maxBodySize int64) ([]*Message, error) { numMessages, err := readLen(r, tmp) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read message count") } // 4 == total num, 5 == length + min 1 maxMessages := (maxBodySize - 4) / 5 if numMessages <= 0 || int64(numMessages) > maxMessages { return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", fmt.Sprintf("MPUB invalid message count %d", numMessages)) } messages := make([]*Message, 0, numMessages) for i := int32(0); i < numMessages; i++ { messageSize, err := readLen(r, tmp) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", fmt.Sprintf("MPUB failed to read message(%d) body size", i)) } if messageSize <= 0 { return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE", fmt.Sprintf("MPUB invalid message(%d) body size %d", i, messageSize)) } if int64(messageSize) > maxMessageSize { return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE", fmt.Sprintf("MPUB message too big %d > %d", messageSize, maxMessageSize)) } msgBody := make([]byte, messageSize) _, err = io.ReadFull(r, msgBody) if err != nil { return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "MPUB failed to read message body") } messages = append(messages, NewMessage(topic.GenerateID(), msgBody)) } return messages, nil } // validate and cast the bytes on the wire to a message ID func getMessageID(p []byte) (*MessageID, error) { if len(p) != MsgIDLength { return nil, errors.New("Invalid Message ID") } return (*MessageID)(unsafe.Pointer(&p[0])), nil } func readLen(r io.Reader, tmp []byte) (int32, error) { _, err := io.ReadFull(r, tmp) if err != nil { return 0, err } return int32(binary.BigEndian.Uint32(tmp)), nil } func enforceTLSPolicy(client *clientV2, p *protocolV2, command []byte) error { if p.nsqd.getOpts().TLSRequired != TLSNotRequired && atomic.LoadInt32(&client.TLS) != 1 { return protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("cannot %s in current state (TLS required)", command)) } return nil }