...

Source file src/github.com/cybertec-postgresql/pgwatch/v5/internal/sources/resolver.go

Documentation: github.com/cybertec-postgresql/pgwatch/v5/internal/sources

     1  package sources
     2  
     3  // This file contains the implemendation of Patroni and PostgrSQL resolvers for continuous monitoring.
     4  // Patroni resolver will return the list of databases from the Patroni cluster.
     5  // Postgres resolver will return the list of databases from the given Postgres instance.
     6  
     7  import (
     8  	"cmp"
     9  	"context"
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"errors"
    13  	"fmt"
    14  	"net/url"
    15  	"os"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	jsoniter "github.com/json-iterator/go"
    21  
    22  	"github.com/cybertec-postgresql/pgwatch/v5/internal/db"
    23  	"github.com/cybertec-postgresql/pgwatch/v5/internal/log"
    24  	pgx "github.com/jackc/pgx/v5"
    25  	client "go.etcd.io/etcd/client/v3"
    26  	"go.uber.org/zap"
    27  )
    28  
    29  // ResolveDatabases() updates list of monitored objects from continuous monitoring sources, e.g. patroni.
    30  // Each source is resolved concurrently so that a slow or unreachable source does not block the others.
    31  func (srcs Sources) ResolveDatabases(onError func(string)) (_ SourceConns, err error) {
    32  	type result struct {
    33  		dbs SourceConns
    34  		err error
    35  	}
    36  	results := make([]result, len(srcs))
    37  	var wg sync.WaitGroup
    38  	for i, s := range srcs {
    39  		wg.Go(func() {
    40  			dbs, e := s.ResolveDatabases()
    41  			results[i] = result{dbs, e}
    42  		})
    43  	}
    44  	wg.Wait()
    45  	resolvedDbs := make(SourceConns, 0, len(srcs))
    46  	for i, res := range results {
    47  		if res.err != nil {
    48  			if onError != nil {
    49  				onError(srcs[i].Name)
    50  			}
    51  			logger.WithField("source", srcs[i].Name).WithError(res.err).Error("could not resolve databases from source")
    52  			err = errors.Join(err, res.err)
    53  		}
    54  		resolvedDbs = append(resolvedDbs, res.dbs...)
    55  	}
    56  	return resolvedDbs, err
    57  }
    58  
    59  // ResolveDatabases() return a slice of found databases for continuous monitoring sources, e.g. patroni
    60  func (s Source) ResolveDatabases() (SourceConns, error) {
    61  	switch s.Kind {
    62  	case SourcePatroni:
    63  		return ResolveDatabasesFromPatroni(s)
    64  	case SourcePostgresContinuous:
    65  		return ResolveDatabasesFromPostgres(s)
    66  	}
    67  	return SourceConns{NewSourceConn(s)}, nil
    68  }
    69  
    70  type PatroniClusterMember struct {
    71  	Scope   string
    72  	Name    string
    73  	ConnURL string `yaml:"conn_url"`
    74  	Role    string
    75  }
    76  
    77  func (pcm PatroniClusterMember) IsPrimary() bool {
    78  	return pcm.Role == "primary" || pcm.Role == "master"
    79  }
    80  
    81  var logger log.Logger = log.FallbackLogger
    82  
    83  var lastFoundClusterMembers = make(map[string][]PatroniClusterMember) // needed for cases where DCS is temporarily down
    84  // don't want to immediately remove monitoring of DBs
    85  
    86  func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
    87  	return nil, errors.ErrUnsupported
    88  }
    89  
    90  func getZookeeperClusterMembers(Source, HostConfig) ([]PatroniClusterMember, error) {
    91  	return nil, errors.ErrUnsupported
    92  }
    93  
    94  func jsonTextToStringMap(jsonText string) (map[string]string, error) {
    95  	retmap := make(map[string]string)
    96  	if jsonText == "" {
    97  		return retmap, nil
    98  	}
    99  	var iMap map[string]any
   100  	if err := jsoniter.ConfigFastest.Unmarshal([]byte(jsonText), &iMap); err != nil {
   101  		return nil, err
   102  	}
   103  	for k, v := range iMap {
   104  		retmap[k] = fmt.Sprintf("%v", v)
   105  	}
   106  	return retmap, nil
   107  }
   108  
   109  func getTransport(conf HostConfig) (*tls.Config, error) {
   110  	var caCertPool *x509.CertPool
   111  
   112  	// create valid CertPool only if the ca certificate file exists
   113  	if conf.CAFile != "" {
   114  		caCert, err := os.ReadFile(conf.CAFile)
   115  		if err != nil {
   116  			return nil, fmt.Errorf("cannot load CA file: %s", err)
   117  		}
   118  
   119  		caCertPool = x509.NewCertPool()
   120  		caCertPool.AppendCertsFromPEM(caCert)
   121  	}
   122  
   123  	var certificates []tls.Certificate
   124  
   125  	// create valid []Certificate only if the client cert and key files exists
   126  	if conf.CertFile != "" && conf.KeyFile != "" {
   127  		cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
   128  		if err != nil {
   129  			return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
   130  		}
   131  
   132  		certificates = []tls.Certificate{cert}
   133  	}
   134  
   135  	tlsClientConfig := new(tls.Config)
   136  
   137  	if caCertPool != nil {
   138  		tlsClientConfig.RootCAs = caCertPool
   139  		if certificates != nil {
   140  			tlsClientConfig.Certificates = certificates
   141  		}
   142  	}
   143  
   144  	return tlsClientConfig, nil
   145  }
   146  
   147  func getEtcdClusterMembers(s Source, hc HostConfig) ([]PatroniClusterMember, error) {
   148  	var ret = make([]PatroniClusterMember, 0)
   149  	var cfg client.Config
   150  
   151  	if len(hc.DcsEndpoints) == 0 {
   152  		return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
   153  	}
   154  
   155  	tlsConfig, err := getTransport(hc)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	cfg = client.Config{
   160  		Endpoints:            hc.DcsEndpoints,
   161  		TLS:                  tlsConfig,
   162  		DialKeepAliveTimeout: time.Second,
   163  		Username:             hc.Username,
   164  		Password:             hc.Password,
   165  		DialTimeout:          5 * time.Second,
   166  		Logger:               zap.NewNop(),
   167  	}
   168  
   169  	c, err := client.New(cfg)
   170  	if err != nil {
   171  		return ret, err
   172  	}
   173  	defer c.Close()
   174  
   175  	ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
   176  	defer cancel()
   177  
   178  	// etcd3 does not have a dir node.
   179  	// Key="/namespace/scope/leader", e.g. "/service/batman/leader"
   180  	// Key="/namespace/scope/members/node", e.g. "/service/batman/members/pg1"
   181  
   182  	resp, err := c.Get(ctx, hc.Path, client.WithPrefix())
   183  	if err != nil {
   184  		return ret, cmp.Or(context.Cause(ctx), err)
   185  	}
   186  
   187  	for _, node := range resp.Kvs {
   188  		// remove leading slash and split by "/"
   189  		parts := strings.Split(strings.TrimPrefix(string(node.Key), "/"), "/")
   190  		if len(parts) < 4 || parts[2] != "members" {
   191  			continue // skip non-member keys
   192  		}
   193  		nodeData, err := jsonTextToStringMap(string(node.Value))
   194  		if err != nil {
   195  			logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node.Key, err)
   196  			continue
   197  		}
   198  		role := nodeData["role"]
   199  		connURL := nodeData["conn_url"]
   200  		scope := parts[1]
   201  		name := parts[3]
   202  		ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
   203  	}
   204  
   205  	lastFoundClusterMembers[s.Name] = ret
   206  	return ret, nil
   207  }
   208  
   209  const (
   210  	dcsTypeEtcd      = "etcd"
   211  	dcsTypeZookeeper = "zookeeper"
   212  	dcsTypeConsul    = "consul"
   213  )
   214  
   215  type HostConfig struct {
   216  	DcsType      string   `yaml:"dcs_type"`
   217  	DcsEndpoints []string `yaml:"dcs_endpoints"`
   218  	Path         string
   219  	Username     string
   220  	Password     string
   221  	CAFile       string `yaml:"ca_file"`
   222  	CertFile     string `yaml:"cert_file"`
   223  	KeyFile      string `yaml:"key_file"`
   224  }
   225  
   226  func (hc HostConfig) IsScopeSpecified() bool {
   227  	// Path is usually "/namespace/scope"
   228  	// so we check if it has at least 2 slashes
   229  	return strings.Count(hc.Path, "/") >= 2
   230  }
   231  
   232  func NewHostConfig(URI string) (hc HostConfig, err error) {
   233  	// Extract scheme
   234  	before, after, ok := strings.Cut(URI, "://")
   235  	if !ok {
   236  		return hc, fmt.Errorf("invalid URI: missing scheme")
   237  	}
   238  	scheme := before
   239  	remainder := after // skip "://"
   240  
   241  	// Find where the host portion ends (at first '/' or '?' or end of string)
   242  	hostEnd := strings.IndexAny(remainder, "/?")
   243  	var hostPart, pathAndQuery string
   244  	if hostEnd == -1 {
   245  		hostPart = remainder
   246  		pathAndQuery = ""
   247  	} else {
   248  		hostPart = remainder[:hostEnd]
   249  		pathAndQuery = remainder[hostEnd:]
   250  	}
   251  
   252  	// Check for user info (username:password@)
   253  	var userInfo string
   254  	if atIdx := strings.LastIndex(hostPart, "@"); atIdx != -1 {
   255  		userInfo = hostPart[:atIdx]
   256  		hostPart = hostPart[atIdx+1:]
   257  	}
   258  
   259  	// Split hosts by comma for multiple endpoints
   260  	hosts := strings.Split(hostPart, ",")
   261  
   262  	// Parse a clean URL with just the first host to extract other components
   263  	cleanURI := scheme + "://"
   264  	if userInfo != "" {
   265  		cleanURI += userInfo + "@"
   266  	}
   267  	cleanURI += hosts[0] + pathAndQuery
   268  
   269  	var url *url.URL
   270  	url, err = url.Parse(cleanURI)
   271  	if err != nil {
   272  		return
   273  	}
   274  
   275  	switch url.Scheme {
   276  	case dcsTypeEtcd:
   277  		hc.DcsType = dcsTypeEtcd
   278  		for _, h := range hosts {
   279  			hc.DcsEndpoints = append(hc.DcsEndpoints, "http://"+h)
   280  		}
   281  	case dcsTypeZookeeper:
   282  		hc.DcsType = dcsTypeZookeeper
   283  		hc.DcsEndpoints = hosts // Use the split hosts directly
   284  	case dcsTypeConsul:
   285  		hc.DcsType = dcsTypeConsul
   286  		hc.DcsEndpoints = hosts // Use the split hosts directly
   287  	default:
   288  		return hc, fmt.Errorf("unsupported DCS type: %s", url.Scheme)
   289  	}
   290  
   291  	hc.Path = url.Path
   292  	hc.Username = url.User.Username()
   293  	hc.Password, _ = url.User.Password() // password is optional, so we ignore the error
   294  	hc.CAFile = url.Query().Get("ca_file")
   295  	hc.CertFile = url.Query().Get("cert_file")
   296  	hc.KeyFile = url.Query().Get("key_file")
   297  
   298  	return hc, nil
   299  }
   300  
   301  func ResolveDatabasesFromPatroni(source Source) (SourceConns, error) {
   302  	var mds []*SourceConn
   303  	var clusterMembers []PatroniClusterMember
   304  	var err error
   305  	var ok bool
   306  
   307  	hostConfig, err := NewHostConfig(source.ConnStr)
   308  	if err != nil {
   309  		return nil, err
   310  	}
   311  
   312  	switch hostConfig.DcsType {
   313  	case dcsTypeEtcd:
   314  		clusterMembers, err = getEtcdClusterMembers(source, hostConfig)
   315  	case dcsTypeZookeeper:
   316  		clusterMembers, err = getZookeeperClusterMembers(source, hostConfig)
   317  	case dcsTypeConsul:
   318  		clusterMembers, err = getConsulClusterMembers(source)
   319  	default:
   320  		return nil, errors.New("unknown DCS")
   321  	}
   322  	logger := logger.WithField("source", source.Name)
   323  	if err != nil {
   324  		if errors.Is(err, errors.ErrUnsupported) {
   325  			return nil, err
   326  		}
   327  		logger.Debug("failed to get info from DCS, using previous member info if any")
   328  		if clusterMembers, ok = lastFoundClusterMembers[source.Name]; ok { // mask error from main loop not to remove monitored DBs due to "jitter"
   329  			err = nil
   330  		}
   331  	} else {
   332  		lastFoundClusterMembers[source.Name] = clusterMembers
   333  	}
   334  	if len(clusterMembers) == 0 {
   335  		return mds, err
   336  	}
   337  
   338  	for _, patroniMember := range clusterMembers {
   339  		logger.Info("processing Patroni cluster member: ", patroniMember.Name)
   340  		if source.OnlyIfMaster && !patroniMember.IsPrimary() {
   341  			continue
   342  		}
   343  		src := *source.Clone()
   344  		src.ConnStr = patroniMember.ConnURL
   345  		if !hostConfig.IsScopeSpecified() {
   346  			src.Name += "_" + patroniMember.Scope
   347  		}
   348  		src.Name += "_" + patroniMember.Name
   349  		if dbs, err := ResolveDatabasesFromPostgres(src); err == nil {
   350  			mds = append(mds, dbs...)
   351  		} else {
   352  			logger.WithError(err).Error("failed to resolve databases for Patroni member: ", patroniMember.Name)
   353  		}
   354  	}
   355  	return mds, err
   356  }
   357  
   358  // ResolveDatabasesFromPostgres reads all the databases from the given cluster,
   359  // additionally matching/not matching specified regex patterns
   360  func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
   361  	var (
   362  		c      db.PgxPoolIface
   363  		dbname string
   364  		rows   pgx.Rows
   365  	)
   366  	c, err = NewConn(context.TODO(), s.ConnStr)
   367  	if err != nil {
   368  		return
   369  	}
   370  	defer c.Close()
   371  
   372  	sql := `select /* pgwatch_generated */
   373  	datname
   374  	from pg_database
   375  	where not datistemplate
   376  	and datallowconn
   377  	and has_database_privilege (datname, 'CONNECT')
   378  	and case when length(trim($1)) > 0 then datname ~ $1 else true end
   379  	and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
   380  
   381  	if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
   382  		return nil, err
   383  	}
   384  	for rows.Next() {
   385  		if err = rows.Scan(&dbname); err != nil {
   386  			return nil, err
   387  		}
   388  		rdb := NewSourceConn(*s.Clone())
   389  		rdb.Name += "_" + dbname
   390  		rdb.SetDatabaseName(dbname)
   391  		resolvedDbs = append(resolvedDbs, rdb)
   392  	}
   393  
   394  	if err := rows.Err(); err != nil {
   395  		return nil, err
   396  	}
   397  	return
   398  }
   399