1 package sources
2
3
4
5
6
7 import (
8 "cmp"
9 "context"
10 "crypto/tls"
11 "crypto/x509"
12 "errors"
13 "fmt"
14 "net/url"
15 "os"
16 "strings"
17 "time"
18
19 jsoniter "github.com/json-iterator/go"
20
21 "github.com/cybertec-postgresql/pgwatch/v5/internal/db"
22 "github.com/cybertec-postgresql/pgwatch/v5/internal/log"
23 pgx "github.com/jackc/pgx/v5"
24 client "go.etcd.io/etcd/client/v3"
25 "go.uber.org/zap"
26 )
27
28
29 func (srcs Sources) ResolveDatabases() (_ SourceConns, err error) {
30 resolvedDbs := make(SourceConns, 0, len(srcs))
31 for _, s := range srcs {
32 if !s.IsEnabled {
33 continue
34 }
35 dbs, e := s.ResolveDatabases()
36 err = errors.Join(err, e)
37 resolvedDbs = append(resolvedDbs, dbs...)
38 }
39 return resolvedDbs, err
40 }
41
42
43 func (s Source) ResolveDatabases() (SourceConns, error) {
44 switch s.Kind {
45 case SourcePatroni:
46 return ResolveDatabasesFromPatroni(s)
47 case SourcePostgresContinuous:
48 return ResolveDatabasesFromPostgres(s)
49 }
50 return SourceConns{NewSourceConn(s)}, nil
51 }
52
53 type PatroniClusterMember struct {
54 Scope string
55 Name string
56 ConnURL string `yaml:"conn_url"`
57 Role string
58 }
59
60 func (pcm PatroniClusterMember) IsPrimary() bool {
61 return pcm.Role == "primary" || pcm.Role == "master"
62 }
63
64 var logger log.Logger = log.FallbackLogger
65
66 var lastFoundClusterMembers = make(map[string][]PatroniClusterMember)
67
68
69 func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
70 return nil, errors.ErrUnsupported
71 }
72
73 func getZookeeperClusterMembers(Source, HostConfig) ([]PatroniClusterMember, error) {
74 return nil, errors.ErrUnsupported
75 }
76
77 func jsonTextToStringMap(jsonText string) (map[string]string, error) {
78 retmap := make(map[string]string)
79 if jsonText == "" {
80 return retmap, nil
81 }
82 var iMap map[string]any
83 if err := jsoniter.ConfigFastest.Unmarshal([]byte(jsonText), &iMap); err != nil {
84 return nil, err
85 }
86 for k, v := range iMap {
87 retmap[k] = fmt.Sprintf("%v", v)
88 }
89 return retmap, nil
90 }
91
92 func getTransport(conf HostConfig) (*tls.Config, error) {
93 var caCertPool *x509.CertPool
94
95
96 if conf.CAFile != "" {
97 caCert, err := os.ReadFile(conf.CAFile)
98 if err != nil {
99 return nil, fmt.Errorf("cannot load CA file: %s", err)
100 }
101
102 caCertPool = x509.NewCertPool()
103 caCertPool.AppendCertsFromPEM(caCert)
104 }
105
106 var certificates []tls.Certificate
107
108
109 if conf.CertFile != "" && conf.KeyFile != "" {
110 cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
111 if err != nil {
112 return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
113 }
114
115 certificates = []tls.Certificate{cert}
116 }
117
118 tlsClientConfig := new(tls.Config)
119
120 if caCertPool != nil {
121 tlsClientConfig.RootCAs = caCertPool
122 if certificates != nil {
123 tlsClientConfig.Certificates = certificates
124 }
125 }
126
127 return tlsClientConfig, nil
128 }
129
130 func getEtcdClusterMembers(s Source, hc HostConfig) ([]PatroniClusterMember, error) {
131 var ret = make([]PatroniClusterMember, 0)
132 var cfg client.Config
133
134 if len(hc.DcsEndpoints) == 0 {
135 return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
136 }
137
138 tlsConfig, err := getTransport(hc)
139 if err != nil {
140 return nil, err
141 }
142 cfg = client.Config{
143 Endpoints: hc.DcsEndpoints,
144 TLS: tlsConfig,
145 DialKeepAliveTimeout: time.Second,
146 Username: hc.Username,
147 Password: hc.Password,
148 DialTimeout: 5 * time.Second,
149 Logger: zap.NewNop(),
150 }
151
152 c, err := client.New(cfg)
153 if err != nil {
154 return ret, err
155 }
156 defer c.Close()
157
158 ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
159 defer cancel()
160
161
162
163
164
165 resp, err := c.Get(ctx, hc.Path, client.WithPrefix())
166 if err != nil {
167 return ret, cmp.Or(context.Cause(ctx), err)
168 }
169
170 for _, node := range resp.Kvs {
171
172 parts := strings.Split(strings.TrimPrefix(string(node.Key), "/"), "/")
173 if len(parts) < 4 || parts[2] != "members" {
174 continue
175 }
176 nodeData, err := jsonTextToStringMap(string(node.Value))
177 if err != nil {
178 logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node.Key, err)
179 continue
180 }
181 role := nodeData["role"]
182 connURL := nodeData["conn_url"]
183 scope := parts[1]
184 name := parts[3]
185 ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
186 }
187
188 lastFoundClusterMembers[s.Name] = ret
189 return ret, nil
190 }
191
192 const (
193 dcsTypeEtcd = "etcd"
194 dcsTypeZookeeper = "zookeeper"
195 dcsTypeConsul = "consul"
196 )
197
198 type HostConfig struct {
199 DcsType string `yaml:"dcs_type"`
200 DcsEndpoints []string `yaml:"dcs_endpoints"`
201 Path string
202 Username string
203 Password string
204 CAFile string `yaml:"ca_file"`
205 CertFile string `yaml:"cert_file"`
206 KeyFile string `yaml:"key_file"`
207 }
208
209 func (hc HostConfig) IsScopeSpecified() bool {
210
211
212 return strings.Count(hc.Path, "/") >= 2
213 }
214
215 func NewHostConfig(URI string) (hc HostConfig, err error) {
216
217 before, after, ok := strings.Cut(URI, "://")
218 if !ok {
219 return hc, fmt.Errorf("invalid URI: missing scheme")
220 }
221 scheme := before
222 remainder := after
223
224
225 hostEnd := strings.IndexAny(remainder, "/?")
226 var hostPart, pathAndQuery string
227 if hostEnd == -1 {
228 hostPart = remainder
229 pathAndQuery = ""
230 } else {
231 hostPart = remainder[:hostEnd]
232 pathAndQuery = remainder[hostEnd:]
233 }
234
235
236 var userInfo string
237 if atIdx := strings.LastIndex(hostPart, "@"); atIdx != -1 {
238 userInfo = hostPart[:atIdx]
239 hostPart = hostPart[atIdx+1:]
240 }
241
242
243 hosts := strings.Split(hostPart, ",")
244
245
246 cleanURI := scheme + "://"
247 if userInfo != "" {
248 cleanURI += userInfo + "@"
249 }
250 cleanURI += hosts[0] + pathAndQuery
251
252 var url *url.URL
253 url, err = url.Parse(cleanURI)
254 if err != nil {
255 return
256 }
257
258 switch url.Scheme {
259 case dcsTypeEtcd:
260 hc.DcsType = dcsTypeEtcd
261 for _, h := range hosts {
262 hc.DcsEndpoints = append(hc.DcsEndpoints, "http://"+h)
263 }
264 case dcsTypeZookeeper:
265 hc.DcsType = dcsTypeZookeeper
266 hc.DcsEndpoints = hosts
267 case dcsTypeConsul:
268 hc.DcsType = dcsTypeConsul
269 hc.DcsEndpoints = hosts
270 default:
271 return hc, fmt.Errorf("unsupported DCS type: %s", url.Scheme)
272 }
273
274 hc.Path = url.Path
275 hc.Username = url.User.Username()
276 hc.Password, _ = url.User.Password()
277 hc.CAFile = url.Query().Get("ca_file")
278 hc.CertFile = url.Query().Get("cert_file")
279 hc.KeyFile = url.Query().Get("key_file")
280
281 return hc, nil
282 }
283
284 func ResolveDatabasesFromPatroni(source Source) (SourceConns, error) {
285 var mds []*SourceConn
286 var clusterMembers []PatroniClusterMember
287 var err error
288 var ok bool
289
290 hostConfig, err := NewHostConfig(source.ConnStr)
291 if err != nil {
292 return nil, err
293 }
294
295 switch hostConfig.DcsType {
296 case dcsTypeEtcd:
297 clusterMembers, err = getEtcdClusterMembers(source, hostConfig)
298 case dcsTypeZookeeper:
299 clusterMembers, err = getZookeeperClusterMembers(source, hostConfig)
300 case dcsTypeConsul:
301 clusterMembers, err = getConsulClusterMembers(source)
302 default:
303 return nil, errors.New("unknown DCS")
304 }
305 logger := logger.WithField("source", source.Name)
306 if err != nil {
307 if errors.Is(err, errors.ErrUnsupported) {
308 return nil, err
309 }
310 logger.Debug("failed to get info from DCS, using previous member info if any")
311 if clusterMembers, ok = lastFoundClusterMembers[source.Name]; ok {
312 err = nil
313 }
314 } else {
315 lastFoundClusterMembers[source.Name] = clusterMembers
316 }
317 if len(clusterMembers) == 0 {
318 return mds, err
319 }
320
321 for _, patroniMember := range clusterMembers {
322 logger.Info("processing Patroni cluster member: ", patroniMember.Name)
323 if source.OnlyIfMaster && !patroniMember.IsPrimary() {
324 continue
325 }
326 src := *source.Clone()
327 src.ConnStr = patroniMember.ConnURL
328 if !hostConfig.IsScopeSpecified() {
329 src.Name += "_" + patroniMember.Scope
330 }
331 src.Name += "_" + patroniMember.Name
332 if dbs, err := ResolveDatabasesFromPostgres(src); err == nil {
333 mds = append(mds, dbs...)
334 } else {
335 logger.WithError(err).Error("failed to resolve databases for Patroni member: ", patroniMember.Name)
336 }
337 }
338 return mds, err
339 }
340
341
342
343 func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
344 var (
345 c db.PgxPoolIface
346 dbname string
347 rows pgx.Rows
348 )
349 c, err = NewConn(context.TODO(), s.ConnStr)
350 if err != nil {
351 return
352 }
353 defer c.Close()
354
355 sql := `select /* pgwatch_generated */
356 datname
357 from pg_database
358 where not datistemplate
359 and datallowconn
360 and has_database_privilege (datname, 'CONNECT')
361 and case when length(trim($1)) > 0 then datname ~ $1 else true end
362 and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
363
364 if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
365 return nil, err
366 }
367 for rows.Next() {
368 if err = rows.Scan(&dbname); err != nil {
369 return nil, err
370 }
371 rdb := NewSourceConn(*s.Clone())
372 rdb.Name += "_" + dbname
373 rdb.SetDatabaseName(dbname)
374 resolvedDbs = append(resolvedDbs, rdb)
375 }
376
377 if err := rows.Err(); err != nil {
378 return nil, err
379 }
380 return
381 }
382