123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- package auth
- import (
- "errors"
- "fmt"
- "log"
- "math/rand"
- "net/url"
- "regexp"
- "strings"
- "time"
- "github.com/nsqio/nsq/internal/http_api"
- )
- type Authorization struct {
- Topic string `json:"topic"`
- Channels []string `json:"channels"`
- Permissions []string `json:"permissions"`
- }
- type State struct {
- TTL int `json:"ttl"`
- Authorizations []Authorization `json:"authorizations"`
- Identity string `json:"identity"`
- IdentityURL string `json:"identity_url"`
- Expires time.Time
- }
- func (a *Authorization) HasPermission(permission string) bool {
- for _, p := range a.Permissions {
- if permission == p {
- return true
- }
- }
- return false
- }
- func (a *Authorization) IsAllowed(topic, channel string) bool {
- if channel != "" {
- if !a.HasPermission("subscribe") {
- return false
- }
- } else {
- if !a.HasPermission("publish") {
- return false
- }
- }
- topicRegex := regexp.MustCompile(a.Topic)
- if !topicRegex.MatchString(topic) {
- return false
- }
- for _, c := range a.Channels {
- channelRegex := regexp.MustCompile(c)
- if channelRegex.MatchString(channel) {
- return true
- }
- }
- return false
- }
- func (a *State) IsAllowed(topic, channel string) bool {
- for _, aa := range a.Authorizations {
- if aa.IsAllowed(topic, channel) {
- return true
- }
- }
- return false
- }
- func (a *State) IsExpired() bool {
- if a.Expires.Before(time.Now()) {
- return true
- }
- return false
- }
- func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
- connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
- start := rand.Int()
- n := len(authd)
- for i := 0; i < n; i++ {
- a := authd[(i+start)%n]
- authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, connectTimeout, requestTimeout)
- if err != nil {
- log.Printf("Error: failed auth against %s %s", a, err)
- continue
- }
- return authState, nil
- }
- return nil, errors.New("Unable to access auth server")
- }
- func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
- connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
- v := url.Values{}
- v.Set("remote_ip", remoteIP)
- if tlsEnabled {
- v.Set("tls", "true")
- } else {
- v.Set("tls", "false")
- }
- v.Set("secret", authSecret)
- v.Set("common_name", commonName)
- var endpoint string
- if strings.Contains(authd, "://") {
- endpoint = fmt.Sprintf("%s?%s", authd, v.Encode())
- } else {
- endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
- }
- var authState State
- client := http_api.NewClient(nil, connectTimeout, requestTimeout)
- if err := client.GETV1(endpoint, &authState); err != nil {
- return nil, err
- }
- // validation on response
- for _, auth := range authState.Authorizations {
- for _, p := range auth.Permissions {
- switch p {
- case "subscribe", "publish":
- default:
- return nil, fmt.Errorf("unknown permission %s", p)
- }
- }
- if _, err := regexp.Compile(auth.Topic); err != nil {
- return nil, fmt.Errorf("unable to compile topic %q %s", auth.Topic, err)
- }
- for _, channel := range auth.Channels {
- if _, err := regexp.Compile(channel); err != nil {
- return nil, fmt.Errorf("unable to compile channel %q %s", channel, err)
- }
- }
- }
- if authState.TTL <= 0 {
- return nil, fmt.Errorf("invalid TTL %d (must be >0)", authState.TTL)
- }
- authState.Expires = time.Now().Add(time.Duration(authState.TTL) * time.Second)
- return &authState, nil
- }
|