...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/sources/resolver.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/sources

     1  package sources
     2  
     3  // This file contains the implemendation of Patroni and PostgrSQL resolvers for continuous monitoring.
     4  // Patroni resolver will return the list of databases from the Patroni cluster.
     5  // Postgres resolver will return the list of databases from the given Postgres instance.
     6  
     7  import (
     8  	"cmp"
     9  	"context"
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"errors"
    13  	"fmt"
    14  	"net/url"
    15  	"os"
    16  	"strings"
    17  	"time"
    18  
    19  	jsoniter "github.com/json-iterator/go"
    20  
    21  	"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
    22  	"github.com/cybertec-postgresql/pgwatch/v3/internal/log"
    23  	pgx "github.com/jackc/pgx/v5"
    24  	client "go.etcd.io/etcd/client/v3"
    25  	"go.uber.org/zap"
    26  )
    27  
    28  // ResolveDatabases() updates list of monitored objects from continuous monitoring sources, e.g. patroni
    29  func (srcs Sources) ResolveDatabases() (_ SourceConns, err error) {
    30  	resolvedDbs := make(SourceConns, 0, len(srcs))
    31  	for _, s := range srcs {
    32  		if !s.IsEnabled {
    33  			continue
    34  		}
    35  		dbs, e := s.ResolveDatabases()
    36  		err = errors.Join(err, e)
    37  		resolvedDbs = append(resolvedDbs, dbs...)
    38  	}
    39  	return resolvedDbs, err
    40  }
    41  
    42  // ResolveDatabases() return a slice of found databases for continuous monitoring sources, e.g. patroni
    43  func (s Source) ResolveDatabases() (SourceConns, error) {
    44  	switch s.Kind {
    45  	case SourcePatroni:
    46  		return ResolveDatabasesFromPatroni(s)
    47  	case SourcePostgresContinuous:
    48  		return ResolveDatabasesFromPostgres(s)
    49  	}
    50  	return SourceConns{NewSourceConn(s)}, nil
    51  }
    52  
    53  type PatroniClusterMember struct {
    54  	Scope   string
    55  	Name    string
    56  	ConnURL string `yaml:"conn_url"`
    57  	Role    string
    58  }
    59  
    60  func (pcm PatroniClusterMember) IsPrimary() bool {
    61  	return pcm.Role == "primary" || pcm.Role == "master"
    62  }
    63  
    64  var logger log.Logger = log.FallbackLogger
    65  
    66  var lastFoundClusterMembers = make(map[string][]PatroniClusterMember) // needed for cases where DCS is temporarily down
    67  // don't want to immediately remove monitoring of DBs
    68  
    69  func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
    70  	return nil, errors.ErrUnsupported
    71  }
    72  
    73  func getZookeeperClusterMembers(Source, HostConfig) ([]PatroniClusterMember, error) {
    74  	return nil, errors.ErrUnsupported
    75  }
    76  
    77  func jsonTextToStringMap(jsonText string) (map[string]string, error) {
    78  	retmap := make(map[string]string)
    79  	if jsonText == "" {
    80  		return retmap, nil
    81  	}
    82  	var iMap map[string]any
    83  	if err := jsoniter.ConfigFastest.Unmarshal([]byte(jsonText), &iMap); err != nil {
    84  		return nil, err
    85  	}
    86  	for k, v := range iMap {
    87  		retmap[k] = fmt.Sprintf("%v", v)
    88  	}
    89  	return retmap, nil
    90  }
    91  
    92  func getTransport(conf HostConfig) (*tls.Config, error) {
    93  	var caCertPool *x509.CertPool
    94  
    95  	// create valid CertPool only if the ca certificate file exists
    96  	if conf.CAFile != "" {
    97  		caCert, err := os.ReadFile(conf.CAFile)
    98  		if err != nil {
    99  			return nil, fmt.Errorf("cannot load CA file: %s", err)
   100  		}
   101  
   102  		caCertPool = x509.NewCertPool()
   103  		caCertPool.AppendCertsFromPEM(caCert)
   104  	}
   105  
   106  	var certificates []tls.Certificate
   107  
   108  	// create valid []Certificate only if the client cert and key files exists
   109  	if conf.CertFile != "" && conf.KeyFile != "" {
   110  		cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
   111  		if err != nil {
   112  			return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
   113  		}
   114  
   115  		certificates = []tls.Certificate{cert}
   116  	}
   117  
   118  	tlsClientConfig := new(tls.Config)
   119  
   120  	if caCertPool != nil {
   121  		tlsClientConfig.RootCAs = caCertPool
   122  		if certificates != nil {
   123  			tlsClientConfig.Certificates = certificates
   124  		}
   125  	}
   126  
   127  	return tlsClientConfig, nil
   128  }
   129  
   130  func getEtcdClusterMembers(s Source, hc HostConfig) ([]PatroniClusterMember, error) {
   131  	var ret = make([]PatroniClusterMember, 0)
   132  	var cfg client.Config
   133  
   134  	if len(hc.DcsEndpoints) == 0 {
   135  		return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
   136  	}
   137  
   138  	tlsConfig, err := getTransport(hc)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  	cfg = client.Config{
   143  		Endpoints:            hc.DcsEndpoints,
   144  		TLS:                  tlsConfig,
   145  		DialKeepAliveTimeout: time.Second,
   146  		Username:             hc.Username,
   147  		Password:             hc.Password,
   148  		DialTimeout:          5 * time.Second,
   149  		Logger:               zap.NewNop(),
   150  	}
   151  
   152  	c, err := client.New(cfg)
   153  	if err != nil {
   154  		return ret, err
   155  	}
   156  	defer c.Close()
   157  
   158  	ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
   159  	defer cancel()
   160  
   161  	// etcd3 does not have a dir node.
   162  	// Key="/namespace/scope/leader", e.g. "/service/batman/leader"
   163  	// Key="/namespace/scope/members/node", e.g. "/service/batman/members/pg1"
   164  
   165  	resp, err := c.Get(ctx, hc.Path, client.WithPrefix())
   166  	if err != nil {
   167  		return ret, cmp.Or(context.Cause(ctx), err)
   168  	}
   169  
   170  	for _, node := range resp.Kvs {
   171  		// remove leading slash and split by "/"
   172  		parts := strings.Split(strings.TrimPrefix(string(node.Key), "/"), "/")
   173  		if len(parts) < 4 || parts[2] != "members" {
   174  			continue // skip non-member keys
   175  		}
   176  		nodeData, err := jsonTextToStringMap(string(node.Value))
   177  		if err != nil {
   178  			logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node.Key, err)
   179  			continue
   180  		}
   181  		role := nodeData["role"]
   182  		connURL := nodeData["conn_url"]
   183  		scope := parts[1]
   184  		name := parts[3]
   185  		ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
   186  	}
   187  
   188  	lastFoundClusterMembers[s.Name] = ret
   189  	return ret, nil
   190  }
   191  
   192  const (
   193  	dcsTypeEtcd      = "etcd"
   194  	dcsTypeZookeeper = "zookeeper"
   195  	dcsTypeConsul    = "consul"
   196  )
   197  
   198  type HostConfig struct {
   199  	DcsType      string   `yaml:"dcs_type"`
   200  	DcsEndpoints []string `yaml:"dcs_endpoints"`
   201  	Path         string
   202  	Username     string
   203  	Password     string
   204  	CAFile       string `yaml:"ca_file"`
   205  	CertFile     string `yaml:"cert_file"`
   206  	KeyFile      string `yaml:"key_file"`
   207  }
   208  
   209  func (hc HostConfig) IsScopeSpecified() bool {
   210  	// Path is usually "/namespace/scope"
   211  	// so we check if it has at least 2 slashes
   212  	return strings.Count(hc.Path, "/") >= 2
   213  }
   214  
   215  func NewHostConfig(URI string) (hc HostConfig, err error) {
   216  	var url *url.URL
   217  	url, err = url.Parse(URI)
   218  	if err != nil {
   219  		return
   220  	}
   221  
   222  	switch url.Scheme {
   223  	case dcsTypeEtcd:
   224  		hc.DcsType = dcsTypeEtcd
   225  		for h := range strings.SplitSeq(url.Host, ",") {
   226  			hc.DcsEndpoints = append(hc.DcsEndpoints, "http://"+h)
   227  		}
   228  	case dcsTypeZookeeper:
   229  		hc.DcsType = dcsTypeZookeeper
   230  		hc.DcsEndpoints = []string{url.Host} // Zookeeper usually has a
   231  		// single endpoint, but can be a list of hosts separated by commas
   232  	case dcsTypeConsul:
   233  		hc.DcsType = dcsTypeConsul
   234  		hc.DcsEndpoints = strings.Split(url.Host, ",")
   235  	default:
   236  		return hc, fmt.Errorf("unsupported DCS type: %s", url.Scheme)
   237  	}
   238  
   239  	hc.Path = url.Path
   240  	hc.Username = url.User.Username()
   241  	hc.Password, _ = url.User.Password() // password is optional, so we ignore the error
   242  	hc.CAFile = url.Query().Get("ca_file")
   243  	hc.CertFile = url.Query().Get("cert_file")
   244  	hc.KeyFile = url.Query().Get("key_file")
   245  
   246  	return hc, nil
   247  }
   248  
   249  func ResolveDatabasesFromPatroni(source Source) (SourceConns, error) {
   250  	var mds []*SourceConn
   251  	var clusterMembers []PatroniClusterMember
   252  	var err error
   253  	var ok bool
   254  
   255  	hostConfig, err := NewHostConfig(source.ConnStr)
   256  	if err != nil {
   257  		return nil, err
   258  	}
   259  
   260  	switch hostConfig.DcsType {
   261  	case dcsTypeEtcd:
   262  		clusterMembers, err = getEtcdClusterMembers(source, hostConfig)
   263  	case dcsTypeZookeeper:
   264  		clusterMembers, err = getZookeeperClusterMembers(source, hostConfig)
   265  	case dcsTypeConsul:
   266  		clusterMembers, err = getConsulClusterMembers(source)
   267  	default:
   268  		return nil, errors.New("unknown DCS")
   269  	}
   270  	logger := logger.WithField("sorce", source.Name)
   271  	if err != nil {
   272  		if errors.Is(err, errors.ErrUnsupported) {
   273  			return nil, err
   274  		}
   275  		logger.Debug("failed to get info from DCS, using previous member info if any")
   276  		if clusterMembers, ok = lastFoundClusterMembers[source.Name]; ok { // mask error from main loop not to remove monitored DBs due to "jitter"
   277  			err = nil
   278  		}
   279  	} else {
   280  		lastFoundClusterMembers[source.Name] = clusterMembers
   281  	}
   282  	if len(clusterMembers) == 0 {
   283  		return mds, err
   284  	}
   285  
   286  	for _, patroniMember := range clusterMembers {
   287  		logger.Info("processing Patroni cluster member: ", patroniMember.Name)
   288  		if source.OnlyIfMaster && !patroniMember.IsPrimary() {
   289  			continue
   290  		}
   291  		src := *source.Clone()
   292  		src.ConnStr = patroniMember.ConnURL
   293  		if !hostConfig.IsScopeSpecified() {
   294  			src.Name += "_" + patroniMember.Scope
   295  		}
   296  		src.Name += "_" + patroniMember.Name
   297  		if dbs, err := ResolveDatabasesFromPostgres(src); err == nil {
   298  			mds = append(mds, dbs...)
   299  		} else {
   300  			logger.WithError(err).Error("failed to resolve databases for Patroni member: ", patroniMember.Name)
   301  		}
   302  	}
   303  	return mds, err
   304  }
   305  
   306  // ResolveDatabasesFromPostgres reads all the databases from the given cluster,
   307  // additionally matching/not matching specified regex patterns
   308  func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
   309  	var (
   310  		c      db.PgxPoolIface
   311  		dbname string
   312  		rows   pgx.Rows
   313  	)
   314  	c, err = NewConn(context.TODO(), s.ConnStr)
   315  	if err != nil {
   316  		return
   317  	}
   318  	defer c.Close()
   319  
   320  	sql := `select /* pgwatch_generated */
   321  	datname
   322  	from pg_database
   323  	where not datistemplate
   324  	and datallowconn
   325  	and has_database_privilege (datname, 'CONNECT')
   326  	and case when length(trim($1)) > 0 then datname ~ $1 else true end
   327  	and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
   328  
   329  	if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
   330  		return nil, err
   331  	}
   332  	for rows.Next() {
   333  		if err = rows.Scan(&dbname); err != nil {
   334  			return nil, err
   335  		}
   336  		rdb := NewSourceConn(*s.Clone())
   337  		rdb.Name += "_" + dbname
   338  		rdb.SetDatabaseName(dbname)
   339  		resolvedDbs = append(resolvedDbs, rdb)
   340  	}
   341  
   342  	if err := rows.Err(); err != nil {
   343  		return nil, err
   344  	}
   345  	return
   346  }
   347