1 package sources
2
3
4
5
6
7 import (
8 "cmp"
9 "context"
10 "crypto/tls"
11 "crypto/x509"
12 "errors"
13 "fmt"
14 "net/url"
15 "os"
16 "strings"
17 "sync"
18 "time"
19
20 jsoniter "github.com/json-iterator/go"
21
22 "github.com/cybertec-postgresql/pgwatch/v5/internal/db"
23 "github.com/cybertec-postgresql/pgwatch/v5/internal/log"
24 pgx "github.com/jackc/pgx/v5"
25 client "go.etcd.io/etcd/client/v3"
26 "go.uber.org/zap"
27 )
28
29
30
31 func (srcs Sources) ResolveDatabases(onError func(string)) (_ SourceConns, err error) {
32 type result struct {
33 dbs SourceConns
34 err error
35 }
36 results := make([]result, len(srcs))
37 var wg sync.WaitGroup
38 for i, s := range srcs {
39 wg.Go(func() {
40 dbs, e := s.ResolveDatabases()
41 results[i] = result{dbs, e}
42 })
43 }
44 wg.Wait()
45 resolvedDbs := make(SourceConns, 0, len(srcs))
46 for i, res := range results {
47 if res.err != nil {
48 if onError != nil {
49 onError(srcs[i].Name)
50 }
51 logger.WithField("source", srcs[i].Name).WithError(res.err).Error("could not resolve databases from source")
52 err = errors.Join(err, res.err)
53 }
54 resolvedDbs = append(resolvedDbs, res.dbs...)
55 }
56 return resolvedDbs, err
57 }
58
59
60 func (s Source) ResolveDatabases() (SourceConns, error) {
61 switch s.Kind {
62 case SourcePatroni:
63 return ResolveDatabasesFromPatroni(s)
64 case SourcePostgresContinuous:
65 return ResolveDatabasesFromPostgres(s)
66 }
67 return SourceConns{NewSourceConn(s)}, nil
68 }
69
70 type PatroniClusterMember struct {
71 Scope string
72 Name string
73 ConnURL string `yaml:"conn_url"`
74 Role string
75 }
76
77 func (pcm PatroniClusterMember) IsPrimary() bool {
78 return pcm.Role == "primary" || pcm.Role == "master"
79 }
80
81 var logger log.Logger = log.FallbackLogger
82
83 var lastFoundClusterMembers = make(map[string][]PatroniClusterMember)
84
85
86 func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
87 return nil, errors.ErrUnsupported
88 }
89
90 func getZookeeperClusterMembers(Source, HostConfig) ([]PatroniClusterMember, error) {
91 return nil, errors.ErrUnsupported
92 }
93
94 func jsonTextToStringMap(jsonText string) (map[string]string, error) {
95 retmap := make(map[string]string)
96 if jsonText == "" {
97 return retmap, nil
98 }
99 var iMap map[string]any
100 if err := jsoniter.ConfigFastest.Unmarshal([]byte(jsonText), &iMap); err != nil {
101 return nil, err
102 }
103 for k, v := range iMap {
104 retmap[k] = fmt.Sprintf("%v", v)
105 }
106 return retmap, nil
107 }
108
109 func getTransport(conf HostConfig) (*tls.Config, error) {
110 var caCertPool *x509.CertPool
111
112
113 if conf.CAFile != "" {
114 caCert, err := os.ReadFile(conf.CAFile)
115 if err != nil {
116 return nil, fmt.Errorf("cannot load CA file: %s", err)
117 }
118
119 caCertPool = x509.NewCertPool()
120 caCertPool.AppendCertsFromPEM(caCert)
121 }
122
123 var certificates []tls.Certificate
124
125
126 if conf.CertFile != "" && conf.KeyFile != "" {
127 cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
128 if err != nil {
129 return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
130 }
131
132 certificates = []tls.Certificate{cert}
133 }
134
135 tlsClientConfig := new(tls.Config)
136
137 if caCertPool != nil {
138 tlsClientConfig.RootCAs = caCertPool
139 if certificates != nil {
140 tlsClientConfig.Certificates = certificates
141 }
142 }
143
144 return tlsClientConfig, nil
145 }
146
147 func getEtcdClusterMembers(s Source, hc HostConfig) ([]PatroniClusterMember, error) {
148 var ret = make([]PatroniClusterMember, 0)
149 var cfg client.Config
150
151 if len(hc.DcsEndpoints) == 0 {
152 return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
153 }
154
155 tlsConfig, err := getTransport(hc)
156 if err != nil {
157 return nil, err
158 }
159 cfg = client.Config{
160 Endpoints: hc.DcsEndpoints,
161 TLS: tlsConfig,
162 DialKeepAliveTimeout: time.Second,
163 Username: hc.Username,
164 Password: hc.Password,
165 DialTimeout: 5 * time.Second,
166 Logger: zap.NewNop(),
167 }
168
169 c, err := client.New(cfg)
170 if err != nil {
171 return ret, err
172 }
173 defer c.Close()
174
175 ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
176 defer cancel()
177
178
179
180
181
182 resp, err := c.Get(ctx, hc.Path, client.WithPrefix())
183 if err != nil {
184 return ret, cmp.Or(context.Cause(ctx), err)
185 }
186
187 for _, node := range resp.Kvs {
188
189 parts := strings.Split(strings.TrimPrefix(string(node.Key), "/"), "/")
190 if len(parts) < 4 || parts[2] != "members" {
191 continue
192 }
193 nodeData, err := jsonTextToStringMap(string(node.Value))
194 if err != nil {
195 logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node.Key, err)
196 continue
197 }
198 role := nodeData["role"]
199 connURL := nodeData["conn_url"]
200 scope := parts[1]
201 name := parts[3]
202 ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
203 }
204
205 lastFoundClusterMembers[s.Name] = ret
206 return ret, nil
207 }
208
209 const (
210 dcsTypeEtcd = "etcd"
211 dcsTypeZookeeper = "zookeeper"
212 dcsTypeConsul = "consul"
213 )
214
215 type HostConfig struct {
216 DcsType string `yaml:"dcs_type"`
217 DcsEndpoints []string `yaml:"dcs_endpoints"`
218 Path string
219 Username string
220 Password string
221 CAFile string `yaml:"ca_file"`
222 CertFile string `yaml:"cert_file"`
223 KeyFile string `yaml:"key_file"`
224 }
225
226 func (hc HostConfig) IsScopeSpecified() bool {
227
228
229 return strings.Count(hc.Path, "/") >= 2
230 }
231
232 func NewHostConfig(URI string) (hc HostConfig, err error) {
233
234 before, after, ok := strings.Cut(URI, "://")
235 if !ok {
236 return hc, fmt.Errorf("invalid URI: missing scheme")
237 }
238 scheme := before
239 remainder := after
240
241
242 hostEnd := strings.IndexAny(remainder, "/?")
243 var hostPart, pathAndQuery string
244 if hostEnd == -1 {
245 hostPart = remainder
246 pathAndQuery = ""
247 } else {
248 hostPart = remainder[:hostEnd]
249 pathAndQuery = remainder[hostEnd:]
250 }
251
252
253 var userInfo string
254 if atIdx := strings.LastIndex(hostPart, "@"); atIdx != -1 {
255 userInfo = hostPart[:atIdx]
256 hostPart = hostPart[atIdx+1:]
257 }
258
259
260 hosts := strings.Split(hostPart, ",")
261
262
263 cleanURI := scheme + "://"
264 if userInfo != "" {
265 cleanURI += userInfo + "@"
266 }
267 cleanURI += hosts[0] + pathAndQuery
268
269 var url *url.URL
270 url, err = url.Parse(cleanURI)
271 if err != nil {
272 return
273 }
274
275 switch url.Scheme {
276 case dcsTypeEtcd:
277 hc.DcsType = dcsTypeEtcd
278 for _, h := range hosts {
279 hc.DcsEndpoints = append(hc.DcsEndpoints, "http://"+h)
280 }
281 case dcsTypeZookeeper:
282 hc.DcsType = dcsTypeZookeeper
283 hc.DcsEndpoints = hosts
284 case dcsTypeConsul:
285 hc.DcsType = dcsTypeConsul
286 hc.DcsEndpoints = hosts
287 default:
288 return hc, fmt.Errorf("unsupported DCS type: %s", url.Scheme)
289 }
290
291 hc.Path = url.Path
292 hc.Username = url.User.Username()
293 hc.Password, _ = url.User.Password()
294 hc.CAFile = url.Query().Get("ca_file")
295 hc.CertFile = url.Query().Get("cert_file")
296 hc.KeyFile = url.Query().Get("key_file")
297
298 return hc, nil
299 }
300
301 func ResolveDatabasesFromPatroni(source Source) (SourceConns, error) {
302 var mds []*SourceConn
303 var clusterMembers []PatroniClusterMember
304 var err error
305 var ok bool
306
307 hostConfig, err := NewHostConfig(source.ConnStr)
308 if err != nil {
309 return nil, err
310 }
311
312 switch hostConfig.DcsType {
313 case dcsTypeEtcd:
314 clusterMembers, err = getEtcdClusterMembers(source, hostConfig)
315 case dcsTypeZookeeper:
316 clusterMembers, err = getZookeeperClusterMembers(source, hostConfig)
317 case dcsTypeConsul:
318 clusterMembers, err = getConsulClusterMembers(source)
319 default:
320 return nil, errors.New("unknown DCS")
321 }
322 logger := logger.WithField("source", source.Name)
323 if err != nil {
324 if errors.Is(err, errors.ErrUnsupported) {
325 return nil, err
326 }
327 logger.Debug("failed to get info from DCS, using previous member info if any")
328 if clusterMembers, ok = lastFoundClusterMembers[source.Name]; ok {
329 err = nil
330 }
331 } else {
332 lastFoundClusterMembers[source.Name] = clusterMembers
333 }
334 if len(clusterMembers) == 0 {
335 return mds, err
336 }
337
338 for _, patroniMember := range clusterMembers {
339 logger.Info("processing Patroni cluster member: ", patroniMember.Name)
340 if source.OnlyIfMaster && !patroniMember.IsPrimary() {
341 continue
342 }
343 src := *source.Clone()
344 src.ConnStr = patroniMember.ConnURL
345 if !hostConfig.IsScopeSpecified() {
346 src.Name += "_" + patroniMember.Scope
347 }
348 src.Name += "_" + patroniMember.Name
349 if dbs, err := ResolveDatabasesFromPostgres(src); err == nil {
350 mds = append(mds, dbs...)
351 } else {
352 logger.WithError(err).Error("failed to resolve databases for Patroni member: ", patroniMember.Name)
353 }
354 }
355 return mds, err
356 }
357
358
359
360 func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
361 var (
362 c db.PgxPoolIface
363 dbname string
364 rows pgx.Rows
365 )
366 c, err = NewConn(context.TODO(), s.ConnStr)
367 if err != nil {
368 return
369 }
370 defer c.Close()
371
372 sql := `select /* pgwatch_generated */
373 datname
374 from pg_database
375 where not datistemplate
376 and datallowconn
377 and has_database_privilege (datname, 'CONNECT')
378 and case when length(trim($1)) > 0 then datname ~ $1 else true end
379 and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
380
381 if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
382 return nil, err
383 }
384 for rows.Next() {
385 if err = rows.Scan(&dbname); err != nil {
386 return nil, err
387 }
388 rdb := NewSourceConn(*s.Clone())
389 rdb.Name += "_" + dbname
390 rdb.SetDatabaseName(dbname)
391 resolvedDbs = append(resolvedDbs, rdb)
392 }
393
394 if err := rows.Err(); err != nil {
395 return nil, err
396 }
397 return
398 }
399