1 package sources
2
3
4
5
6
7 import (
8 "cmp"
9 "context"
10 "crypto/tls"
11 "crypto/x509"
12 "errors"
13 "fmt"
14 "os"
15 "path"
16 "strings"
17 "time"
18
19 jsoniter "github.com/json-iterator/go"
20
21 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
22 "github.com/cybertec-postgresql/pgwatch/v3/internal/log"
23 pgx "github.com/jackc/pgx/v5"
24 client "go.etcd.io/etcd/client/v3"
25 "go.uber.org/zap"
26 )
27
28
29 func (srcs Sources) ResolveDatabases() (_ SourceConns, err error) {
30 resolvedDbs := make(SourceConns, 0, len(srcs))
31 for _, s := range srcs {
32 if !s.IsEnabled {
33 continue
34 }
35 dbs, e := s.ResolveDatabases()
36 err = errors.Join(err, e)
37 resolvedDbs = append(resolvedDbs, dbs...)
38 }
39 return resolvedDbs, err
40 }
41
42
43 func (s Source) ResolveDatabases() (SourceConns, error) {
44 switch s.Kind {
45 case SourcePatroni, SourcePatroniContinuous, SourcePatroniNamespace:
46 return ResolveDatabasesFromPatroni(s)
47 case SourcePostgresContinuous:
48 return ResolveDatabasesFromPostgres(s)
49 }
50 return SourceConns{&SourceConn{Source: s}}, nil
51 }
52
53 type PatroniClusterMember struct {
54 Scope string
55 Name string
56 ConnURL string `yaml:"conn_url"`
57 Role string
58 }
59
60 var logger log.Logger = log.FallbackLogger
61
62 var lastFoundClusterMembers = make(map[string][]PatroniClusterMember)
63
64
65 func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
66 return nil, errors.ErrUnsupported
67 }
68
69 func getZookeeperClusterMembers(Source) ([]PatroniClusterMember, error) {
70 return nil, errors.ErrUnsupported
71 }
72
73 func jsonTextToStringMap(jsonText string) (map[string]string, error) {
74 retmap := make(map[string]string)
75 if jsonText == "" {
76 return retmap, nil
77 }
78 var iMap map[string]any
79 if err := jsoniter.ConfigFastest.Unmarshal([]byte(jsonText), &iMap); err != nil {
80 return nil, err
81 }
82 for k, v := range iMap {
83 retmap[k] = fmt.Sprintf("%v", v)
84 }
85 return retmap, nil
86 }
87
88 func getTransport(conf HostConfigAttrs) (*tls.Config, error) {
89 var caCertPool *x509.CertPool
90
91
92 if conf.CAFile != "" {
93 caCert, err := os.ReadFile(conf.CAFile)
94 if err != nil {
95 return nil, fmt.Errorf("cannot load CA file: %s", err)
96 }
97
98 caCertPool = x509.NewCertPool()
99 caCertPool.AppendCertsFromPEM(caCert)
100 }
101
102 var certificates []tls.Certificate
103
104
105 if conf.CertFile != "" && conf.KeyFile != "" {
106 cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
107 if err != nil {
108 return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
109 }
110
111 certificates = []tls.Certificate{cert}
112 }
113
114 tlsClientConfig := new(tls.Config)
115
116 if caCertPool != nil {
117 tlsClientConfig.RootCAs = caCertPool
118 if certificates != nil {
119 tlsClientConfig.Certificates = certificates
120 }
121 }
122
123 return tlsClientConfig, nil
124 }
125
126 func getEtcdClusterMembers(s Source) ([]PatroniClusterMember, error) {
127 var ret = make([]PatroniClusterMember, 0)
128 var cfg client.Config
129
130 if len(s.HostConfig.DcsEndpoints) == 0 {
131 return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
132 }
133
134 tlsConfig, err := getTransport(s.HostConfig)
135 if err != nil {
136 return nil, err
137 }
138 cfg = client.Config{
139 Endpoints: s.HostConfig.DcsEndpoints,
140 TLS: tlsConfig,
141 DialKeepAliveTimeout: time.Second,
142 Username: s.HostConfig.Username,
143 Password: s.HostConfig.Password,
144 DialTimeout: 5 * time.Second,
145 Logger: zap.NewNop(),
146 }
147
148 c, err := client.New(cfg)
149 if err != nil {
150 return ret, err
151 }
152 defer c.Close()
153
154 ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
155 defer cancel()
156
157
158
159
160
161 p := path.Join(cmp.Or(s.HostConfig.Namespace, "/service"), s.HostConfig.Scope)
162 resp, err := c.Get(ctx, p, client.WithPrefix())
163 if err != nil {
164 return ret, cmp.Or(context.Cause(ctx), err)
165 }
166
167 for _, node := range resp.Kvs {
168 nodeData, err := jsonTextToStringMap(string(node.Value))
169 if err != nil {
170 logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node, err)
171 continue
172 }
173
174 parts := strings.Split(strings.TrimPrefix(string(node.Key), "/"), "/")
175 if len(parts) < 3 {
176 return nil, errors.New("invalid ETCD key format")
177 }
178 role := nodeData["role"]
179 connURL := nodeData["conn_url"]
180 scope := parts[1]
181 name := parts[3]
182 ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
183 }
184
185 lastFoundClusterMembers[s.Name] = ret
186 return ret, nil
187 }
188
189 const (
190 dcsTypeEtcd = "etcd"
191 dcsTypeZookeeper = "zookeeper"
192 dcsTypeConsul = "consul"
193 )
194
195 func ResolveDatabasesFromPatroni(source Source) (SourceConns, error) {
196 var mds []*SourceConn
197 var clusterMembers []PatroniClusterMember
198 var err error
199 var ok bool
200
201 switch source.HostConfig.DcsType {
202 case dcsTypeEtcd:
203 clusterMembers, err = getEtcdClusterMembers(source)
204 case dcsTypeZookeeper:
205 clusterMembers, err = getZookeeperClusterMembers(source)
206 case dcsTypeConsul:
207 clusterMembers, err = getConsulClusterMembers(source)
208 default:
209 return nil, errors.New("unknown DCS")
210 }
211 logger := logger.WithField("sorce", source.Name)
212 if err != nil {
213 if errors.Is(err, errors.ErrUnsupported) {
214 return nil, err
215 }
216 logger.Debug("Failed to get info from DCS, using previous member info if any")
217 if clusterMembers, ok = lastFoundClusterMembers[source.Name]; ok {
218 err = nil
219 }
220 } else {
221 lastFoundClusterMembers[source.Name] = clusterMembers
222 }
223 if len(clusterMembers) == 0 {
224 return mds, err
225 }
226
227 for _, patroniMember := range clusterMembers {
228 logger.Info("Processing Patroni cluster member: ", patroniMember.Name)
229 if source.OnlyIfMaster && patroniMember.Role != "master" {
230 continue
231 }
232 src := *source.Clone()
233 src.ConnStr = patroniMember.ConnURL
234 if source.Kind == SourcePatroniNamespace {
235 src.Name += "_" + patroniMember.Scope
236 }
237 src.Name += "_" + patroniMember.Name
238 if dbs, err := ResolveDatabasesFromPostgres(src); err == nil {
239 mds = append(mds, dbs...)
240 } else {
241 logger.WithError(err).Error("Failed to resolve databases for Patroni member: ", patroniMember.Name)
242 }
243 }
244 return mds, err
245 }
246
247
248
249 func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
250 var (
251 c db.PgxPoolIface
252 dbname string
253 rows pgx.Rows
254 )
255 c, err = NewConn(context.TODO(), s.ConnStr)
256 if err != nil {
257 return
258 }
259 defer c.Close()
260
261 sql := `select /* pgwatch_generated */
262 datname
263 from pg_database
264 where not datistemplate
265 and datallowconn
266 and has_database_privilege (datname, 'CONNECT')
267 and case when length(trim($1)) > 0 then datname ~ $1 else true end
268 and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
269
270 if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
271 return nil, err
272 }
273 for rows.Next() {
274 if err = rows.Scan(&dbname); err != nil {
275 return nil, err
276 }
277 rdb := &SourceConn{Source: *s.Clone()}
278 rdb.Name += "_" + dbname
279 rdb.SetDatabaseName(dbname)
280 resolvedDbs = append(resolvedDbs, rdb)
281 }
282
283 if err := rows.Err(); err != nil {
284 return nil, err
285 }
286 return
287 }
288