authorizations.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package auth
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "math/rand"
  7. "net/url"
  8. "regexp"
  9. "strings"
  10. "time"
  11. "github.com/nsqio/nsq/internal/http_api"
  12. )
  13. type Authorization struct {
  14. Topic string `json:"topic"`
  15. Channels []string `json:"channels"`
  16. Permissions []string `json:"permissions"`
  17. }
  18. type State struct {
  19. TTL int `json:"ttl"`
  20. Authorizations []Authorization `json:"authorizations"`
  21. Identity string `json:"identity"`
  22. IdentityURL string `json:"identity_url"`
  23. Expires time.Time
  24. }
  25. func (a *Authorization) HasPermission(permission string) bool {
  26. for _, p := range a.Permissions {
  27. if permission == p {
  28. return true
  29. }
  30. }
  31. return false
  32. }
  33. func (a *Authorization) IsAllowed(topic, channel string) bool {
  34. if channel != "" {
  35. if !a.HasPermission("subscribe") {
  36. return false
  37. }
  38. } else {
  39. if !a.HasPermission("publish") {
  40. return false
  41. }
  42. }
  43. topicRegex := regexp.MustCompile(a.Topic)
  44. if !topicRegex.MatchString(topic) {
  45. return false
  46. }
  47. for _, c := range a.Channels {
  48. channelRegex := regexp.MustCompile(c)
  49. if channelRegex.MatchString(channel) {
  50. return true
  51. }
  52. }
  53. return false
  54. }
  55. func (a *State) IsAllowed(topic, channel string) bool {
  56. for _, aa := range a.Authorizations {
  57. if aa.IsAllowed(topic, channel) {
  58. return true
  59. }
  60. }
  61. return false
  62. }
  63. func (a *State) IsExpired() bool {
  64. if a.Expires.Before(time.Now()) {
  65. return true
  66. }
  67. return false
  68. }
  69. func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
  70. connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
  71. start := rand.Int()
  72. n := len(authd)
  73. for i := 0; i < n; i++ {
  74. a := authd[(i+start)%n]
  75. authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, connectTimeout, requestTimeout)
  76. if err != nil {
  77. log.Printf("Error: failed auth against %s %s", a, err)
  78. continue
  79. }
  80. return authState, nil
  81. }
  82. return nil, errors.New("Unable to access auth server")
  83. }
  84. func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
  85. connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
  86. v := url.Values{}
  87. v.Set("remote_ip", remoteIP)
  88. if tlsEnabled {
  89. v.Set("tls", "true")
  90. } else {
  91. v.Set("tls", "false")
  92. }
  93. v.Set("secret", authSecret)
  94. v.Set("common_name", commonName)
  95. var endpoint string
  96. if strings.Contains(authd, "://") {
  97. endpoint = fmt.Sprintf("%s?%s", authd, v.Encode())
  98. } else {
  99. endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
  100. }
  101. var authState State
  102. client := http_api.NewClient(nil, connectTimeout, requestTimeout)
  103. if err := client.GETV1(endpoint, &authState); err != nil {
  104. return nil, err
  105. }
  106. // validation on response
  107. for _, auth := range authState.Authorizations {
  108. for _, p := range auth.Permissions {
  109. switch p {
  110. case "subscribe", "publish":
  111. default:
  112. return nil, fmt.Errorf("unknown permission %s", p)
  113. }
  114. }
  115. if _, err := regexp.Compile(auth.Topic); err != nil {
  116. return nil, fmt.Errorf("unable to compile topic %q %s", auth.Topic, err)
  117. }
  118. for _, channel := range auth.Channels {
  119. if _, err := regexp.Compile(channel); err != nil {
  120. return nil, fmt.Errorf("unable to compile channel %q %s", channel, err)
  121. }
  122. }
  123. }
  124. if authState.TTL <= 0 {
  125. return nil, fmt.Errorf("invalid TTL %d (must be >0)", authState.TTL)
  126. }
  127. authState.Expires = time.Now().Add(time.Duration(authState.TTL) * time.Second)
  128. return &authState, nil
  129. }