protocol_v2.go 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026
  1. package nsqd
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math/rand"
  10. "net"
  11. "sync/atomic"
  12. "time"
  13. "unsafe"
  14. "github.com/nsqio/nsq/internal/protocol"
  15. "github.com/nsqio/nsq/internal/version"
  16. )
  17. const maxTimeout = time.Hour
  18. const (
  19. frameTypeResponse int32 = 0
  20. frameTypeError int32 = 1
  21. frameTypeMessage int32 = 2
  22. )
  23. var separatorBytes = []byte(" ")
  24. var heartbeatBytes = []byte("_heartbeat_")
  25. var okBytes = []byte("OK")
  26. type protocolV2 struct {
  27. nsqd *NSQD
  28. }
  29. func (p *protocolV2) NewClient(conn net.Conn) protocol.Client {
  30. clientID := atomic.AddInt64(&p.nsqd.clientIDSequence, 1)
  31. return newClientV2(clientID, conn, p.nsqd)
  32. }
  33. func (p *protocolV2) IOLoop(c protocol.Client) error {
  34. var err error
  35. var line []byte
  36. var zeroTime time.Time
  37. client := c.(*clientV2)
  38. // synchronize the startup of messagePump in order
  39. // to guarantee that it gets a chance to initialize
  40. // goroutine local state derived from client attributes
  41. // and avoid a potential race with IDENTIFY (where a client
  42. // could have changed or disabled said attributes)
  43. messagePumpStartedChan := make(chan bool)
  44. go p.messagePump(client, messagePumpStartedChan)
  45. <-messagePumpStartedChan
  46. for {
  47. if client.HeartbeatInterval > 0 {
  48. client.SetReadDeadline(time.Now().Add(client.HeartbeatInterval * 2))
  49. } else {
  50. client.SetReadDeadline(zeroTime)
  51. }
  52. // ReadSlice does not allocate new space for the data each request
  53. // ie. the returned slice is only valid until the next call to it
  54. line, err = client.Reader.ReadSlice('\n')
  55. if err != nil {
  56. if err == io.EOF {
  57. err = nil
  58. } else {
  59. err = fmt.Errorf("failed to read command - %s", err)
  60. }
  61. break
  62. }
  63. // trim the '\n'
  64. line = line[:len(line)-1]
  65. // optionally trim the '\r'
  66. if len(line) > 0 && line[len(line)-1] == '\r' {
  67. line = line[:len(line)-1]
  68. }
  69. params := bytes.Split(line, separatorBytes)
  70. p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): [%s] %s", client, params)
  71. var response []byte
  72. response, err = p.Exec(client, params)
  73. if err != nil {
  74. ctx := ""
  75. if parentErr := err.(protocol.ChildErr).Parent(); parentErr != nil {
  76. ctx = " - " + parentErr.Error()
  77. }
  78. p.nsqd.logf(LOG_ERROR, "[%s] - %s%s", client, err, ctx)
  79. sendErr := p.Send(client, frameTypeError, []byte(err.Error()))
  80. if sendErr != nil {
  81. p.nsqd.logf(LOG_ERROR, "[%s] - %s%s", client, sendErr, ctx)
  82. break
  83. }
  84. // errors of type FatalClientErr should forceably close the connection
  85. if _, ok := err.(*protocol.FatalClientErr); ok {
  86. break
  87. }
  88. continue
  89. }
  90. if response != nil {
  91. err = p.Send(client, frameTypeResponse, response)
  92. if err != nil {
  93. err = fmt.Errorf("failed to send response - %s", err)
  94. break
  95. }
  96. }
  97. }
  98. p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] exiting ioloop", client)
  99. close(client.ExitChan)
  100. if client.Channel != nil {
  101. client.Channel.RemoveClient(client.ID)
  102. }
  103. return err
  104. }
  105. func (p *protocolV2) SendMessage(client *clientV2, msg *Message) error {
  106. p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): writing msg(%s) to client(%s) - %s", msg.ID, client, msg.Body)
  107. buf := bufferPoolGet()
  108. defer bufferPoolPut(buf)
  109. _, err := msg.WriteTo(buf)
  110. if err != nil {
  111. return err
  112. }
  113. err = p.Send(client, frameTypeMessage, buf.Bytes())
  114. if err != nil {
  115. return err
  116. }
  117. return nil
  118. }
  119. func (p *protocolV2) Send(client *clientV2, frameType int32, data []byte) error {
  120. client.writeLock.Lock()
  121. var zeroTime time.Time
  122. if client.HeartbeatInterval > 0 {
  123. client.SetWriteDeadline(time.Now().Add(client.HeartbeatInterval))
  124. } else {
  125. client.SetWriteDeadline(zeroTime)
  126. }
  127. _, err := protocol.SendFramedResponse(client.Writer, frameType, data)
  128. if err != nil {
  129. client.writeLock.Unlock()
  130. return err
  131. }
  132. if frameType != frameTypeMessage {
  133. err = client.Flush()
  134. }
  135. client.writeLock.Unlock()
  136. return err
  137. }
  138. func (p *protocolV2) Exec(client *clientV2, params [][]byte) ([]byte, error) {
  139. if bytes.Equal(params[0], []byte("IDENTIFY")) {
  140. return p.IDENTIFY(client, params)
  141. }
  142. err := enforceTLSPolicy(client, p, params[0])
  143. if err != nil {
  144. return nil, err
  145. }
  146. switch {
  147. case bytes.Equal(params[0], []byte("FIN")):
  148. return p.FIN(client, params)
  149. case bytes.Equal(params[0], []byte("RDY")):
  150. return p.RDY(client, params)
  151. case bytes.Equal(params[0], []byte("REQ")):
  152. return p.REQ(client, params)
  153. case bytes.Equal(params[0], []byte("PUB")):
  154. return p.PUB(client, params)
  155. case bytes.Equal(params[0], []byte("MPUB")):
  156. return p.MPUB(client, params)
  157. case bytes.Equal(params[0], []byte("DPUB")):
  158. return p.DPUB(client, params)
  159. case bytes.Equal(params[0], []byte("NOP")):
  160. return p.NOP(client, params)
  161. case bytes.Equal(params[0], []byte("TOUCH")):
  162. return p.TOUCH(client, params)
  163. case bytes.Equal(params[0], []byte("SUB")):
  164. return p.SUB(client, params)
  165. case bytes.Equal(params[0], []byte("CLS")):
  166. return p.CLS(client, params)
  167. case bytes.Equal(params[0], []byte("AUTH")):
  168. return p.AUTH(client, params)
  169. }
  170. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("invalid command %s", params[0]))
  171. }
  172. func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) {
  173. var err error
  174. var memoryMsgChan chan *Message
  175. var backendMsgChan <-chan []byte
  176. var subChannel *Channel
  177. // NOTE: `flusherChan` is used to bound message latency for
  178. // the pathological case of a channel on a low volume topic
  179. // with >1 clients having >1 RDY counts
  180. var flusherChan <-chan time.Time
  181. var sampleRate int32
  182. subEventChan := client.SubEventChan
  183. identifyEventChan := client.IdentifyEventChan
  184. outputBufferTicker := time.NewTicker(client.OutputBufferTimeout)
  185. heartbeatTicker := time.NewTicker(client.HeartbeatInterval)
  186. heartbeatChan := heartbeatTicker.C
  187. msgTimeout := client.MsgTimeout
  188. // v2 opportunistically buffers data to clients to reduce write system calls
  189. // we force flush in two cases:
  190. // 1. when the client is not ready to receive messages
  191. // 2. we're buffered and the channel has nothing left to send us
  192. // (ie. we would block in this loop anyway)
  193. //
  194. flushed := true
  195. // signal to the goroutine that started the messagePump
  196. // that we've started up
  197. close(startedChan)
  198. for {
  199. if subChannel == nil || !client.IsReadyForMessages() {
  200. // the client is not ready to receive messages...
  201. memoryMsgChan = nil
  202. backendMsgChan = nil
  203. flusherChan = nil
  204. // force flush
  205. client.writeLock.Lock()
  206. err = client.Flush()
  207. client.writeLock.Unlock()
  208. if err != nil {
  209. goto exit
  210. }
  211. flushed = true
  212. } else if flushed {
  213. // last iteration we flushed...
  214. // do not select on the flusher ticker channel
  215. memoryMsgChan = subChannel.memoryMsgChan
  216. backendMsgChan = subChannel.backend.ReadChan()
  217. flusherChan = nil
  218. } else {
  219. // we're buffered (if there isn't any more data we should flush)...
  220. // select on the flusher ticker channel, too
  221. memoryMsgChan = subChannel.memoryMsgChan
  222. backendMsgChan = subChannel.backend.ReadChan()
  223. flusherChan = outputBufferTicker.C
  224. }
  225. select {
  226. case <-flusherChan:
  227. // if this case wins, we're either starved
  228. // or we won the race between other channels...
  229. // in either case, force flush
  230. client.writeLock.Lock()
  231. err = client.Flush()
  232. client.writeLock.Unlock()
  233. if err != nil {
  234. goto exit
  235. }
  236. flushed = true
  237. case <-client.ReadyStateChan:
  238. case subChannel = <-subEventChan:
  239. // you can't SUB anymore
  240. subEventChan = nil
  241. case identifyData := <-identifyEventChan:
  242. // you can't IDENTIFY anymore
  243. identifyEventChan = nil
  244. outputBufferTicker.Stop()
  245. if identifyData.OutputBufferTimeout > 0 {
  246. outputBufferTicker = time.NewTicker(identifyData.OutputBufferTimeout)
  247. }
  248. heartbeatTicker.Stop()
  249. heartbeatChan = nil
  250. if identifyData.HeartbeatInterval > 0 {
  251. heartbeatTicker = time.NewTicker(identifyData.HeartbeatInterval)
  252. heartbeatChan = heartbeatTicker.C
  253. }
  254. if identifyData.SampleRate > 0 {
  255. sampleRate = identifyData.SampleRate
  256. }
  257. msgTimeout = identifyData.MsgTimeout
  258. case <-heartbeatChan:
  259. err = p.Send(client, frameTypeResponse, heartbeatBytes)
  260. if err != nil {
  261. goto exit
  262. }
  263. case b := <-backendMsgChan:
  264. if sampleRate > 0 && rand.Int31n(100) > sampleRate {
  265. continue
  266. }
  267. msg, err := decodeMessage(b)
  268. if err != nil {
  269. p.nsqd.logf(LOG_ERROR, "failed to decode message - %s", err)
  270. continue
  271. }
  272. msg.Attempts++
  273. subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
  274. client.SendingMessage()
  275. err = p.SendMessage(client, msg)
  276. if err != nil {
  277. goto exit
  278. }
  279. flushed = false
  280. case msg := <-memoryMsgChan:
  281. if sampleRate > 0 && rand.Int31n(100) > sampleRate {
  282. continue
  283. }
  284. msg.Attempts++
  285. subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
  286. client.SendingMessage()
  287. err = p.SendMessage(client, msg)
  288. if err != nil {
  289. goto exit
  290. }
  291. flushed = false
  292. case <-client.ExitChan:
  293. goto exit
  294. }
  295. }
  296. exit:
  297. p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] exiting messagePump", client)
  298. heartbeatTicker.Stop()
  299. outputBufferTicker.Stop()
  300. if err != nil {
  301. p.nsqd.logf(LOG_ERROR, "PROTOCOL(V2): [%s] messagePump error - %s", client, err)
  302. }
  303. }
  304. func (p *protocolV2) IDENTIFY(client *clientV2, params [][]byte) ([]byte, error) {
  305. var err error
  306. if atomic.LoadInt32(&client.State) != stateInit {
  307. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot IDENTIFY in current state")
  308. }
  309. bodyLen, err := readLen(client.Reader, client.lenSlice)
  310. if err != nil {
  311. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body size")
  312. }
  313. if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize {
  314. return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
  315. fmt.Sprintf("IDENTIFY body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize))
  316. }
  317. if bodyLen <= 0 {
  318. return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
  319. fmt.Sprintf("IDENTIFY invalid body size %d", bodyLen))
  320. }
  321. body := make([]byte, bodyLen)
  322. _, err = io.ReadFull(client.Reader, body)
  323. if err != nil {
  324. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body")
  325. }
  326. // body is a json structure with producer information
  327. var identifyData identifyDataV2
  328. err = json.Unmarshal(body, &identifyData)
  329. if err != nil {
  330. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to decode JSON body")
  331. }
  332. p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): [%s] %+v", client, identifyData)
  333. err = client.Identify(identifyData)
  334. if err != nil {
  335. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY "+err.Error())
  336. }
  337. // bail out early if we're not negotiating features
  338. if !identifyData.FeatureNegotiation {
  339. return okBytes, nil
  340. }
  341. tlsv1 := p.nsqd.tlsConfig != nil && identifyData.TLSv1
  342. deflate := p.nsqd.getOpts().DeflateEnabled && identifyData.Deflate
  343. deflateLevel := 6
  344. if deflate && identifyData.DeflateLevel > 0 {
  345. deflateLevel = identifyData.DeflateLevel
  346. }
  347. if max := p.nsqd.getOpts().MaxDeflateLevel; max < deflateLevel {
  348. deflateLevel = max
  349. }
  350. snappy := p.nsqd.getOpts().SnappyEnabled && identifyData.Snappy
  351. if deflate && snappy {
  352. return nil, protocol.NewFatalClientErr(nil, "E_IDENTIFY_FAILED", "cannot enable both deflate and snappy compression")
  353. }
  354. resp, err := json.Marshal(struct {
  355. MaxRdyCount int64 `json:"max_rdy_count"`
  356. Version string `json:"version"`
  357. MaxMsgTimeout int64 `json:"max_msg_timeout"`
  358. MsgTimeout int64 `json:"msg_timeout"`
  359. TLSv1 bool `json:"tls_v1"`
  360. Deflate bool `json:"deflate"`
  361. DeflateLevel int `json:"deflate_level"`
  362. MaxDeflateLevel int `json:"max_deflate_level"`
  363. Snappy bool `json:"snappy"`
  364. SampleRate int32 `json:"sample_rate"`
  365. AuthRequired bool `json:"auth_required"`
  366. OutputBufferSize int `json:"output_buffer_size"`
  367. OutputBufferTimeout int64 `json:"output_buffer_timeout"`
  368. }{
  369. MaxRdyCount: p.nsqd.getOpts().MaxRdyCount,
  370. Version: version.Binary,
  371. MaxMsgTimeout: int64(p.nsqd.getOpts().MaxMsgTimeout / time.Millisecond),
  372. MsgTimeout: int64(client.MsgTimeout / time.Millisecond),
  373. TLSv1: tlsv1,
  374. Deflate: deflate,
  375. DeflateLevel: deflateLevel,
  376. MaxDeflateLevel: p.nsqd.getOpts().MaxDeflateLevel,
  377. Snappy: snappy,
  378. SampleRate: client.SampleRate,
  379. AuthRequired: p.nsqd.IsAuthEnabled(),
  380. OutputBufferSize: client.OutputBufferSize,
  381. OutputBufferTimeout: int64(client.OutputBufferTimeout / time.Millisecond),
  382. })
  383. if err != nil {
  384. return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
  385. }
  386. err = p.Send(client, frameTypeResponse, resp)
  387. if err != nil {
  388. return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
  389. }
  390. if tlsv1 {
  391. p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to TLS", client)
  392. err = client.UpgradeTLS()
  393. if err != nil {
  394. return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
  395. }
  396. err = p.Send(client, frameTypeResponse, okBytes)
  397. if err != nil {
  398. return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
  399. }
  400. }
  401. if snappy {
  402. p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to snappy", client)
  403. err = client.UpgradeSnappy()
  404. if err != nil {
  405. return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
  406. }
  407. err = p.Send(client, frameTypeResponse, okBytes)
  408. if err != nil {
  409. return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
  410. }
  411. }
  412. if deflate {
  413. p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to deflate (level %d)", client, deflateLevel)
  414. err = client.UpgradeDeflate(deflateLevel)
  415. if err != nil {
  416. return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
  417. }
  418. err = p.Send(client, frameTypeResponse, okBytes)
  419. if err != nil {
  420. return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
  421. }
  422. }
  423. return nil, nil
  424. }
  425. func (p *protocolV2) AUTH(client *clientV2, params [][]byte) ([]byte, error) {
  426. if atomic.LoadInt32(&client.State) != stateInit {
  427. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot AUTH in current state")
  428. }
  429. if len(params) != 1 {
  430. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH invalid number of parameters")
  431. }
  432. bodyLen, err := readLen(client.Reader, client.lenSlice)
  433. if err != nil {
  434. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body size")
  435. }
  436. if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize {
  437. return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
  438. fmt.Sprintf("AUTH body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize))
  439. }
  440. if bodyLen <= 0 {
  441. return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
  442. fmt.Sprintf("AUTH invalid body size %d", bodyLen))
  443. }
  444. body := make([]byte, bodyLen)
  445. _, err = io.ReadFull(client.Reader, body)
  446. if err != nil {
  447. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body")
  448. }
  449. if client.HasAuthorizations() {
  450. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH already set")
  451. }
  452. if !client.nsqd.IsAuthEnabled() {
  453. return nil, protocol.NewFatalClientErr(err, "E_AUTH_DISABLED", "AUTH disabled")
  454. }
  455. if err := client.Auth(string(body)); err != nil {
  456. // we don't want to leak errors contacting the auth server to untrusted clients
  457. p.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] AUTH failed %s", client, err)
  458. return nil, protocol.NewFatalClientErr(err, "E_AUTH_FAILED", "AUTH failed")
  459. }
  460. if !client.HasAuthorizations() {
  461. return nil, protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", "AUTH no authorizations found")
  462. }
  463. resp, err := json.Marshal(struct {
  464. Identity string `json:"identity"`
  465. IdentityURL string `json:"identity_url"`
  466. PermissionCount int `json:"permission_count"`
  467. }{
  468. Identity: client.AuthState.Identity,
  469. IdentityURL: client.AuthState.IdentityURL,
  470. PermissionCount: len(client.AuthState.Authorizations),
  471. })
  472. if err != nil {
  473. return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
  474. }
  475. err = p.Send(client, frameTypeResponse, resp)
  476. if err != nil {
  477. return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
  478. }
  479. return nil, nil
  480. }
  481. func (p *protocolV2) CheckAuth(client *clientV2, cmd, topicName, channelName string) error {
  482. // if auth is enabled, the client must have authorized already
  483. // compare topic/channel against cached authorization data (refetching if expired)
  484. if client.nsqd.IsAuthEnabled() {
  485. if !client.HasAuthorizations() {
  486. return protocol.NewFatalClientErr(nil, "E_AUTH_FIRST",
  487. fmt.Sprintf("AUTH required before %s", cmd))
  488. }
  489. ok, err := client.IsAuthorized(topicName, channelName)
  490. if err != nil {
  491. // we don't want to leak errors contacting the auth server to untrusted clients
  492. p.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] AUTH failed %s", client, err)
  493. return protocol.NewFatalClientErr(nil, "E_AUTH_FAILED", "AUTH failed")
  494. }
  495. if !ok {
  496. return protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED",
  497. fmt.Sprintf("AUTH failed for %s on %q %q", cmd, topicName, channelName))
  498. }
  499. }
  500. return nil
  501. }
  502. func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) {
  503. if atomic.LoadInt32(&client.State) != stateInit {
  504. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB in current state")
  505. }
  506. if client.HeartbeatInterval <= 0 {
  507. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB with heartbeats disabled")
  508. }
  509. if len(params) < 3 {
  510. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "SUB insufficient number of parameters")
  511. }
  512. topicName := string(params[1])
  513. if !protocol.IsValidTopicName(topicName) {
  514. return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
  515. fmt.Sprintf("SUB topic name %q is not valid", topicName))
  516. }
  517. channelName := string(params[2])
  518. if !protocol.IsValidChannelName(channelName) {
  519. return nil, protocol.NewFatalClientErr(nil, "E_BAD_CHANNEL",
  520. fmt.Sprintf("SUB channel name %q is not valid", channelName))
  521. }
  522. if err := p.CheckAuth(client, "SUB", topicName, channelName); err != nil {
  523. return nil, err
  524. }
  525. // This retry-loop is a work-around for a race condition, where the
  526. // last client can leave the channel between GetChannel() and AddClient().
  527. // Avoid adding a client to an ephemeral channel / topic which has started exiting.
  528. var channel *Channel
  529. for i := 1; ; i++ {
  530. topic := p.nsqd.GetTopic(topicName)
  531. channel = topic.GetChannel(channelName)
  532. if err := channel.AddClient(client.ID, client); err != nil {
  533. return nil, protocol.NewFatalClientErr(err, "E_SUB_FAILED", "SUB failed "+err.Error())
  534. }
  535. if (channel.ephemeral && channel.Exiting()) || (topic.ephemeral && topic.Exiting()) {
  536. channel.RemoveClient(client.ID)
  537. if i < 2 {
  538. time.Sleep(100 * time.Millisecond)
  539. continue
  540. }
  541. return nil, protocol.NewFatalClientErr(nil, "E_SUB_FAILED", "SUB failed to deleted topic/channel")
  542. }
  543. break
  544. }
  545. atomic.StoreInt32(&client.State, stateSubscribed)
  546. client.Channel = channel
  547. // update message pump
  548. client.SubEventChan <- channel
  549. return okBytes, nil
  550. }
  551. func (p *protocolV2) RDY(client *clientV2, params [][]byte) ([]byte, error) {
  552. state := atomic.LoadInt32(&client.State)
  553. if state == stateClosing {
  554. // just ignore ready changes on a closing channel
  555. p.nsqd.logf(LOG_INFO,
  556. "PROTOCOL(V2): [%s] ignoring RDY after CLS in state ClientStateV2Closing",
  557. client)
  558. return nil, nil
  559. }
  560. if state != stateSubscribed {
  561. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot RDY in current state")
  562. }
  563. count := int64(1)
  564. if len(params) > 1 {
  565. b10, err := protocol.ByteToBase10(params[1])
  566. if err != nil {
  567. return nil, protocol.NewFatalClientErr(err, "E_INVALID",
  568. fmt.Sprintf("RDY could not parse count %s", params[1]))
  569. }
  570. count = int64(b10)
  571. }
  572. if count < 0 || count > p.nsqd.getOpts().MaxRdyCount {
  573. // this needs to be a fatal error otherwise clients would have
  574. // inconsistent state
  575. return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
  576. fmt.Sprintf("RDY count %d out of range 0-%d", count, p.nsqd.getOpts().MaxRdyCount))
  577. }
  578. client.SetReadyCount(count)
  579. return nil, nil
  580. }
  581. func (p *protocolV2) FIN(client *clientV2, params [][]byte) ([]byte, error) {
  582. state := atomic.LoadInt32(&client.State)
  583. if state != stateSubscribed && state != stateClosing {
  584. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot FIN in current state")
  585. }
  586. if len(params) < 2 {
  587. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "FIN insufficient number of params")
  588. }
  589. id, err := getMessageID(params[1])
  590. if err != nil {
  591. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
  592. }
  593. err = client.Channel.FinishMessage(client.ID, *id)
  594. if err != nil {
  595. return nil, protocol.NewClientErr(err, "E_FIN_FAILED",
  596. fmt.Sprintf("FIN %s failed %s", *id, err.Error()))
  597. }
  598. client.FinishedMessage()
  599. return nil, nil
  600. }
  601. func (p *protocolV2) REQ(client *clientV2, params [][]byte) ([]byte, error) {
  602. state := atomic.LoadInt32(&client.State)
  603. if state != stateSubscribed && state != stateClosing {
  604. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot REQ in current state")
  605. }
  606. if len(params) < 3 {
  607. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "REQ insufficient number of params")
  608. }
  609. id, err := getMessageID(params[1])
  610. if err != nil {
  611. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
  612. }
  613. timeoutMs, err := protocol.ByteToBase10(params[2])
  614. if err != nil {
  615. return nil, protocol.NewFatalClientErr(err, "E_INVALID",
  616. fmt.Sprintf("REQ could not parse timeout %s", params[2]))
  617. }
  618. timeoutDuration := time.Duration(timeoutMs) * time.Millisecond
  619. maxReqTimeout := p.nsqd.getOpts().MaxReqTimeout
  620. clampedTimeout := timeoutDuration
  621. if timeoutDuration < 0 {
  622. clampedTimeout = 0
  623. } else if timeoutDuration > maxReqTimeout {
  624. clampedTimeout = maxReqTimeout
  625. }
  626. if clampedTimeout != timeoutDuration {
  627. p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] REQ timeout %d out of range 0-%d. Setting to %d",
  628. client, timeoutDuration, maxReqTimeout, clampedTimeout)
  629. timeoutDuration = clampedTimeout
  630. }
  631. err = client.Channel.RequeueMessage(client.ID, *id, timeoutDuration)
  632. if err != nil {
  633. return nil, protocol.NewClientErr(err, "E_REQ_FAILED",
  634. fmt.Sprintf("REQ %s failed %s", *id, err.Error()))
  635. }
  636. client.RequeuedMessage()
  637. return nil, nil
  638. }
  639. func (p *protocolV2) CLS(client *clientV2, params [][]byte) ([]byte, error) {
  640. if atomic.LoadInt32(&client.State) != stateSubscribed {
  641. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot CLS in current state")
  642. }
  643. client.StartClose()
  644. return []byte("CLOSE_WAIT"), nil
  645. }
  646. func (p *protocolV2) NOP(client *clientV2, params [][]byte) ([]byte, error) {
  647. return nil, nil
  648. }
  649. func (p *protocolV2) PUB(client *clientV2, params [][]byte) ([]byte, error) {
  650. var err error
  651. if len(params) < 2 {
  652. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "PUB insufficient number of parameters")
  653. }
  654. topicName := string(params[1])
  655. if !protocol.IsValidTopicName(topicName) {
  656. return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
  657. fmt.Sprintf("PUB topic name %q is not valid", topicName))
  658. }
  659. bodyLen, err := readLen(client.Reader, client.lenSlice)
  660. if err != nil {
  661. return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body size")
  662. }
  663. if bodyLen <= 0 {
  664. return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
  665. fmt.Sprintf("PUB invalid message body size %d", bodyLen))
  666. }
  667. if int64(bodyLen) > p.nsqd.getOpts().MaxMsgSize {
  668. return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
  669. fmt.Sprintf("PUB message too big %d > %d", bodyLen, p.nsqd.getOpts().MaxMsgSize))
  670. }
  671. messageBody := make([]byte, bodyLen)
  672. _, err = io.ReadFull(client.Reader, messageBody)
  673. if err != nil {
  674. return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body")
  675. }
  676. if err := p.CheckAuth(client, "PUB", topicName, ""); err != nil {
  677. return nil, err
  678. }
  679. topic := p.nsqd.GetTopic(topicName)
  680. msg := NewMessage(topic.GenerateID(), messageBody)
  681. err = topic.PutMessage(msg)
  682. if err != nil {
  683. return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed "+err.Error())
  684. }
  685. client.PublishedMessage(topicName, 1)
  686. return okBytes, nil
  687. }
  688. func (p *protocolV2) MPUB(client *clientV2, params [][]byte) ([]byte, error) {
  689. var err error
  690. if len(params) < 2 {
  691. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "MPUB insufficient number of parameters")
  692. }
  693. topicName := string(params[1])
  694. if !protocol.IsValidTopicName(topicName) {
  695. return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
  696. fmt.Sprintf("E_BAD_TOPIC MPUB topic name %q is not valid", topicName))
  697. }
  698. if err := p.CheckAuth(client, "MPUB", topicName, ""); err != nil {
  699. return nil, err
  700. }
  701. topic := p.nsqd.GetTopic(topicName)
  702. bodyLen, err := readLen(client.Reader, client.lenSlice)
  703. if err != nil {
  704. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read body size")
  705. }
  706. if bodyLen <= 0 {
  707. return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
  708. fmt.Sprintf("MPUB invalid body size %d", bodyLen))
  709. }
  710. if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize {
  711. return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
  712. fmt.Sprintf("MPUB body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize))
  713. }
  714. messages, err := readMPUB(client.Reader, client.lenSlice, topic,
  715. p.nsqd.getOpts().MaxMsgSize, p.nsqd.getOpts().MaxBodySize)
  716. if err != nil {
  717. return nil, err
  718. }
  719. // if we've made it this far we've validated all the input,
  720. // the only possible error is that the topic is exiting during
  721. // this next call (and no messages will be queued in that case)
  722. err = topic.PutMessages(messages)
  723. if err != nil {
  724. return nil, protocol.NewFatalClientErr(err, "E_MPUB_FAILED", "MPUB failed "+err.Error())
  725. }
  726. client.PublishedMessage(topicName, uint64(len(messages)))
  727. return okBytes, nil
  728. }
  729. func (p *protocolV2) DPUB(client *clientV2, params [][]byte) ([]byte, error) {
  730. var err error
  731. if len(params) < 3 {
  732. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "DPUB insufficient number of parameters")
  733. }
  734. topicName := string(params[1])
  735. if !protocol.IsValidTopicName(topicName) {
  736. return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
  737. fmt.Sprintf("DPUB topic name %q is not valid", topicName))
  738. }
  739. timeoutMs, err := protocol.ByteToBase10(params[2])
  740. if err != nil {
  741. return nil, protocol.NewFatalClientErr(err, "E_INVALID",
  742. fmt.Sprintf("DPUB could not parse timeout %s", params[2]))
  743. }
  744. timeoutDuration := time.Duration(timeoutMs) * time.Millisecond
  745. if timeoutDuration < 0 || timeoutDuration > p.nsqd.getOpts().MaxReqTimeout {
  746. return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
  747. fmt.Sprintf("DPUB timeout %d out of range 0-%d",
  748. timeoutMs, p.nsqd.getOpts().MaxReqTimeout/time.Millisecond))
  749. }
  750. bodyLen, err := readLen(client.Reader, client.lenSlice)
  751. if err != nil {
  752. return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body size")
  753. }
  754. if bodyLen <= 0 {
  755. return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
  756. fmt.Sprintf("DPUB invalid message body size %d", bodyLen))
  757. }
  758. if int64(bodyLen) > p.nsqd.getOpts().MaxMsgSize {
  759. return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
  760. fmt.Sprintf("DPUB message too big %d > %d", bodyLen, p.nsqd.getOpts().MaxMsgSize))
  761. }
  762. messageBody := make([]byte, bodyLen)
  763. _, err = io.ReadFull(client.Reader, messageBody)
  764. if err != nil {
  765. return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body")
  766. }
  767. if err := p.CheckAuth(client, "DPUB", topicName, ""); err != nil {
  768. return nil, err
  769. }
  770. topic := p.nsqd.GetTopic(topicName)
  771. msg := NewMessage(topic.GenerateID(), messageBody)
  772. msg.deferred = timeoutDuration
  773. err = topic.PutMessage(msg)
  774. if err != nil {
  775. return nil, protocol.NewFatalClientErr(err, "E_DPUB_FAILED", "DPUB failed "+err.Error())
  776. }
  777. client.PublishedMessage(topicName, 1)
  778. return okBytes, nil
  779. }
  780. func (p *protocolV2) TOUCH(client *clientV2, params [][]byte) ([]byte, error) {
  781. state := atomic.LoadInt32(&client.State)
  782. if state != stateSubscribed && state != stateClosing {
  783. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot TOUCH in current state")
  784. }
  785. if len(params) < 2 {
  786. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "TOUCH insufficient number of params")
  787. }
  788. id, err := getMessageID(params[1])
  789. if err != nil {
  790. return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
  791. }
  792. client.writeLock.RLock()
  793. msgTimeout := client.MsgTimeout
  794. client.writeLock.RUnlock()
  795. err = client.Channel.TouchMessage(client.ID, *id, msgTimeout)
  796. if err != nil {
  797. return nil, protocol.NewClientErr(err, "E_TOUCH_FAILED",
  798. fmt.Sprintf("TOUCH %s failed %s", *id, err.Error()))
  799. }
  800. return nil, nil
  801. }
  802. func readMPUB(r io.Reader, tmp []byte, topic *Topic, maxMessageSize int64, maxBodySize int64) ([]*Message, error) {
  803. numMessages, err := readLen(r, tmp)
  804. if err != nil {
  805. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read message count")
  806. }
  807. // 4 == total num, 5 == length + min 1
  808. maxMessages := (maxBodySize - 4) / 5
  809. if numMessages <= 0 || int64(numMessages) > maxMessages {
  810. return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY",
  811. fmt.Sprintf("MPUB invalid message count %d", numMessages))
  812. }
  813. messages := make([]*Message, 0, numMessages)
  814. for i := int32(0); i < numMessages; i++ {
  815. messageSize, err := readLen(r, tmp)
  816. if err != nil {
  817. return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE",
  818. fmt.Sprintf("MPUB failed to read message(%d) body size", i))
  819. }
  820. if messageSize <= 0 {
  821. return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
  822. fmt.Sprintf("MPUB invalid message(%d) body size %d", i, messageSize))
  823. }
  824. if int64(messageSize) > maxMessageSize {
  825. return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
  826. fmt.Sprintf("MPUB message too big %d > %d", messageSize, maxMessageSize))
  827. }
  828. msgBody := make([]byte, messageSize)
  829. _, err = io.ReadFull(r, msgBody)
  830. if err != nil {
  831. return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "MPUB failed to read message body")
  832. }
  833. messages = append(messages, NewMessage(topic.GenerateID(), msgBody))
  834. }
  835. return messages, nil
  836. }
  837. // validate and cast the bytes on the wire to a message ID
  838. func getMessageID(p []byte) (*MessageID, error) {
  839. if len(p) != MsgIDLength {
  840. return nil, errors.New("Invalid Message ID")
  841. }
  842. return (*MessageID)(unsafe.Pointer(&p[0])), nil
  843. }
  844. func readLen(r io.Reader, tmp []byte) (int32, error) {
  845. _, err := io.ReadFull(r, tmp)
  846. if err != nil {
  847. return 0, err
  848. }
  849. return int32(binary.BigEndian.Uint32(tmp)), nil
  850. }
  851. func enforceTLSPolicy(client *clientV2, p *protocolV2, command []byte) error {
  852. if p.nsqd.getOpts().TLSRequired != TLSNotRequired && atomic.LoadInt32(&client.TLS) != 1 {
  853. return protocol.NewFatalClientErr(nil, "E_INVALID",
  854. fmt.Sprintf("cannot %s in current state (TLS required)", command))
  855. }
  856. return nil
  857. }