nsq_to_nsq.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. // This is an NSQ client that reads the specified topic/channel
  2. // and re-publishes the messages to destination nsqd via TCP
  3. package main
  4. import (
  5. "encoding/json"
  6. "errors"
  7. "flag"
  8. "fmt"
  9. "log"
  10. "os"
  11. "os/signal"
  12. "strconv"
  13. "sync/atomic"
  14. "syscall"
  15. "time"
  16. "github.com/bitly/go-hostpool"
  17. "github.com/bitly/timer_metrics"
  18. "github.com/nsqio/go-nsq"
  19. "github.com/nsqio/nsq/internal/app"
  20. "github.com/nsqio/nsq/internal/protocol"
  21. "github.com/nsqio/nsq/internal/version"
  22. )
  23. const (
  24. ModeRoundRobin = iota
  25. ModeHostPool
  26. )
  27. var (
  28. showVersion = flag.Bool("version", false, "print version string")
  29. channel = flag.String("channel", "nsq_to_nsq", "nsq channel")
  30. destTopic = flag.String("destination-topic", "", "use this destination topic for all consumed topics (default is consumed topic name)")
  31. maxInFlight = flag.Int("max-in-flight", 200, "max number of messages to allow in flight")
  32. statusEvery = flag.Int("status-every", 250, "the # of requests between logging status (per destination), 0 disables")
  33. mode = flag.String("mode", "hostpool", "the upstream request mode options: round-robin, hostpool (default), epsilon-greedy")
  34. nsqdTCPAddrs = app.StringArray{}
  35. lookupdHTTPAddrs = app.StringArray{}
  36. destNsqdTCPAddrs = app.StringArray{}
  37. whitelistJSONFields = app.StringArray{}
  38. topics = app.StringArray{}
  39. requireJSONField = flag.String("require-json-field", "", "for JSON messages: only pass messages that contain this field")
  40. requireJSONValue = flag.String("require-json-value", "", "for JSON messages: only pass messages in which the required field has this value")
  41. )
  42. func init() {
  43. flag.Var(&nsqdTCPAddrs, "nsqd-tcp-address", "nsqd TCP address (may be given multiple times)")
  44. flag.Var(&destNsqdTCPAddrs, "destination-nsqd-tcp-address", "destination nsqd TCP address (may be given multiple times)")
  45. flag.Var(&lookupdHTTPAddrs, "lookupd-http-address", "lookupd HTTP address (may be given multiple times)")
  46. flag.Var(&topics, "topic", "nsq topic (may be given multiple times)")
  47. flag.Var(&whitelistJSONFields, "whitelist-json-field", "for JSON messages: pass this field (may be given multiple times)")
  48. }
  49. type PublishHandler struct {
  50. // 64bit atomic vars need to be first for proper alignment on 32bit platforms
  51. counter uint64
  52. addresses app.StringArray
  53. producers map[string]*nsq.Producer
  54. mode int
  55. hostPool hostpool.HostPool
  56. respChan chan *nsq.ProducerTransaction
  57. requireJSONValueParsed bool
  58. requireJSONValueIsNumber bool
  59. requireJSONNumber float64
  60. perAddressStatus map[string]*timer_metrics.TimerMetrics
  61. timermetrics *timer_metrics.TimerMetrics
  62. }
  63. type TopicHandler struct {
  64. publishHandler *PublishHandler
  65. destinationTopic string
  66. }
  67. func (ph *PublishHandler) responder() {
  68. var msg *nsq.Message
  69. var startTime time.Time
  70. var address string
  71. var hostPoolResponse hostpool.HostPoolResponse
  72. for t := range ph.respChan {
  73. switch ph.mode {
  74. case ModeRoundRobin:
  75. msg = t.Args[0].(*nsq.Message)
  76. startTime = t.Args[1].(time.Time)
  77. hostPoolResponse = nil
  78. address = t.Args[2].(string)
  79. case ModeHostPool:
  80. msg = t.Args[0].(*nsq.Message)
  81. startTime = t.Args[1].(time.Time)
  82. hostPoolResponse = t.Args[2].(hostpool.HostPoolResponse)
  83. address = hostPoolResponse.Host()
  84. }
  85. success := t.Error == nil
  86. if hostPoolResponse != nil {
  87. if !success {
  88. hostPoolResponse.Mark(errors.New("failed"))
  89. } else {
  90. hostPoolResponse.Mark(nil)
  91. }
  92. }
  93. if success {
  94. msg.Finish()
  95. } else {
  96. msg.Requeue(-1)
  97. }
  98. ph.perAddressStatus[address].Status(startTime)
  99. ph.timermetrics.Status(startTime)
  100. }
  101. }
  102. func (ph *PublishHandler) shouldPassMessage(js map[string]interface{}) (bool, bool) {
  103. pass := true
  104. backoff := false
  105. if *requireJSONField == "" {
  106. return pass, backoff
  107. }
  108. if *requireJSONValue != "" && !ph.requireJSONValueParsed {
  109. // cache conversion in case needed while filtering json
  110. var err error
  111. ph.requireJSONNumber, err = strconv.ParseFloat(*requireJSONValue, 64)
  112. ph.requireJSONValueIsNumber = (err == nil)
  113. ph.requireJSONValueParsed = true
  114. }
  115. v, ok := js[*requireJSONField]
  116. if !ok {
  117. pass = false
  118. if *requireJSONValue != "" {
  119. log.Printf("ERROR: missing field to check required value")
  120. backoff = true
  121. }
  122. } else if *requireJSONValue != "" {
  123. // if command-line argument can't convert to float, then it can't match a number
  124. // if it can, also integers (up to 2^53 or so) can be compared as float64
  125. if s, ok := v.(string); ok {
  126. if s != *requireJSONValue {
  127. pass = false
  128. }
  129. } else if ph.requireJSONValueIsNumber {
  130. f, ok := v.(float64)
  131. if !ok || f != ph.requireJSONNumber {
  132. pass = false
  133. }
  134. } else {
  135. // json value wasn't a plain string, and argument wasn't a number
  136. // give up on comparisons of other types
  137. pass = false
  138. }
  139. }
  140. return pass, backoff
  141. }
  142. func filterMessage(js map[string]interface{}, rawMsg []byte) ([]byte, error) {
  143. if len(whitelistJSONFields) == 0 {
  144. // no change
  145. return rawMsg, nil
  146. }
  147. newMsg := make(map[string]interface{}, len(whitelistJSONFields))
  148. for _, key := range whitelistJSONFields {
  149. value, ok := js[key]
  150. if ok {
  151. // avoid printing int as float (go 1.0)
  152. switch tvalue := value.(type) {
  153. case float64:
  154. ivalue := int64(tvalue)
  155. if float64(ivalue) == tvalue {
  156. newMsg[key] = ivalue
  157. } else {
  158. newMsg[key] = tvalue
  159. }
  160. default:
  161. newMsg[key] = value
  162. }
  163. }
  164. }
  165. newRawMsg, err := json.Marshal(newMsg)
  166. if err != nil {
  167. return nil, fmt.Errorf("unable to marshal filtered message %v", newMsg)
  168. }
  169. return newRawMsg, nil
  170. }
  171. func (t *TopicHandler) HandleMessage(m *nsq.Message) error {
  172. return t.publishHandler.HandleMessage(m, t.destinationTopic)
  173. }
  174. func (ph *PublishHandler) HandleMessage(m *nsq.Message, destinationTopic string) error {
  175. var err error
  176. msgBody := m.Body
  177. if *requireJSONField != "" || len(whitelistJSONFields) > 0 {
  178. var js map[string]interface{}
  179. err = json.Unmarshal(msgBody, &js)
  180. if err != nil {
  181. log.Printf("ERROR: Unable to decode json: %s", msgBody)
  182. return nil
  183. }
  184. if pass, backoff := ph.shouldPassMessage(js); !pass {
  185. if backoff {
  186. return errors.New("backoff")
  187. }
  188. return nil
  189. }
  190. msgBody, err = filterMessage(js, msgBody)
  191. if err != nil {
  192. log.Printf("ERROR: filterMessage() failed: %s", err)
  193. return err
  194. }
  195. }
  196. startTime := time.Now()
  197. switch ph.mode {
  198. case ModeRoundRobin:
  199. counter := atomic.AddUint64(&ph.counter, 1)
  200. idx := counter % uint64(len(ph.addresses))
  201. addr := ph.addresses[idx]
  202. p := ph.producers[addr]
  203. err = p.PublishAsync(destinationTopic, msgBody, ph.respChan, m, startTime, addr)
  204. case ModeHostPool:
  205. hostPoolResponse := ph.hostPool.Get()
  206. p := ph.producers[hostPoolResponse.Host()]
  207. err = p.PublishAsync(destinationTopic, msgBody, ph.respChan, m, startTime, hostPoolResponse)
  208. if err != nil {
  209. hostPoolResponse.Mark(err)
  210. }
  211. }
  212. if err != nil {
  213. return err
  214. }
  215. m.DisableAutoResponse()
  216. return nil
  217. }
  218. func hasArg(s string) bool {
  219. argExist := false
  220. flag.Visit(func(f *flag.Flag) {
  221. if f.Name == s {
  222. argExist = true
  223. }
  224. })
  225. return argExist
  226. }
  227. func main() {
  228. var selectedMode int
  229. cCfg := nsq.NewConfig()
  230. pCfg := nsq.NewConfig()
  231. flag.Var(&nsq.ConfigFlag{cCfg}, "consumer-opt", "option to passthrough to nsq.Consumer (may be given multiple times, see http://godoc.org/github.com/nsqio/go-nsq#Config)")
  232. flag.Var(&nsq.ConfigFlag{pCfg}, "producer-opt", "option to passthrough to nsq.Producer (may be given multiple times, see http://godoc.org/github.com/nsqio/go-nsq#Config)")
  233. flag.Parse()
  234. if *showVersion {
  235. fmt.Printf("nsq_to_nsq v%s\n", version.Binary)
  236. return
  237. }
  238. if len(topics) == 0 || *channel == "" {
  239. log.Fatal("--topic and --channel are required")
  240. }
  241. for _, topic := range topics {
  242. if !protocol.IsValidTopicName(topic) {
  243. log.Fatal("--topic is invalid")
  244. }
  245. }
  246. if *destTopic != "" && !protocol.IsValidTopicName(*destTopic) {
  247. log.Fatal("--destination-topic is invalid")
  248. }
  249. if !protocol.IsValidChannelName(*channel) {
  250. log.Fatal("--channel is invalid")
  251. }
  252. if len(nsqdTCPAddrs) == 0 && len(lookupdHTTPAddrs) == 0 {
  253. log.Fatal("--nsqd-tcp-address or --lookupd-http-address required")
  254. }
  255. if len(nsqdTCPAddrs) > 0 && len(lookupdHTTPAddrs) > 0 {
  256. log.Fatal("use --nsqd-tcp-address or --lookupd-http-address not both")
  257. }
  258. if len(destNsqdTCPAddrs) == 0 {
  259. log.Fatal("--destination-nsqd-tcp-address required")
  260. }
  261. switch *mode {
  262. case "round-robin":
  263. selectedMode = ModeRoundRobin
  264. case "hostpool", "epsilon-greedy":
  265. selectedMode = ModeHostPool
  266. }
  267. termChan := make(chan os.Signal, 1)
  268. signal.Notify(termChan, syscall.SIGINT, syscall.SIGTERM)
  269. defaultUA := fmt.Sprintf("nsq_to_nsq/%s go-nsq/%s", version.Binary, nsq.VERSION)
  270. cCfg.UserAgent = defaultUA
  271. cCfg.MaxInFlight = *maxInFlight
  272. pCfg.UserAgent = defaultUA
  273. producers := make(map[string]*nsq.Producer)
  274. for _, addr := range destNsqdTCPAddrs {
  275. producer, err := nsq.NewProducer(addr, pCfg)
  276. if err != nil {
  277. log.Fatalf("failed creating producer %s", err)
  278. }
  279. producers[addr] = producer
  280. }
  281. perAddressStatus := make(map[string]*timer_metrics.TimerMetrics)
  282. if len(destNsqdTCPAddrs) == 1 {
  283. // disable since there is only one address
  284. perAddressStatus[destNsqdTCPAddrs[0]] = timer_metrics.NewTimerMetrics(0, "")
  285. } else {
  286. for _, a := range destNsqdTCPAddrs {
  287. perAddressStatus[a] = timer_metrics.NewTimerMetrics(*statusEvery,
  288. fmt.Sprintf("[%s]:", a))
  289. }
  290. }
  291. hostPool := hostpool.New(destNsqdTCPAddrs)
  292. if *mode == "epsilon-greedy" {
  293. hostPool = hostpool.NewEpsilonGreedy(destNsqdTCPAddrs, 0, &hostpool.LinearEpsilonValueCalculator{})
  294. }
  295. var consumerList []*nsq.Consumer
  296. publisher := &PublishHandler{
  297. addresses: destNsqdTCPAddrs,
  298. producers: producers,
  299. mode: selectedMode,
  300. hostPool: hostPool,
  301. respChan: make(chan *nsq.ProducerTransaction, len(destNsqdTCPAddrs)),
  302. perAddressStatus: perAddressStatus,
  303. timermetrics: timer_metrics.NewTimerMetrics(*statusEvery, "[aggregate]:"),
  304. }
  305. for _, topic := range topics {
  306. consumer, err := nsq.NewConsumer(topic, *channel, cCfg)
  307. consumerList = append(consumerList, consumer)
  308. if err != nil {
  309. log.Fatal(err)
  310. }
  311. publishTopic := topic
  312. if *destTopic != "" {
  313. publishTopic = *destTopic
  314. }
  315. topicHandler := &TopicHandler{
  316. publishHandler: publisher,
  317. destinationTopic: publishTopic,
  318. }
  319. consumer.AddConcurrentHandlers(topicHandler, len(destNsqdTCPAddrs))
  320. }
  321. for i := 0; i < len(destNsqdTCPAddrs); i++ {
  322. go publisher.responder()
  323. }
  324. for _, consumer := range consumerList {
  325. err := consumer.ConnectToNSQDs(nsqdTCPAddrs)
  326. if err != nil {
  327. log.Fatal(err)
  328. }
  329. }
  330. for _, consumer := range consumerList {
  331. err := consumer.ConnectToNSQLookupds(lookupdHTTPAddrs)
  332. if err != nil {
  333. log.Fatal(err)
  334. }
  335. }
  336. <-termChan // wait for signal
  337. for _, consumer := range consumerList {
  338. consumer.Stop()
  339. }
  340. for _, consumer := range consumerList {
  341. <-consumer.StopChan
  342. }
  343. }