1 package sources
2
3
4
5
6
7 import (
8 "cmp"
9 "context"
10 "crypto/tls"
11 "crypto/x509"
12 "encoding/json"
13 "errors"
14 "fmt"
15 "net/url"
16 "os"
17 "path"
18 "regexp"
19 "strings"
20 "time"
21
22 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
23 "github.com/cybertec-postgresql/pgwatch/v3/internal/log"
24 pgx "github.com/jackc/pgx/v5"
25 client "go.etcd.io/etcd/client/v3"
26 "go.uber.org/zap"
27 )
28
29
30 func (srcs Sources) ResolveDatabases() (_ SourceConns, err error) {
31 resolvedDbs := make(SourceConns, 0, len(srcs))
32 for _, s := range srcs {
33 if !s.IsEnabled {
34 continue
35 }
36 dbs, e := s.ResolveDatabases()
37 err = errors.Join(err, e)
38 resolvedDbs = append(resolvedDbs, dbs...)
39 }
40 return resolvedDbs, err
41 }
42
43
44 func (s Source) ResolveDatabases() (SourceConns, error) {
45 switch s.Kind {
46 case SourcePatroni, SourcePatroniContinuous, SourcePatroniNamespace:
47 return ResolveDatabasesFromPatroni(s)
48 case SourcePostgresContinuous:
49 return ResolveDatabasesFromPostgres(s)
50 }
51 return SourceConns{&SourceConn{Source: *(&s).Clone()}}, nil
52 }
53
54 type PatroniClusterMember struct {
55 Scope string
56 Name string
57 ConnURL string `yaml:"conn_url"`
58 Role string
59 }
60
61 var logger log.LoggerIface = log.FallbackLogger
62
63 var lastFoundClusterMembers = make(map[string][]PatroniClusterMember)
64
65
66 func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
67 return nil, errors.ErrUnsupported
68 }
69
70 func getZookeeperClusterMembers(Source) ([]PatroniClusterMember, error) {
71 return nil, errors.ErrUnsupported
72 }
73
74 func parseHostAndPortFromJdbcConnStr(connStr string) (string, string, error) {
75 r := regexp.MustCompile(`postgres://(.*)+:([0-9]+)/`)
76 matches := r.FindStringSubmatch(connStr)
77 if len(matches) != 3 {
78 logger.Errorf("Unexpected regex result groups:", matches)
79 return "", "", fmt.Errorf("unexpected regex result groups: %v", matches)
80 }
81 return matches[1], matches[2], nil
82 }
83
84 func jsonTextToStringMap(jsonText string) (map[string]string, error) {
85 retmap := make(map[string]string)
86 if jsonText == "" {
87 return retmap, nil
88 }
89 var iMap map[string]any
90 if err := json.Unmarshal([]byte(jsonText), &iMap); err != nil {
91 return nil, err
92 }
93 for k, v := range iMap {
94 retmap[k] = fmt.Sprintf("%v", v)
95 }
96 return retmap, nil
97 }
98
99 func getTransport(conf HostConfigAttrs) (*tls.Config, error) {
100 var caCertPool *x509.CertPool
101
102
103 if conf.CAFile != "" {
104 caCert, err := os.ReadFile(conf.CAFile)
105 if err != nil {
106 return nil, fmt.Errorf("cannot load CA file: %s", err)
107 }
108
109 caCertPool = x509.NewCertPool()
110 caCertPool.AppendCertsFromPEM(caCert)
111 }
112
113 var certificates []tls.Certificate
114
115
116 if conf.CertFile != "" && conf.KeyFile != "" {
117 cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
118 if err != nil {
119 return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
120 }
121
122 certificates = []tls.Certificate{cert}
123 }
124
125 tlsClientConfig := new(tls.Config)
126
127 if caCertPool != nil {
128 tlsClientConfig.RootCAs = caCertPool
129 if certificates != nil {
130 tlsClientConfig.Certificates = certificates
131 }
132 }
133
134 return tlsClientConfig, nil
135 }
136
137 func getEtcdClusterMembers(s Source) ([]PatroniClusterMember, error) {
138 var ret = make([]PatroniClusterMember, 0)
139 var cfg client.Config
140
141 if len(s.HostConfig.DcsEndpoints) == 0 {
142 return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
143 }
144
145 tlsConfig, err := getTransport(s.HostConfig)
146 if err != nil {
147 return nil, err
148 }
149 cfg = client.Config{
150 Endpoints: s.HostConfig.DcsEndpoints,
151 TLS: tlsConfig,
152 DialKeepAliveTimeout: time.Second,
153 Username: s.HostConfig.Username,
154 Password: s.HostConfig.Password,
155 DialTimeout: 5 * time.Second,
156 Logger: zap.NewNop(),
157 }
158
159 c, err := client.New(cfg)
160 if err != nil {
161 return ret, err
162 }
163 defer c.Close()
164
165 ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
166 defer cancel()
167 kapi := c.KV
168
169 if s.Kind == SourcePatroniNamespace {
170 if len(s.GetDatabaseName()) > 0 {
171 return ret, fmt.Errorf("skipping Patroni entry %s - cannot specify a DB name when monitoring all scopes (regex patterns are supported though)", s.Name)
172 }
173 if s.HostConfig.Namespace == "" {
174 return ret, fmt.Errorf("skipping Patroni entry %s - search 'namespace' not specified", s.Name)
175 }
176 resp, err := kapi.Get(ctx, s.HostConfig.Namespace, client.WithPrefix(), client.WithKeysOnly())
177 if err != nil {
178 return ret, cmp.Or(context.Cause(ctx), err)
179 }
180
181
182
183
184
185
186 scopes := make(map[string]bool, len(resp.Kvs))
187 for _, node := range resp.Kvs {
188 pathSuffix := strings.TrimPrefix(string(node.Key), s.HostConfig.Namespace)
189 scope := strings.SplitN(pathSuffix, "/", 2)[0]
190 scopes[scope] = true
191 }
192
193 for scope := range scopes {
194 scopeMembers, err := extractEtcdScopeMembers(ctx, s, scope, kapi, true)
195 if err != nil {
196 continue
197 }
198 ret = append(ret, scopeMembers...)
199 }
200 } else {
201 ret, err = extractEtcdScopeMembers(ctx, s, s.HostConfig.Scope, kapi, false)
202 if err != nil {
203 return ret, cmp.Or(context.Cause(ctx), err)
204 }
205 }
206 lastFoundClusterMembers[s.Name] = ret
207 return ret, nil
208 }
209
210 func extractEtcdScopeMembers(ctx context.Context, s Source, scope string, kapi client.KV, addScopeToName bool) ([]PatroniClusterMember, error) {
211 var ret = make([]PatroniClusterMember, 0)
212 var name string
213 membersPath := path.Join(s.HostConfig.Namespace, scope, "members")
214
215 resp, err := kapi.Get(ctx, membersPath, client.WithPrefix())
216 if err != nil {
217 return nil, err
218 }
219 logger.Debugf("ETCD response for %s scope %s: %+v", s.Name, scope, resp)
220
221 for _, node := range resp.Kvs {
222 logger.Debugf("Found a cluster member from etcd [%s:%s]: %+v", s.Name, scope, node.Value)
223 nodeData, err := jsonTextToStringMap(string(node.Value))
224 if err != nil {
225 logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node, err)
226 continue
227 }
228 role := nodeData["role"]
229 connURL := nodeData["conn_url"]
230 if addScopeToName {
231 name = scope + "_" + path.Base(string(node.Key))
232 } else {
233 name = path.Base(string(node.Key))
234 }
235
236 ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
237 }
238 return ret, nil
239 }
240
241 const (
242 dcsTypeEtcd = "etcd"
243 dcsTypeZookeeper = "zookeeper"
244 dcsTypeConsul = "consul"
245 )
246
247 func ResolveDatabasesFromPatroni(ce Source) ([]*SourceConn, error) {
248 var mds []*SourceConn
249 var clusterMembers []PatroniClusterMember
250 var err error
251 var ok bool
252 var dbUnique string
253
254 switch ce.HostConfig.DcsType {
255 case dcsTypeEtcd:
256 clusterMembers, err = getEtcdClusterMembers(ce)
257 case dcsTypeZookeeper:
258 clusterMembers, err = getZookeeperClusterMembers(ce)
259 case dcsTypeConsul:
260 clusterMembers, err = getConsulClusterMembers(ce)
261 default:
262 return nil, errors.New("unknown DCS")
263 }
264 if err != nil {
265 logger.WithField("source", ce.Name).Debug("Failed to get info from DCS, using previous member info if any")
266 if clusterMembers, ok = lastFoundClusterMembers[ce.Name]; ok {
267 err = nil
268 }
269 } else {
270 lastFoundClusterMembers[ce.Name] = clusterMembers
271 }
272 if len(clusterMembers) == 0 {
273 return mds, err
274 }
275
276 for _, m := range clusterMembers {
277 logger.Infof("Processing Patroni cluster member [%s:%s]", ce.Name, m.Name)
278 if ce.OnlyIfMaster && m.Role != "master" {
279 logger.Infof("Skipping over Patroni cluster member [%s:%s] as not a master", ce.Name, m.Name)
280 continue
281 }
282 host, port, err := parseHostAndPortFromJdbcConnStr(m.ConnURL)
283 if err != nil {
284 logger.Errorf("Could not parse Patroni conn str \"%s\" [%s:%s]: %v", m.ConnURL, ce.Name, m.Scope, err)
285 continue
286 }
287 if ce.OnlyIfMaster {
288 dbUnique = ce.Name
289 if ce.Kind == SourcePatroniNamespace {
290 dbUnique = ce.Name + "_" + m.Scope
291 }
292 } else {
293 dbUnique = ce.Name + "_" + m.Name
294 }
295 if ce.GetDatabaseName() != "" {
296 c := &SourceConn{Source: *ce.Clone()}
297 c.Name = dbUnique
298 mds = append(mds, c)
299 continue
300 }
301 connURL, err := url.Parse(ce.ConnStr)
302 if err != nil {
303 logger.Errorf("Could not contact Patroni member [%s:%s]: %v", ce.Name, m.Scope, err)
304 continue
305 }
306 connURL.Scheme = "postgresql"
307 connURL.Host = host + ":" + port
308 connURL.Path = "template1"
309 c, err := db.New(context.TODO(), connURL.String())
310 if err != nil {
311 logger.Errorf("Could not contact Patroni member [%s:%s]: %v", ce.Name, m.Scope, err)
312 continue
313 }
314 defer c.Close()
315 sql := `select datname::text as datname,
316 quote_ident(datname)::text as datname_escaped
317 from pg_database
318 where not datistemplate
319 and datallowconn
320 and has_database_privilege (datname, 'CONNECT')
321 and case when length(trim($1)) > 0 then datname ~ $1 else true end
322 and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
323
324 rows, err := c.Query(context.TODO(), sql, ce.IncludePattern, ce.ExcludePattern)
325 if err != nil {
326 return nil, err
327 }
328 data, err := pgx.CollectRows(rows, pgx.RowToMap)
329 if err != nil {
330 logger.Errorf("Could not get DB name listing from Patroni member [%s:%s]: %v", ce.Name, m.Scope, err)
331 continue
332 }
333
334 for _, d := range data {
335 connURL.Path = d["datname"].(string)
336 c := ce.Clone()
337 c.Name = dbUnique + "_" + d["datname_escaped"].(string)
338 c.ConnStr = connURL.String()
339 mds = append(mds, &SourceConn{Source: *c})
340 }
341
342 }
343
344 return mds, err
345 }
346
347
348
349 func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
350 var (
351 c db.PgxPoolIface
352 dbname string
353 rows pgx.Rows
354 )
355 c, err = NewConn(context.TODO(), s.ConnStr)
356 if err != nil {
357 return
358 }
359 defer c.Close()
360
361 sql := `select /* pgwatch_generated */
362 datname
363 from pg_database
364 where not datistemplate
365 and datallowconn
366 and has_database_privilege (datname, 'CONNECT')
367 and case when length(trim($1)) > 0 then datname ~ $1 else true end
368 and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
369
370 if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
371 return nil, err
372 }
373 for rows.Next() {
374 if err = rows.Scan(&dbname); err != nil {
375 return nil, err
376 }
377 rdb := &SourceConn{Source: *s.Clone()}
378 rdb.Name += "_" + dbname
379 rdb.SetDatabaseName(dbname)
380 resolvedDbs = append(resolvedDbs, rdb)
381 }
382
383 if err := rows.Err(); err != nil {
384 return nil, err
385 }
386 return
387 }
388