1 package sources_test
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "testing"
9 "time"
10
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13 client "go.etcd.io/etcd/client/v3"
14
15 "github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
16 testcontainers "github.com/testcontainers/testcontainers-go"
17 "github.com/testcontainers/testcontainers-go/modules/etcd"
18 "github.com/testcontainers/testcontainers-go/modules/postgres"
19 "github.com/testcontainers/testcontainers-go/wait"
20 )
21
22 func TestMonitoredDatabase_ResolveDatabasesFromPostgres(t *testing.T) {
23 pgContainer, err := postgres.Run(ctx,
24 ImageName,
25 postgres.WithDatabase("mydatabase"),
26 testcontainers.WithWaitStrategy(
27 wait.ForLog("database system is ready to accept connections").
28 WithOccurrence(2).
29 WithStartupTimeout(5*time.Second)),
30 )
31 require.NoError(t, err)
32 defer func() { assert.NoError(t, pgContainer.Terminate(ctx)) }()
33
34
35 md := sources.Source{}
36 md.Name = "continuous"
37 md.Kind = sources.SourcePostgresContinuous
38 md.ConnStr, err = pgContainer.ConnectionString(ctx, "sslmode=disable")
39 assert.NoError(t, err)
40
41
42 dbs, err := md.ResolveDatabases()
43 assert.NoError(t, err)
44 assert.True(t, len(dbs) == 2)
45
46
47 db := dbs.GetMonitoredDatabase(md.Name + "_mydatabase")
48 assert.NotNil(t, db)
49 assert.Equal(t, "mydatabase", db.GetDatabaseName())
50
51
52 db = dbs.GetMonitoredDatabase(md.Name + "_unexpected")
53 assert.Nil(t, db)
54 }
55
56 func TestMonitoredDatabase_ResolveDatabasesFromPatroni(t *testing.T) {
57 etcdContainer, err := etcd.Run(ctx, "gcr.io/etcd-development/etcd:v3.5.14",
58 testcontainers.WithWaitStrategy(wait.ForLog("ready to serve client requests").
59 WithStartupTimeout(15*time.Second)))
60 require.NoError(t, err)
61 defer func() { assert.NoError(t, etcdContainer.Terminate(ctx)) }()
62
63 endpoint, err := etcdContainer.ClientEndpoint(ctx)
64 require.NoError(t, err)
65
66 cli, err := client.New(client.Config{
67 Endpoints: []string{endpoint},
68 DialTimeout: 10 * time.Second,
69 })
70 require.NoError(t, err, "failed to create etcd client")
71 defer cli.Close()
72
73
74 pgContainer, err := postgres.Run(ctx,
75 ImageName,
76 postgres.WithDatabase("mydatabase"),
77 postgres.WithInitScripts("../../docker/bootstrap/create_role_db.sql"),
78 testcontainers.WithWaitStrategy(
79 wait.ForLog("database system is ready to accept connections").
80 WithOccurrence(2).
81 WithStartupTimeout(5*time.Second)),
82 )
83 require.NoError(t, err)
84 defer func() { assert.NoError(t, pgContainer.Terminate(ctx)) }()
85
86
87 cancelCtx, cancel := context.WithTimeout(context.Background(), time.Second)
88 defer cancel()
89 connStr, err := pgContainer.ConnectionString(cancelCtx, "sslmode=disable")
90 require.NoError(t, err)
91 _, err = cli.Put(cancelCtx, "/service/batman/members/pg1",
92 fmt.Sprintf(`{"role":"master","conn_url":"%s"}`, connStr))
93 require.NoError(t, err)
94 _, err = cli.Put(cancelCtx, "/service/batman/members/pg2",
95 `{"role":"standby","conn_url":"must_be_skipped"}`)
96 require.NoError(t, err)
97
98 md := sources.Source{}
99 md.Name = "continuous"
100 md.OnlyIfMaster = true
101
102 t.Run("simple patroni discovery", func(t *testing.T) {
103 md.Kind = sources.SourcePatroni
104 md.ConnStr = "etcd://" + strings.TrimPrefix(endpoint, "http://")
105 md.ConnStr += "/service"
106 md.ConnStr += "/batman"
107
108
109 dbs, err := md.ResolveDatabases()
110 assert.NoError(t, err)
111 assert.NotNil(t, dbs)
112 assert.Len(t, dbs, 4)
113 })
114
115 t.Run("several endpoints patroni discovery", func(t *testing.T) {
116 md.Kind = sources.SourcePatroni
117 e := strings.TrimPrefix(endpoint, "http://")
118 md.ConnStr = "etcd://" + strings.Join([]string{e, e, e}, ",")
119 md.ConnStr += "/service"
120 md.ConnStr += "/batman"
121
122
123 dbs, err := md.ResolveDatabases()
124 assert.NoError(t, err)
125 assert.NotNil(t, dbs)
126 assert.Len(t, dbs, 4)
127 })
128
129 t.Run("namespace patroni discovery", func(t *testing.T) {
130 md.Kind = sources.SourcePatroni
131 md.ConnStr = "etcd://" + strings.TrimPrefix(endpoint, "http://")
132
133
134 dbs, err := md.ResolveDatabases()
135 assert.NoError(t, err)
136 assert.NotNil(t, dbs)
137 assert.Len(t, dbs, 4)
138 })
139 }
140
141 func TestMonitoredDatabase_UnsupportedDCS(t *testing.T) {
142 md := sources.Source{}
143 md.Name = "continuous"
144 md.Kind = sources.SourcePatroni
145
146 md.ConnStr = "consul://foo"
147 _, err := md.ResolveDatabases()
148 assert.ErrorIs(t, err, errors.ErrUnsupported)
149
150 md.ConnStr = "zookeeper://foo"
151 _, err = md.ResolveDatabases()
152 assert.ErrorIs(t, err, errors.ErrUnsupported)
153
154 md.ConnStr = "unknown://foo"
155 _, err = md.ResolveDatabases()
156 assert.EqualError(t, err, "unsupported DCS type: unknown")
157
158 }
159