1 package sources_test
2
3 import (
4 "context"
5 "fmt"
6 "testing"
7 "time"
8
9 "github.com/jackc/pgx/v5"
10 "github.com/jackc/pgx/v5/pgconn"
11 "github.com/jackc/pgx/v5/pgxpool"
12 "github.com/pashagolub/pgxmock/v4"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15
16 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
17 "github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
18 )
19
20 const ImageName = "docker.io/postgres:17-alpine"
21
22 func TestSourceConn_Connect(t *testing.T) {
23
24 t.Run("failed config parsing", func(t *testing.T) {
25 md := &sources.SourceConn{}
26 md.ConnStr = "invalid connection string"
27 err := md.Connect(ctx, sources.CmdOpts{})
28 assert.Error(t, err)
29 })
30
31 t.Run("failed connection", func(t *testing.T) {
32 md := &sources.SourceConn{}
33 sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
34 return nil, assert.AnError
35 }
36 err := md.Connect(ctx, sources.CmdOpts{})
37 assert.ErrorIs(t, err, assert.AnError)
38 })
39
40 t.Run("successful connection to pgbouncer", func(t *testing.T) {
41 mock, err := pgxmock.NewPool()
42 require.NoError(t, err)
43 sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
44 return mock, nil
45 }
46
47 md := &sources.SourceConn{}
48 md.Kind = sources.SourcePgBouncer
49
50 opts := sources.CmdOpts{}
51 opts.MaxParallelConnectionsPerDb = 3
52
53 mock.ExpectExec("SHOW VERSION").WillReturnResult(pgconn.NewCommandTag("SELECT 1"))
54
55 err = md.Connect(ctx, opts)
56 assert.NoError(t, err)
57
58 assert.NoError(t, mock.ExpectationsWereMet())
59 })
60 }
61
62 func TestSourceConn_ParseConfig(t *testing.T) {
63 md := &sources.SourceConn{}
64 assert.NoError(t, md.ParseConfig())
65
66 assert.NoError(t, md.ParseConfig())
67 }
68
69 func TestSourceConn_GetDatabaseName(t *testing.T) {
70 md := &sources.SourceConn{}
71 md.ConnStr = "postgres://user:password@localhost:5432/mydatabase"
72 expected := "mydatabase"
73
74 got := md.GetDatabaseName()
75 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
76
77 got = md.Source.GetDatabaseName()
78 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
79
80 md = &sources.SourceConn{}
81 md.ConnStr = "foo boo"
82 expected = ""
83 got = md.GetDatabaseName()
84 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
85 }
86
87 func TestSourceConn_SetDatabaseName(t *testing.T) {
88 md := &sources.SourceConn{}
89 md.ConnStr = "postgres://user:password@localhost:5432/mydatabase"
90 expected := "mydatabase"
91
92 md.SetDatabaseName(expected)
93 got := md.GetDatabaseName()
94 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
95
96 expected = "newdatabase"
97 md.SetDatabaseName(expected)
98 got = md.GetDatabaseName()
99 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
100
101 md = &sources.SourceConn{}
102 md.ConnStr = "foo boo"
103 expected = ""
104 md.SetDatabaseName("ingored due to invalid ConnStr")
105 got = md.GetDatabaseName()
106 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
107 }
108
109 func TestSourceConn_DiscoverPlatform(t *testing.T) {
110 ctx := context.Background()
111 mock, err := pgxmock.NewPool()
112 require.NoError(t, err)
113 md := &sources.SourceConn{Conn: mock}
114
115 mock.ExpectQuery("select").WillReturnRows(pgxmock.NewRows([]string{"exec_env"}).AddRow("AZURE_SINGLE"))
116 md.ExecEnv = md.DiscoverPlatform(ctx)
117 assert.Equal(t, "AZURE_SINGLE", md.ExecEnv)
118 assert.Equal(t, "AZURE_SINGLE", md.DiscoverPlatform(ctx))
119 assert.NoError(t, mock.ExpectationsWereMet())
120 }
121
122 func TestSourceConn_GetApproxSize(t *testing.T) {
123 mock, err := pgxmock.NewPool()
124 require.NoError(t, err)
125 md := &sources.SourceConn{Conn: mock}
126
127 mock.ExpectQuery("select").WillReturnRows(pgxmock.NewRows([]string{"size"}).AddRow(42))
128
129 assert.EqualValues(t, 42, md.FetchApproxSize(ctx))
130 assert.NoError(t, err)
131 assert.NoError(t, mock.ExpectationsWereMet())
132 }
133
134 func TestSourceConn_FunctionExists(t *testing.T) {
135 mock, err := pgxmock.NewPool()
136 require.NoError(t, err)
137 md := &sources.SourceConn{Conn: mock}
138
139 mock.ExpectQuery("select").WithArgs("get_foo").WillReturnRows(pgxmock.NewRows([]string{"exists"}))
140
141 assert.False(t, md.FunctionExists(ctx, "get_foo"))
142 assert.NoError(t, mock.ExpectationsWereMet())
143 }
144
145 func TestSourceConn_IsPostgresSource(t *testing.T) {
146 md := &sources.SourceConn{}
147 md.Kind = sources.SourcePostgres
148 assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true")
149
150 md.Kind = sources.SourcePgBouncer
151 assert.False(t, md.IsPostgresSource(), "IsPostgresSource() = true, want false")
152
153 md.Kind = sources.SourcePgPool
154 assert.False(t, md.IsPostgresSource(), "IsPostgresSource() = true, want false")
155
156 md.Kind = sources.SourcePatroni
157 assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true")
158 }
159
160 func TestSourceConn_Ping(t *testing.T) {
161 db, err := pgxmock.NewPool()
162 require.NoError(t, err)
163 md := &sources.SourceConn{Conn: db}
164
165 db.ExpectPing()
166 md.Kind = sources.SourcePostgres
167 assert.NoError(t, md.Ping(ctx), "Ping() = error, want nil")
168
169 db.ExpectExec("SHOW VERSION").WillReturnResult(pgconn.NewCommandTag("SELECT 1"))
170 md.Conn = db
171 md.Kind = sources.SourcePgBouncer
172 assert.NoError(t, md.Ping(ctx), "Ping() = error, want nil")
173 }
174
175 func TestSourceConn_GetMetricInterval(t *testing.T) {
176 md := &sources.SourceConn{
177 Source: sources.Source{
178 Metrics: map[string]float64{"foo": 1.5, "bar": 2.5},
179 MetricsStandby: map[string]float64{"foo": 3.5},
180 },
181 }
182
183 t.Run("primary uses Metrics", func(t *testing.T) {
184 md.IsInRecovery = false
185 assert.Equal(t, 1.5, md.GetMetricInterval("foo"))
186 assert.Equal(t, 2.5, md.GetMetricInterval("bar"))
187 })
188
189 t.Run("standby uses MetricsStandby if present", func(t *testing.T) {
190 md.IsInRecovery = true
191 assert.Equal(t, 3.5, md.GetMetricInterval("foo"))
192 assert.Equal(t, 0.0, md.GetMetricInterval("bar"))
193 })
194
195 t.Run("standby with empty MetricsStandby falls back to Metrics", func(t *testing.T) {
196 md.IsInRecovery = true
197 md.MetricsStandby = map[string]float64{}
198 assert.Equal(t, 1.5, md.GetMetricInterval("foo"))
199 })
200 }
201
202 func TestVersionToInt(t *testing.T) {
203 tests := []struct {
204 arg string
205 want int
206 }{
207 {"", 0},
208 {"foo", 0},
209 {"13", 13_00_00},
210 {"3.0", 3_00_00},
211 {"9.6.3", 9_06_03},
212 {"v9.6-beta2", 9_06_00},
213 }
214 for _, tt := range tests {
215 if got := sources.VersionToInt(tt.arg); got != tt.want {
216 t.Errorf("VersionToInt() = %v, want %v", got, tt.want)
217 }
218 }
219 }
220
221 func TestSourceConn_FetchRuntimeInfo(t *testing.T) {
222 ctx := context.Background()
223
224 t.Run("cancelled context", func(t *testing.T) {
225 ctxNew, cancel := context.WithCancel(ctx)
226 cancel()
227 err := (&sources.SourceConn{}).FetchRuntimeInfo(ctxNew, true)
228 assert.Error(t, err)
229 })
230
231 t.Run("cached version", func(t *testing.T) {
232 md := &sources.SourceConn{
233 RuntimeInfo: sources.RuntimeInfo{
234 LastCheckedOn: time.Now().Add(-time.Minute),
235 Version: 42,
236 },
237 }
238 err := md.FetchRuntimeInfo(ctx, false)
239 assert.NoError(t, err)
240 assert.Equal(t, 42, md.Version)
241 })
242
243 t.Run("pgbouncer version fetch", func(t *testing.T) {
244 mock, err := pgxmock.NewPool()
245 require.NoError(t, err)
246 md := &sources.SourceConn{
247 Conn: mock,
248 Source: sources.Source{Kind: sources.SourcePgBouncer},
249 }
250 mock.ExpectQuery("SHOW VERSION").
251 WithArgs(pgx.QueryExecModeSimpleProtocol).
252 WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("PgBouncer 1.12.0"))
253 err = md.FetchRuntimeInfo(ctx, true)
254 assert.NoError(t, err)
255 assert.Contains(t, md.VersionStr, "PgBouncer")
256 assert.True(t, md.Version > 0)
257 assert.NoError(t, mock.ExpectationsWereMet())
258 })
259
260 t.Run("pgpool version fetch", func(t *testing.T) {
261 mock, err := pgxmock.NewPool()
262 require.NoError(t, err)
263 md := &sources.SourceConn{
264 Conn: mock,
265 Source: sources.Source{Kind: sources.SourcePgPool},
266 }
267 mock.ExpectQuery("SHOW POOL_VERSION").
268 WithArgs(pgx.QueryExecModeSimpleProtocol).
269 WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("4.1.2"))
270 err = md.FetchRuntimeInfo(ctx, true)
271 assert.NoError(t, err)
272 assert.Contains(t, md.VersionStr, "4.1.2")
273 assert.True(t, md.Version > 0)
274 assert.NoError(t, mock.ExpectationsWereMet())
275 })
276
277 t.Run("postgres version and extensions", func(t *testing.T) {
278 mock, err := pgxmock.NewPool()
279 require.NoError(t, err)
280 md := &sources.SourceConn{
281 Conn: mock,
282 Source: sources.Source{Kind: sources.SourcePostgres},
283 }
284 mock.ExpectQuery("select").WillReturnRows(
285 pgxmock.NewRows([]string{"ver", "version", "pg_is_in_recovery", "current_database", "system_identifier", "is_superuser"}).
286 AddRow(13, "PostgreSQL 13.3", false, "testdb", "42424242", true),
287 )
288 mock.ExpectQuery("select").WillReturnRows(
289 pgxmock.NewRows([]string{"exec_env"}).AddRow("UNKNOWN"),
290 )
291 mock.ExpectQuery("select").WillReturnRows(
292 pgxmock.NewRows([]string{"approx_size"}).AddRow(42),
293 )
294
295 mock.ExpectQuery("select").WillReturnRows(
296 pgxmock.NewRows([]string{"extname", "extversion"}).AddRow("pg_stat_statements", "1.8"),
297 )
298 err = md.FetchRuntimeInfo(ctx, true)
299 assert.NoError(t, err)
300 assert.Equal(t, 13, md.Version)
301 assert.Equal(t, "testdb", md.RealDbname)
302 assert.Contains(t, md.Extensions, "pg_stat_statements")
303 assert.NoError(t, mock.ExpectationsWereMet())
304 })
305
306 t.Run("query error", func(t *testing.T) {
307 mock, err := pgxmock.NewPool()
308 require.NoError(t, err)
309 md := &sources.SourceConn{
310 Conn: mock,
311 Source: sources.Source{Kind: sources.SourcePgBouncer},
312 }
313 mock.ExpectQuery("SHOW VERSION").
314 WithArgs(pgx.QueryExecModeSimpleProtocol).
315 WillReturnError(fmt.Errorf("db error"))
316 err = md.FetchRuntimeInfo(ctx, true)
317 assert.Error(t, err)
318 assert.NoError(t, mock.ExpectationsWereMet())
319 })
320 }
321
322 func TestSourceConn_FetchVersion(t *testing.T) {
323 ctx := context.Background()
324
325 t.Run("valid version string", func(t *testing.T) {
326 mock, err := pgxmock.NewPool()
327 require.NoError(t, err)
328 md := &sources.SourceConn{Conn: mock}
329 mock.ExpectQuery("SHOW VERSION").
330 WithArgs(pgx.QueryExecModeSimpleProtocol).
331 WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("FooBar 1.12.0"))
332 verStr, verInt, err := md.FetchVersion(ctx, "SHOW VERSION")
333 assert.NoError(t, err)
334 assert.Equal(t, "FooBar 1.12.0", verStr)
335 assert.Equal(t, 1_12_00, verInt)
336 assert.NoError(t, mock.ExpectationsWereMet())
337 })
338
339 t.Run("invalid version string", func(t *testing.T) {
340 mock, err := pgxmock.NewPool()
341 require.NoError(t, err)
342 md := &sources.SourceConn{Conn: mock}
343 mock.ExpectQuery("SHOW VERSION").
344 WithArgs(pgx.QueryExecModeSimpleProtocol).
345 WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("invalid version"))
346 _, verInt, err := md.FetchVersion(ctx, "SHOW VERSION")
347 assert.Equal(t, 0, verInt)
348 assert.NoError(t, err)
349 assert.NoError(t, mock.ExpectationsWereMet())
350 })
351
352 t.Run("query error", func(t *testing.T) {
353 mock, err := pgxmock.NewPool()
354 require.NoError(t, err)
355 md := &sources.SourceConn{Conn: mock}
356 mock.ExpectQuery("SHOW VERSION").
357 WithArgs(pgx.QueryExecModeSimpleProtocol).
358 WillReturnError(assert.AnError)
359 _, _, err = md.FetchVersion(ctx, "SHOW VERSION")
360 assert.Error(t, err)
361 assert.NoError(t, mock.ExpectationsWereMet())
362 })
363 }
364
365 func TestSourceConn_GetClusterIdentifier(t *testing.T) {
366 md := &sources.SourceConn{
367 Source: sources.Source{
368 Name: "test",
369 Kind: sources.SourcePostgres,
370 ConnStr: "postgres://user:password@localhost:5432/mydatabase",
371 },
372 RuntimeInfo: sources.RuntimeInfo{
373 SystemIdentifier: "42424242",
374 },
375 }
376 assert.Equal(t, "42424242:localhost:5432", md.GetClusterIdentifier())
377
378 md = &sources.SourceConn{
379 Source: sources.Source{
380 Name: "test",
381 Kind: sources.SourcePostgres,
382 ConnStr: "foo boo",
383 },
384 }
385 assert.Equal(t, "", md.GetClusterIdentifier())
386 }
387