...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/sources/resolver.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/sources

     1  package sources
     2  
     3  // This file contains the implemendation of Patroni and PostgrSQL resolvers for continuous monitoring.
     4  // Patroni resolver will return the list of databases from the Patroni cluster.
     5  // Postgres resolver will return the list of databases from the given Postgres instance.
     6  
     7  import (
     8  	"cmp"
     9  	"context"
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"errors"
    13  	"fmt"
    14  	"net/url"
    15  	"os"
    16  	"strings"
    17  	"time"
    18  
    19  	jsoniter "github.com/json-iterator/go"
    20  
    21  	"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
    22  	"github.com/cybertec-postgresql/pgwatch/v3/internal/log"
    23  	pgx "github.com/jackc/pgx/v5"
    24  	client "go.etcd.io/etcd/client/v3"
    25  	"go.uber.org/zap"
    26  )
    27  
    28  // ResolveDatabases() updates list of monitored objects from continuous monitoring sources, e.g. patroni
    29  func (srcs Sources) ResolveDatabases() (_ SourceConns, err error) {
    30  	resolvedDbs := make(SourceConns, 0, len(srcs))
    31  	for _, s := range srcs {
    32  		if !s.IsEnabled {
    33  			continue
    34  		}
    35  		dbs, e := s.ResolveDatabases()
    36  		err = errors.Join(err, e)
    37  		resolvedDbs = append(resolvedDbs, dbs...)
    38  	}
    39  	return resolvedDbs, err
    40  }
    41  
    42  // ResolveDatabases() return a slice of found databases for continuous monitoring sources, e.g. patroni
    43  func (s Source) ResolveDatabases() (SourceConns, error) {
    44  	switch s.Kind {
    45  	case SourcePatroni:
    46  		return ResolveDatabasesFromPatroni(s)
    47  	case SourcePostgresContinuous:
    48  		return ResolveDatabasesFromPostgres(s)
    49  	}
    50  	return SourceConns{&SourceConn{Source: s}}, nil
    51  }
    52  
    53  type PatroniClusterMember struct {
    54  	Scope   string
    55  	Name    string
    56  	ConnURL string `yaml:"conn_url"`
    57  	Role    string
    58  }
    59  
    60  var logger log.Logger = log.FallbackLogger
    61  
    62  var lastFoundClusterMembers = make(map[string][]PatroniClusterMember) // needed for cases where DCS is temporarily down
    63  // don't want to immediately remove monitoring of DBs
    64  
    65  func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
    66  	return nil, errors.ErrUnsupported
    67  }
    68  
    69  func getZookeeperClusterMembers(Source, HostConfig) ([]PatroniClusterMember, error) {
    70  	return nil, errors.ErrUnsupported
    71  }
    72  
    73  func jsonTextToStringMap(jsonText string) (map[string]string, error) {
    74  	retmap := make(map[string]string)
    75  	if jsonText == "" {
    76  		return retmap, nil
    77  	}
    78  	var iMap map[string]any
    79  	if err := jsoniter.ConfigFastest.Unmarshal([]byte(jsonText), &iMap); err != nil {
    80  		return nil, err
    81  	}
    82  	for k, v := range iMap {
    83  		retmap[k] = fmt.Sprintf("%v", v)
    84  	}
    85  	return retmap, nil
    86  }
    87  
    88  func getTransport(conf HostConfig) (*tls.Config, error) {
    89  	var caCertPool *x509.CertPool
    90  
    91  	// create valid CertPool only if the ca certificate file exists
    92  	if conf.CAFile != "" {
    93  		caCert, err := os.ReadFile(conf.CAFile)
    94  		if err != nil {
    95  			return nil, fmt.Errorf("cannot load CA file: %s", err)
    96  		}
    97  
    98  		caCertPool = x509.NewCertPool()
    99  		caCertPool.AppendCertsFromPEM(caCert)
   100  	}
   101  
   102  	var certificates []tls.Certificate
   103  
   104  	// create valid []Certificate only if the client cert and key files exists
   105  	if conf.CertFile != "" && conf.KeyFile != "" {
   106  		cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
   107  		if err != nil {
   108  			return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
   109  		}
   110  
   111  		certificates = []tls.Certificate{cert}
   112  	}
   113  
   114  	tlsClientConfig := new(tls.Config)
   115  
   116  	if caCertPool != nil {
   117  		tlsClientConfig.RootCAs = caCertPool
   118  		if certificates != nil {
   119  			tlsClientConfig.Certificates = certificates
   120  		}
   121  	}
   122  
   123  	return tlsClientConfig, nil
   124  }
   125  
   126  func getEtcdClusterMembers(s Source, hc HostConfig) ([]PatroniClusterMember, error) {
   127  	var ret = make([]PatroniClusterMember, 0)
   128  	var cfg client.Config
   129  
   130  	if len(hc.DcsEndpoints) == 0 {
   131  		return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
   132  	}
   133  
   134  	tlsConfig, err := getTransport(hc)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	cfg = client.Config{
   139  		Endpoints:            hc.DcsEndpoints,
   140  		TLS:                  tlsConfig,
   141  		DialKeepAliveTimeout: time.Second,
   142  		Username:             hc.Username,
   143  		Password:             hc.Password,
   144  		DialTimeout:          5 * time.Second,
   145  		Logger:               zap.NewNop(),
   146  	}
   147  
   148  	c, err := client.New(cfg)
   149  	if err != nil {
   150  		return ret, err
   151  	}
   152  	defer c.Close()
   153  
   154  	ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
   155  	defer cancel()
   156  
   157  	// etcd3 does not have a dir node.
   158  	// Key="/namespace/scope/leader", e.g. "/service/batman/leader"
   159  	// Key="/namespace/scope/members/node", e.g. "/service/batman/members/pg1"
   160  
   161  	resp, err := c.Get(ctx, hc.Path, client.WithPrefix())
   162  	if err != nil {
   163  		return ret, cmp.Or(context.Cause(ctx), err)
   164  	}
   165  
   166  	for _, node := range resp.Kvs {
   167  		nodeData, err := jsonTextToStringMap(string(node.Value))
   168  		if err != nil {
   169  			logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node, err)
   170  			continue
   171  		}
   172  		// remove leading slash and split by "/"
   173  		parts := strings.Split(strings.TrimPrefix(string(node.Key), "/"), "/")
   174  		if len(parts) < 3 {
   175  			return nil, errors.New("invalid ETCD key format")
   176  		}
   177  		role := nodeData["role"]
   178  		connURL := nodeData["conn_url"]
   179  		scope := parts[1]
   180  		name := parts[3]
   181  		ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
   182  	}
   183  
   184  	lastFoundClusterMembers[s.Name] = ret
   185  	return ret, nil
   186  }
   187  
   188  const (
   189  	dcsTypeEtcd      = "etcd"
   190  	dcsTypeZookeeper = "zookeeper"
   191  	dcsTypeConsul    = "consul"
   192  )
   193  
   194  type HostConfig struct {
   195  	DcsType      string   `yaml:"dcs_type"`
   196  	DcsEndpoints []string `yaml:"dcs_endpoints"`
   197  	Path         string
   198  	Username     string
   199  	Password     string
   200  	CAFile       string `yaml:"ca_file"`
   201  	CertFile     string `yaml:"cert_file"`
   202  	KeyFile      string `yaml:"key_file"`
   203  }
   204  
   205  func (hc HostConfig) IsScopeSpecified() bool {
   206  	// Path is usually "/namespace/scope"
   207  	// so we check if it has at least 2 slashes
   208  	return strings.Count(hc.Path, "/") >= 2
   209  }
   210  
   211  func NewHostConfig(URI string) (hc HostConfig, err error) {
   212  	var url *url.URL
   213  	url, err = url.Parse(URI)
   214  	if err != nil {
   215  		return
   216  	}
   217  
   218  	switch url.Scheme {
   219  	case dcsTypeEtcd:
   220  		hc.DcsType = dcsTypeEtcd
   221  		for h := range strings.SplitSeq(url.Host, ",") {
   222  			hc.DcsEndpoints = append(hc.DcsEndpoints, "http://"+h)
   223  		}
   224  	case dcsTypeZookeeper:
   225  		hc.DcsType = dcsTypeZookeeper
   226  		hc.DcsEndpoints = []string{url.Host} // Zookeeper usually has a
   227  		// single endpoint, but can be a list of hosts separated by commas
   228  	case dcsTypeConsul:
   229  		hc.DcsType = dcsTypeConsul
   230  		hc.DcsEndpoints = strings.Split(url.Host, ",")
   231  	default:
   232  		return hc, fmt.Errorf("unsupported DCS type: %s", url.Scheme)
   233  	}
   234  
   235  	hc.Path = url.Path
   236  	hc.Username = url.User.Username()
   237  	hc.Password, _ = url.User.Password() // password is optional, so we ignore the error
   238  	hc.CAFile = url.Query().Get("ca_file")
   239  	hc.CertFile = url.Query().Get("cert_file")
   240  	hc.KeyFile = url.Query().Get("key_file")
   241  
   242  	return hc, nil
   243  }
   244  
   245  func ResolveDatabasesFromPatroni(source Source) (SourceConns, error) {
   246  	var mds []*SourceConn
   247  	var clusterMembers []PatroniClusterMember
   248  	var err error
   249  	var ok bool
   250  
   251  	hostConfig, err := NewHostConfig(source.ConnStr)
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  
   256  	switch hostConfig.DcsType {
   257  	case dcsTypeEtcd:
   258  		clusterMembers, err = getEtcdClusterMembers(source, hostConfig)
   259  	case dcsTypeZookeeper:
   260  		clusterMembers, err = getZookeeperClusterMembers(source, hostConfig)
   261  	case dcsTypeConsul:
   262  		clusterMembers, err = getConsulClusterMembers(source)
   263  	default:
   264  		return nil, errors.New("unknown DCS")
   265  	}
   266  	logger := logger.WithField("sorce", source.Name)
   267  	if err != nil {
   268  		if errors.Is(err, errors.ErrUnsupported) {
   269  			return nil, err
   270  		}
   271  		logger.Debug("Failed to get info from DCS, using previous member info if any")
   272  		if clusterMembers, ok = lastFoundClusterMembers[source.Name]; ok { // mask error from main loop not to remove monitored DBs due to "jitter"
   273  			err = nil
   274  		}
   275  	} else {
   276  		lastFoundClusterMembers[source.Name] = clusterMembers
   277  	}
   278  	if len(clusterMembers) == 0 {
   279  		return mds, err
   280  	}
   281  
   282  	for _, patroniMember := range clusterMembers {
   283  		logger.Info("Processing Patroni cluster member: ", patroniMember.Name)
   284  		if source.OnlyIfMaster && patroniMember.Role != "master" {
   285  			continue
   286  		}
   287  		src := *source.Clone()
   288  		src.ConnStr = patroniMember.ConnURL
   289  		if hostConfig.IsScopeSpecified() {
   290  			src.Name += "_" + patroniMember.Scope
   291  		}
   292  		src.Name += "_" + patroniMember.Name
   293  		if dbs, err := ResolveDatabasesFromPostgres(src); err == nil {
   294  			mds = append(mds, dbs...)
   295  		} else {
   296  			logger.WithError(err).Error("Failed to resolve databases for Patroni member: ", patroniMember.Name)
   297  		}
   298  	}
   299  	return mds, err
   300  }
   301  
   302  // ResolveDatabasesFromPostgres reads all the databases from the given cluster,
   303  // additionally matching/not matching specified regex patterns
   304  func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
   305  	var (
   306  		c      db.PgxPoolIface
   307  		dbname string
   308  		rows   pgx.Rows
   309  	)
   310  	c, err = NewConn(context.TODO(), s.ConnStr)
   311  	if err != nil {
   312  		return
   313  	}
   314  	defer c.Close()
   315  
   316  	sql := `select /* pgwatch_generated */
   317  	datname
   318  	from pg_database
   319  	where not datistemplate
   320  	and datallowconn
   321  	and has_database_privilege (datname, 'CONNECT')
   322  	and case when length(trim($1)) > 0 then datname ~ $1 else true end
   323  	and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
   324  
   325  	if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
   326  		return nil, err
   327  	}
   328  	for rows.Next() {
   329  		if err = rows.Scan(&dbname); err != nil {
   330  			return nil, err
   331  		}
   332  		rdb := &SourceConn{Source: *s.Clone()}
   333  		rdb.Name += "_" + dbname
   334  		rdb.SetDatabaseName(dbname)
   335  		resolvedDbs = append(resolvedDbs, rdb)
   336  	}
   337  
   338  	if err := rows.Err(); err != nil {
   339  		return nil, err
   340  	}
   341  	return
   342  }
   343