1 package sources
2
3
4
5
6
7 import (
8 "cmp"
9 "context"
10 "crypto/tls"
11 "crypto/x509"
12 "errors"
13 "fmt"
14 "net/url"
15 "os"
16 "strings"
17 "time"
18
19 jsoniter "github.com/json-iterator/go"
20
21 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
22 "github.com/cybertec-postgresql/pgwatch/v3/internal/log"
23 pgx "github.com/jackc/pgx/v5"
24 client "go.etcd.io/etcd/client/v3"
25 "go.uber.org/zap"
26 )
27
28
29 func (srcs Sources) ResolveDatabases() (_ SourceConns, err error) {
30 resolvedDbs := make(SourceConns, 0, len(srcs))
31 for _, s := range srcs {
32 if !s.IsEnabled {
33 continue
34 }
35 dbs, e := s.ResolveDatabases()
36 err = errors.Join(err, e)
37 resolvedDbs = append(resolvedDbs, dbs...)
38 }
39 return resolvedDbs, err
40 }
41
42
43 func (s Source) ResolveDatabases() (SourceConns, error) {
44 switch s.Kind {
45 case SourcePatroni:
46 return ResolveDatabasesFromPatroni(s)
47 case SourcePostgresContinuous:
48 return ResolveDatabasesFromPostgres(s)
49 }
50 return SourceConns{&SourceConn{Source: s}}, nil
51 }
52
53 type PatroniClusterMember struct {
54 Scope string
55 Name string
56 ConnURL string `yaml:"conn_url"`
57 Role string
58 }
59
60 var logger log.Logger = log.FallbackLogger
61
62 var lastFoundClusterMembers = make(map[string][]PatroniClusterMember)
63
64
65 func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
66 return nil, errors.ErrUnsupported
67 }
68
69 func getZookeeperClusterMembers(Source, HostConfig) ([]PatroniClusterMember, error) {
70 return nil, errors.ErrUnsupported
71 }
72
73 func jsonTextToStringMap(jsonText string) (map[string]string, error) {
74 retmap := make(map[string]string)
75 if jsonText == "" {
76 return retmap, nil
77 }
78 var iMap map[string]any
79 if err := jsoniter.ConfigFastest.Unmarshal([]byte(jsonText), &iMap); err != nil {
80 return nil, err
81 }
82 for k, v := range iMap {
83 retmap[k] = fmt.Sprintf("%v", v)
84 }
85 return retmap, nil
86 }
87
88 func getTransport(conf HostConfig) (*tls.Config, error) {
89 var caCertPool *x509.CertPool
90
91
92 if conf.CAFile != "" {
93 caCert, err := os.ReadFile(conf.CAFile)
94 if err != nil {
95 return nil, fmt.Errorf("cannot load CA file: %s", err)
96 }
97
98 caCertPool = x509.NewCertPool()
99 caCertPool.AppendCertsFromPEM(caCert)
100 }
101
102 var certificates []tls.Certificate
103
104
105 if conf.CertFile != "" && conf.KeyFile != "" {
106 cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
107 if err != nil {
108 return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
109 }
110
111 certificates = []tls.Certificate{cert}
112 }
113
114 tlsClientConfig := new(tls.Config)
115
116 if caCertPool != nil {
117 tlsClientConfig.RootCAs = caCertPool
118 if certificates != nil {
119 tlsClientConfig.Certificates = certificates
120 }
121 }
122
123 return tlsClientConfig, nil
124 }
125
126 func getEtcdClusterMembers(s Source, hc HostConfig) ([]PatroniClusterMember, error) {
127 var ret = make([]PatroniClusterMember, 0)
128 var cfg client.Config
129
130 if len(hc.DcsEndpoints) == 0 {
131 return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
132 }
133
134 tlsConfig, err := getTransport(hc)
135 if err != nil {
136 return nil, err
137 }
138 cfg = client.Config{
139 Endpoints: hc.DcsEndpoints,
140 TLS: tlsConfig,
141 DialKeepAliveTimeout: time.Second,
142 Username: hc.Username,
143 Password: hc.Password,
144 DialTimeout: 5 * time.Second,
145 Logger: zap.NewNop(),
146 }
147
148 c, err := client.New(cfg)
149 if err != nil {
150 return ret, err
151 }
152 defer c.Close()
153
154 ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
155 defer cancel()
156
157
158
159
160
161 resp, err := c.Get(ctx, hc.Path, client.WithPrefix())
162 if err != nil {
163 return ret, cmp.Or(context.Cause(ctx), err)
164 }
165
166 for _, node := range resp.Kvs {
167 nodeData, err := jsonTextToStringMap(string(node.Value))
168 if err != nil {
169 logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node, err)
170 continue
171 }
172
173 parts := strings.Split(strings.TrimPrefix(string(node.Key), "/"), "/")
174 if len(parts) < 3 {
175 return nil, errors.New("invalid ETCD key format")
176 }
177 role := nodeData["role"]
178 connURL := nodeData["conn_url"]
179 scope := parts[1]
180 name := parts[3]
181 ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
182 }
183
184 lastFoundClusterMembers[s.Name] = ret
185 return ret, nil
186 }
187
188 const (
189 dcsTypeEtcd = "etcd"
190 dcsTypeZookeeper = "zookeeper"
191 dcsTypeConsul = "consul"
192 )
193
194 type HostConfig struct {
195 DcsType string `yaml:"dcs_type"`
196 DcsEndpoints []string `yaml:"dcs_endpoints"`
197 Path string
198 Username string
199 Password string
200 CAFile string `yaml:"ca_file"`
201 CertFile string `yaml:"cert_file"`
202 KeyFile string `yaml:"key_file"`
203 }
204
205 func (hc HostConfig) IsScopeSpecified() bool {
206
207
208 return strings.Count(hc.Path, "/") >= 2
209 }
210
211 func NewHostConfig(URI string) (hc HostConfig, err error) {
212 var url *url.URL
213 url, err = url.Parse(URI)
214 if err != nil {
215 return
216 }
217
218 switch url.Scheme {
219 case dcsTypeEtcd:
220 hc.DcsType = dcsTypeEtcd
221 for h := range strings.SplitSeq(url.Host, ",") {
222 hc.DcsEndpoints = append(hc.DcsEndpoints, "http://"+h)
223 }
224 case dcsTypeZookeeper:
225 hc.DcsType = dcsTypeZookeeper
226 hc.DcsEndpoints = []string{url.Host}
227
228 case dcsTypeConsul:
229 hc.DcsType = dcsTypeConsul
230 hc.DcsEndpoints = strings.Split(url.Host, ",")
231 default:
232 return hc, fmt.Errorf("unsupported DCS type: %s", url.Scheme)
233 }
234
235 hc.Path = url.Path
236 hc.Username = url.User.Username()
237 hc.Password, _ = url.User.Password()
238 hc.CAFile = url.Query().Get("ca_file")
239 hc.CertFile = url.Query().Get("cert_file")
240 hc.KeyFile = url.Query().Get("key_file")
241
242 return hc, nil
243 }
244
245 func ResolveDatabasesFromPatroni(source Source) (SourceConns, error) {
246 var mds []*SourceConn
247 var clusterMembers []PatroniClusterMember
248 var err error
249 var ok bool
250
251 hostConfig, err := NewHostConfig(source.ConnStr)
252 if err != nil {
253 return nil, err
254 }
255
256 switch hostConfig.DcsType {
257 case dcsTypeEtcd:
258 clusterMembers, err = getEtcdClusterMembers(source, hostConfig)
259 case dcsTypeZookeeper:
260 clusterMembers, err = getZookeeperClusterMembers(source, hostConfig)
261 case dcsTypeConsul:
262 clusterMembers, err = getConsulClusterMembers(source)
263 default:
264 return nil, errors.New("unknown DCS")
265 }
266 logger := logger.WithField("sorce", source.Name)
267 if err != nil {
268 if errors.Is(err, errors.ErrUnsupported) {
269 return nil, err
270 }
271 logger.Debug("Failed to get info from DCS, using previous member info if any")
272 if clusterMembers, ok = lastFoundClusterMembers[source.Name]; ok {
273 err = nil
274 }
275 } else {
276 lastFoundClusterMembers[source.Name] = clusterMembers
277 }
278 if len(clusterMembers) == 0 {
279 return mds, err
280 }
281
282 for _, patroniMember := range clusterMembers {
283 logger.Info("Processing Patroni cluster member: ", patroniMember.Name)
284 if source.OnlyIfMaster && patroniMember.Role != "master" {
285 continue
286 }
287 src := *source.Clone()
288 src.ConnStr = patroniMember.ConnURL
289 if hostConfig.IsScopeSpecified() {
290 src.Name += "_" + patroniMember.Scope
291 }
292 src.Name += "_" + patroniMember.Name
293 if dbs, err := ResolveDatabasesFromPostgres(src); err == nil {
294 mds = append(mds, dbs...)
295 } else {
296 logger.WithError(err).Error("Failed to resolve databases for Patroni member: ", patroniMember.Name)
297 }
298 }
299 return mds, err
300 }
301
302
303
304 func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
305 var (
306 c db.PgxPoolIface
307 dbname string
308 rows pgx.Rows
309 )
310 c, err = NewConn(context.TODO(), s.ConnStr)
311 if err != nil {
312 return
313 }
314 defer c.Close()
315
316 sql := `select /* pgwatch_generated */
317 datname
318 from pg_database
319 where not datistemplate
320 and datallowconn
321 and has_database_privilege (datname, 'CONNECT')
322 and case when length(trim($1)) > 0 then datname ~ $1 else true end
323 and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
324
325 if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
326 return nil, err
327 }
328 for rows.Next() {
329 if err = rows.Scan(&dbname); err != nil {
330 return nil, err
331 }
332 rdb := &SourceConn{Source: *s.Clone()}
333 rdb.Name += "_" + dbname
334 rdb.SetDatabaseName(dbname)
335 resolvedDbs = append(resolvedDbs, rdb)
336 }
337
338 if err := rows.Err(); err != nil {
339 return nil, err
340 }
341 return
342 }
343