nsq_to_http.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. // This is an NSQ client that reads the specified topic/channel
  2. // and performs HTTP requests (GET/POST) to the specified endpoints
  3. package main
  4. import (
  5. "bytes"
  6. "flag"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "log"
  11. "math/rand"
  12. "net/http"
  13. "net/url"
  14. "os"
  15. "os/signal"
  16. "strings"
  17. "sync/atomic"
  18. "syscall"
  19. "time"
  20. "github.com/bitly/go-hostpool"
  21. "github.com/bitly/timer_metrics"
  22. "github.com/nsqio/go-nsq"
  23. "github.com/nsqio/nsq/internal/app"
  24. "github.com/nsqio/nsq/internal/http_api"
  25. "github.com/nsqio/nsq/internal/version"
  26. )
  27. const (
  28. ModeAll = iota
  29. ModeRoundRobin
  30. ModeHostPool
  31. )
  32. var (
  33. showVersion = flag.Bool("version", false, "print version string")
  34. topic = flag.String("topic", "", "nsq topic")
  35. channel = flag.String("channel", "nsq_to_http", "nsq channel")
  36. maxInFlight = flag.Int("max-in-flight", 200, "max number of messages to allow in flight")
  37. numPublishers = flag.Int("n", 100, "number of concurrent publishers")
  38. mode = flag.String("mode", "hostpool", "the upstream request mode options: round-robin, hostpool (default), epsilon-greedy")
  39. sample = flag.Float64("sample", 1.0, "% of messages to publish (float b/w 0 -> 1)")
  40. httpConnectTimeout = flag.Duration("http-client-connect-timeout", 2*time.Second, "timeout for HTTP connect")
  41. httpRequestTimeout = flag.Duration("http-client-request-timeout", 20*time.Second, "timeout for HTTP request")
  42. statusEvery = flag.Int("status-every", 250, "the # of requests between logging status (per handler), 0 disables")
  43. contentType = flag.String("content-type", "application/octet-stream", "the Content-Type used for POST requests")
  44. getAddrs = app.StringArray{}
  45. postAddrs = app.StringArray{}
  46. customHeaders = app.StringArray{}
  47. nsqdTCPAddrs = app.StringArray{}
  48. lookupdHTTPAddrs = app.StringArray{}
  49. validCustomHeaders map[string]string
  50. )
  51. func init() {
  52. flag.Var(&postAddrs, "post", "HTTP address to make a POST request to. data will be in the body (may be given multiple times)")
  53. flag.Var(&customHeaders, "header", "Custom header for HTTP requests (may be given multiple times)")
  54. flag.Var(&getAddrs, "get", "HTTP address to make a GET request to. '%s' will be printf replaced with data (may be given multiple times)")
  55. flag.Var(&nsqdTCPAddrs, "nsqd-tcp-address", "nsqd TCP address (may be given multiple times)")
  56. flag.Var(&lookupdHTTPAddrs, "lookupd-http-address", "lookupd HTTP address (may be given multiple times)")
  57. }
  58. type Publisher interface {
  59. Publish(string, []byte) error
  60. }
  61. type PublishHandler struct {
  62. // 64bit atomic vars need to be first for proper alignment on 32bit platforms
  63. counter uint64
  64. Publisher
  65. addresses app.StringArray
  66. mode int
  67. hostPool hostpool.HostPool
  68. perAddressStatus map[string]*timer_metrics.TimerMetrics
  69. timermetrics *timer_metrics.TimerMetrics
  70. }
  71. func (ph *PublishHandler) HandleMessage(m *nsq.Message) error {
  72. if *sample < 1.0 && rand.Float64() > *sample {
  73. return nil
  74. }
  75. startTime := time.Now()
  76. switch ph.mode {
  77. case ModeAll:
  78. for _, addr := range ph.addresses {
  79. st := time.Now()
  80. err := ph.Publish(addr, m.Body)
  81. if err != nil {
  82. return err
  83. }
  84. ph.perAddressStatus[addr].Status(st)
  85. }
  86. case ModeRoundRobin:
  87. counter := atomic.AddUint64(&ph.counter, 1)
  88. idx := counter % uint64(len(ph.addresses))
  89. addr := ph.addresses[idx]
  90. err := ph.Publish(addr, m.Body)
  91. if err != nil {
  92. return err
  93. }
  94. ph.perAddressStatus[addr].Status(startTime)
  95. case ModeHostPool:
  96. hostPoolResponse := ph.hostPool.Get()
  97. addr := hostPoolResponse.Host()
  98. err := ph.Publish(addr, m.Body)
  99. hostPoolResponse.Mark(err)
  100. if err != nil {
  101. return err
  102. }
  103. ph.perAddressStatus[addr].Status(startTime)
  104. }
  105. ph.timermetrics.Status(startTime)
  106. return nil
  107. }
  108. type PostPublisher struct{}
  109. func (p *PostPublisher) Publish(addr string, msg []byte) error {
  110. buf := bytes.NewBuffer(msg)
  111. resp, err := HTTPPost(addr, buf)
  112. if err != nil {
  113. return err
  114. }
  115. io.Copy(ioutil.Discard, resp.Body)
  116. resp.Body.Close()
  117. if resp.StatusCode < 200 || resp.StatusCode >= 300 {
  118. return fmt.Errorf("got status code %d", resp.StatusCode)
  119. }
  120. return nil
  121. }
  122. type GetPublisher struct{}
  123. func (p *GetPublisher) Publish(addr string, msg []byte) error {
  124. endpoint := fmt.Sprintf(addr, url.QueryEscape(string(msg)))
  125. resp, err := HTTPGet(endpoint)
  126. if err != nil {
  127. return err
  128. }
  129. io.Copy(ioutil.Discard, resp.Body)
  130. resp.Body.Close()
  131. if resp.StatusCode != 200 {
  132. return fmt.Errorf("got status code %d", resp.StatusCode)
  133. }
  134. return nil
  135. }
  136. func hasArg(s string) bool {
  137. argExist := false
  138. flag.Visit(func(f *flag.Flag) {
  139. if f.Name == s {
  140. argExist = true
  141. }
  142. })
  143. return argExist
  144. }
  145. func main() {
  146. var publisher Publisher
  147. var addresses app.StringArray
  148. var selectedMode int
  149. cfg := nsq.NewConfig()
  150. flag.Var(&nsq.ConfigFlag{cfg}, "consumer-opt", "option to passthrough to nsq.Consumer (may be given multiple times, http://godoc.org/github.com/nsqio/go-nsq#Config)")
  151. flag.Parse()
  152. httpclient = &http.Client{Transport: http_api.NewDeadlineTransport(*httpConnectTimeout, *httpRequestTimeout), Timeout: *httpRequestTimeout}
  153. if *showVersion {
  154. fmt.Printf("nsq_to_http v%s\n", version.Binary)
  155. return
  156. }
  157. if len(customHeaders) > 0 {
  158. var err error
  159. validCustomHeaders, err = parseCustomHeaders(customHeaders)
  160. if err != nil {
  161. log.Fatal("--header value format should be 'key=value'")
  162. }
  163. }
  164. if *topic == "" || *channel == "" {
  165. log.Fatal("--topic and --channel are required")
  166. }
  167. if *contentType != flag.Lookup("content-type").DefValue {
  168. if len(postAddrs) == 0 {
  169. log.Fatal("--content-type only used with --post")
  170. }
  171. if len(*contentType) == 0 {
  172. log.Fatal("--content-type requires a value when used")
  173. }
  174. }
  175. if len(nsqdTCPAddrs) == 0 && len(lookupdHTTPAddrs) == 0 {
  176. log.Fatal("--nsqd-tcp-address or --lookupd-http-address required")
  177. }
  178. if len(nsqdTCPAddrs) > 0 && len(lookupdHTTPAddrs) > 0 {
  179. log.Fatal("use --nsqd-tcp-address or --lookupd-http-address not both")
  180. }
  181. if len(getAddrs) == 0 && len(postAddrs) == 0 {
  182. log.Fatal("--get or --post required")
  183. }
  184. if len(getAddrs) > 0 && len(postAddrs) > 0 {
  185. log.Fatal("use --get or --post not both")
  186. }
  187. if len(getAddrs) > 0 {
  188. for _, get := range getAddrs {
  189. if strings.Count(get, "%s") != 1 {
  190. log.Fatal("invalid GET address - must be a printf string")
  191. }
  192. }
  193. }
  194. switch *mode {
  195. case "round-robin":
  196. selectedMode = ModeRoundRobin
  197. case "hostpool", "epsilon-greedy":
  198. selectedMode = ModeHostPool
  199. }
  200. if *sample > 1.0 || *sample < 0.0 {
  201. log.Fatal("ERROR: --sample must be between 0.0 and 1.0")
  202. }
  203. termChan := make(chan os.Signal, 1)
  204. signal.Notify(termChan, syscall.SIGINT, syscall.SIGTERM)
  205. if len(postAddrs) > 0 {
  206. publisher = &PostPublisher{}
  207. addresses = postAddrs
  208. } else {
  209. publisher = &GetPublisher{}
  210. addresses = getAddrs
  211. }
  212. cfg.UserAgent = fmt.Sprintf("smq_to_http/%s go-smq/%s", version.Binary, nsq.VERSION)
  213. cfg.MaxInFlight = *maxInFlight
  214. consumer, err := nsq.NewConsumer(*topic, *channel, cfg)
  215. if err != nil {
  216. log.Fatal(err)
  217. }
  218. perAddressStatus := make(map[string]*timer_metrics.TimerMetrics)
  219. if len(addresses) == 1 {
  220. // disable since there is only one address
  221. perAddressStatus[addresses[0]] = timer_metrics.NewTimerMetrics(0, "")
  222. } else {
  223. for _, a := range addresses {
  224. perAddressStatus[a] = timer_metrics.NewTimerMetrics(*statusEvery,
  225. fmt.Sprintf("[%s]:", a))
  226. }
  227. }
  228. hostPool := hostpool.New(addresses)
  229. if *mode == "epsilon-greedy" {
  230. hostPool = hostpool.NewEpsilonGreedy(addresses, 0, &hostpool.LinearEpsilonValueCalculator{})
  231. }
  232. handler := &PublishHandler{
  233. Publisher: publisher,
  234. addresses: addresses,
  235. mode: selectedMode,
  236. hostPool: hostPool,
  237. perAddressStatus: perAddressStatus,
  238. timermetrics: timer_metrics.NewTimerMetrics(*statusEvery, "[aggregate]:"),
  239. }
  240. consumer.AddConcurrentHandlers(handler, *numPublishers)
  241. err = consumer.ConnectToNSQDs(nsqdTCPAddrs)
  242. if err != nil {
  243. log.Fatal(err)
  244. }
  245. err = consumer.ConnectToNSQLookupds(lookupdHTTPAddrs)
  246. if err != nil {
  247. log.Fatal(err)
  248. }
  249. for {
  250. select {
  251. case <-consumer.StopChan:
  252. return
  253. case <-termChan:
  254. consumer.Stop()
  255. }
  256. }
  257. }
  258. func parseCustomHeaders(strs []string) (map[string]string, error) {
  259. parsedHeaders := make(map[string]string)
  260. for _, s := range strs {
  261. sp := strings.SplitN(s, ":", 2)
  262. if len(sp) != 2 {
  263. return nil, fmt.Errorf("Invalid headers: %q", s)
  264. }
  265. key := strings.TrimSpace(sp[0])
  266. val := strings.TrimSpace(sp[1])
  267. if key == "" || val == "" {
  268. return nil, fmt.Errorf("Invalid headers: %q", s)
  269. }
  270. parsedHeaders[key] = val
  271. }
  272. return parsedHeaders, nil
  273. }