1 package sources
2
3
4
5
6
7 import (
8 "cmp"
9 "context"
10 "crypto/tls"
11 "crypto/x509"
12 "errors"
13 "fmt"
14 "net/url"
15 "os"
16 "strings"
17 "time"
18
19 jsoniter "github.com/json-iterator/go"
20
21 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
22 "github.com/cybertec-postgresql/pgwatch/v3/internal/log"
23 pgx "github.com/jackc/pgx/v5"
24 client "go.etcd.io/etcd/client/v3"
25 "go.uber.org/zap"
26 )
27
28
29 func (srcs Sources) ResolveDatabases() (_ SourceConns, err error) {
30 resolvedDbs := make(SourceConns, 0, len(srcs))
31 for _, s := range srcs {
32 if !s.IsEnabled {
33 continue
34 }
35 dbs, e := s.ResolveDatabases()
36 err = errors.Join(err, e)
37 resolvedDbs = append(resolvedDbs, dbs...)
38 }
39 return resolvedDbs, err
40 }
41
42
43 func (s Source) ResolveDatabases() (SourceConns, error) {
44 switch s.Kind {
45 case SourcePatroni:
46 return ResolveDatabasesFromPatroni(s)
47 case SourcePostgresContinuous:
48 return ResolveDatabasesFromPostgres(s)
49 }
50 return SourceConns{NewSourceConn(s)}, nil
51 }
52
53 type PatroniClusterMember struct {
54 Scope string
55 Name string
56 ConnURL string `yaml:"conn_url"`
57 Role string
58 }
59
60 func (pcm PatroniClusterMember) IsPrimary() bool {
61 return pcm.Role == "primary" || pcm.Role == "master"
62 }
63
64 var logger log.Logger = log.FallbackLogger
65
66 var lastFoundClusterMembers = make(map[string][]PatroniClusterMember)
67
68
69 func getConsulClusterMembers(Source) ([]PatroniClusterMember, error) {
70 return nil, errors.ErrUnsupported
71 }
72
73 func getZookeeperClusterMembers(Source, HostConfig) ([]PatroniClusterMember, error) {
74 return nil, errors.ErrUnsupported
75 }
76
77 func jsonTextToStringMap(jsonText string) (map[string]string, error) {
78 retmap := make(map[string]string)
79 if jsonText == "" {
80 return retmap, nil
81 }
82 var iMap map[string]any
83 if err := jsoniter.ConfigFastest.Unmarshal([]byte(jsonText), &iMap); err != nil {
84 return nil, err
85 }
86 for k, v := range iMap {
87 retmap[k] = fmt.Sprintf("%v", v)
88 }
89 return retmap, nil
90 }
91
92 func getTransport(conf HostConfig) (*tls.Config, error) {
93 var caCertPool *x509.CertPool
94
95
96 if conf.CAFile != "" {
97 caCert, err := os.ReadFile(conf.CAFile)
98 if err != nil {
99 return nil, fmt.Errorf("cannot load CA file: %s", err)
100 }
101
102 caCertPool = x509.NewCertPool()
103 caCertPool.AppendCertsFromPEM(caCert)
104 }
105
106 var certificates []tls.Certificate
107
108
109 if conf.CertFile != "" && conf.KeyFile != "" {
110 cert, err := tls.LoadX509KeyPair(conf.CertFile, conf.KeyFile)
111 if err != nil {
112 return nil, fmt.Errorf("cannot load client cert or key file: %s", err)
113 }
114
115 certificates = []tls.Certificate{cert}
116 }
117
118 tlsClientConfig := new(tls.Config)
119
120 if caCertPool != nil {
121 tlsClientConfig.RootCAs = caCertPool
122 if certificates != nil {
123 tlsClientConfig.Certificates = certificates
124 }
125 }
126
127 return tlsClientConfig, nil
128 }
129
130 func getEtcdClusterMembers(s Source, hc HostConfig) ([]PatroniClusterMember, error) {
131 var ret = make([]PatroniClusterMember, 0)
132 var cfg client.Config
133
134 if len(hc.DcsEndpoints) == 0 {
135 return ret, errors.New("missing ETCD connect info, make sure host config has a 'dcs_endpoints' key")
136 }
137
138 tlsConfig, err := getTransport(hc)
139 if err != nil {
140 return nil, err
141 }
142 cfg = client.Config{
143 Endpoints: hc.DcsEndpoints,
144 TLS: tlsConfig,
145 DialKeepAliveTimeout: time.Second,
146 Username: hc.Username,
147 Password: hc.Password,
148 DialTimeout: 5 * time.Second,
149 Logger: zap.NewNop(),
150 }
151
152 c, err := client.New(cfg)
153 if err != nil {
154 return ret, err
155 }
156 defer c.Close()
157
158 ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("etcd client timeout"))
159 defer cancel()
160
161
162
163
164
165 resp, err := c.Get(ctx, hc.Path, client.WithPrefix())
166 if err != nil {
167 return ret, cmp.Or(context.Cause(ctx), err)
168 }
169
170 for _, node := range resp.Kvs {
171
172 parts := strings.Split(strings.TrimPrefix(string(node.Key), "/"), "/")
173 if len(parts) < 4 || parts[2] != "members" {
174 continue
175 }
176 nodeData, err := jsonTextToStringMap(string(node.Value))
177 if err != nil {
178 logger.Errorf("Could not parse ETCD node data for node \"%s\": %s", node.Key, err)
179 continue
180 }
181 role := nodeData["role"]
182 connURL := nodeData["conn_url"]
183 scope := parts[1]
184 name := parts[3]
185 ret = append(ret, PatroniClusterMember{Scope: scope, ConnURL: connURL, Role: role, Name: name})
186 }
187
188 lastFoundClusterMembers[s.Name] = ret
189 return ret, nil
190 }
191
192 const (
193 dcsTypeEtcd = "etcd"
194 dcsTypeZookeeper = "zookeeper"
195 dcsTypeConsul = "consul"
196 )
197
198 type HostConfig struct {
199 DcsType string `yaml:"dcs_type"`
200 DcsEndpoints []string `yaml:"dcs_endpoints"`
201 Path string
202 Username string
203 Password string
204 CAFile string `yaml:"ca_file"`
205 CertFile string `yaml:"cert_file"`
206 KeyFile string `yaml:"key_file"`
207 }
208
209 func (hc HostConfig) IsScopeSpecified() bool {
210
211
212 return strings.Count(hc.Path, "/") >= 2
213 }
214
215 func NewHostConfig(URI string) (hc HostConfig, err error) {
216 var url *url.URL
217 url, err = url.Parse(URI)
218 if err != nil {
219 return
220 }
221
222 switch url.Scheme {
223 case dcsTypeEtcd:
224 hc.DcsType = dcsTypeEtcd
225 for h := range strings.SplitSeq(url.Host, ",") {
226 hc.DcsEndpoints = append(hc.DcsEndpoints, "http://"+h)
227 }
228 case dcsTypeZookeeper:
229 hc.DcsType = dcsTypeZookeeper
230 hc.DcsEndpoints = []string{url.Host}
231
232 case dcsTypeConsul:
233 hc.DcsType = dcsTypeConsul
234 hc.DcsEndpoints = strings.Split(url.Host, ",")
235 default:
236 return hc, fmt.Errorf("unsupported DCS type: %s", url.Scheme)
237 }
238
239 hc.Path = url.Path
240 hc.Username = url.User.Username()
241 hc.Password, _ = url.User.Password()
242 hc.CAFile = url.Query().Get("ca_file")
243 hc.CertFile = url.Query().Get("cert_file")
244 hc.KeyFile = url.Query().Get("key_file")
245
246 return hc, nil
247 }
248
249 func ResolveDatabasesFromPatroni(source Source) (SourceConns, error) {
250 var mds []*SourceConn
251 var clusterMembers []PatroniClusterMember
252 var err error
253 var ok bool
254
255 hostConfig, err := NewHostConfig(source.ConnStr)
256 if err != nil {
257 return nil, err
258 }
259
260 switch hostConfig.DcsType {
261 case dcsTypeEtcd:
262 clusterMembers, err = getEtcdClusterMembers(source, hostConfig)
263 case dcsTypeZookeeper:
264 clusterMembers, err = getZookeeperClusterMembers(source, hostConfig)
265 case dcsTypeConsul:
266 clusterMembers, err = getConsulClusterMembers(source)
267 default:
268 return nil, errors.New("unknown DCS")
269 }
270 logger := logger.WithField("sorce", source.Name)
271 if err != nil {
272 if errors.Is(err, errors.ErrUnsupported) {
273 return nil, err
274 }
275 logger.Debug("failed to get info from DCS, using previous member info if any")
276 if clusterMembers, ok = lastFoundClusterMembers[source.Name]; ok {
277 err = nil
278 }
279 } else {
280 lastFoundClusterMembers[source.Name] = clusterMembers
281 }
282 if len(clusterMembers) == 0 {
283 return mds, err
284 }
285
286 for _, patroniMember := range clusterMembers {
287 logger.Info("processing Patroni cluster member: ", patroniMember.Name)
288 if source.OnlyIfMaster && !patroniMember.IsPrimary() {
289 continue
290 }
291 src := *source.Clone()
292 src.ConnStr = patroniMember.ConnURL
293 if !hostConfig.IsScopeSpecified() {
294 src.Name += "_" + patroniMember.Scope
295 }
296 src.Name += "_" + patroniMember.Name
297 if dbs, err := ResolveDatabasesFromPostgres(src); err == nil {
298 mds = append(mds, dbs...)
299 } else {
300 logger.WithError(err).Error("failed to resolve databases for Patroni member: ", patroniMember.Name)
301 }
302 }
303 return mds, err
304 }
305
306
307
308 func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error) {
309 var (
310 c db.PgxPoolIface
311 dbname string
312 rows pgx.Rows
313 )
314 c, err = NewConn(context.TODO(), s.ConnStr)
315 if err != nil {
316 return
317 }
318 defer c.Close()
319
320 sql := `select /* pgwatch_generated */
321 datname
322 from pg_database
323 where not datistemplate
324 and datallowconn
325 and has_database_privilege (datname, 'CONNECT')
326 and case when length(trim($1)) > 0 then datname ~ $1 else true end
327 and case when length(trim($2)) > 0 then not datname ~ $2 else true end`
328
329 if rows, err = c.Query(context.TODO(), sql, s.IncludePattern, s.ExcludePattern); err != nil {
330 return nil, err
331 }
332 for rows.Next() {
333 if err = rows.Scan(&dbname); err != nil {
334 return nil, err
335 }
336 rdb := NewSourceConn(*s.Clone())
337 rdb.Name += "_" + dbname
338 rdb.SetDatabaseName(dbname)
339 resolvedDbs = append(resolvedDbs, rdb)
340 }
341
342 if err := rows.Err(); err != nil {
343 return nil, err
344 }
345 return
346 }
347