1 package sources_test
2
3 import (
4 "context"
5 "fmt"
6 "testing"
7 "time"
8
9 "github.com/jackc/pgx/v5"
10 "github.com/jackc/pgx/v5/pgconn"
11 "github.com/jackc/pgx/v5/pgxpool"
12 "github.com/pashagolub/pgxmock/v4"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15
16 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
17 "github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
18 )
19
20 func TestSourceConn_Connect(t *testing.T) {
21
22 t.Run("failed config parsing", func(t *testing.T) {
23 md := &sources.SourceConn{}
24 md.ConnStr = "invalid connection string"
25 err := md.Connect(ctx, sources.CmdOpts{})
26 assert.Error(t, err)
27 })
28
29 t.Run("failed connection", func(t *testing.T) {
30 md := &sources.SourceConn{}
31 sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
32 return nil, assert.AnError
33 }
34 err := md.Connect(ctx, sources.CmdOpts{})
35 assert.ErrorIs(t, err, assert.AnError)
36 })
37
38 t.Run("successful connection to pgbouncer", func(t *testing.T) {
39 mock, err := pgxmock.NewPool()
40 require.NoError(t, err)
41 sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
42 return mock, nil
43 }
44
45 md := &sources.SourceConn{}
46 md.Kind = sources.SourcePgBouncer
47
48 opts := sources.CmdOpts{}
49 opts.MaxParallelConnectionsPerDb = 3
50
51 mock.ExpectExec("SHOW VERSION").WillReturnResult(pgconn.NewCommandTag("SELECT 1"))
52
53 err = md.Connect(ctx, opts)
54 assert.NoError(t, err)
55
56 assert.NoError(t, mock.ExpectationsWereMet())
57 })
58 }
59
60 func TestSourceConn_ParseConfig(t *testing.T) {
61 md := &sources.SourceConn{}
62 assert.NoError(t, md.ParseConfig())
63
64 assert.NoError(t, md.ParseConfig())
65 }
66
67 func TestSourceConn_GetDatabaseName(t *testing.T) {
68 md := &sources.SourceConn{}
69 md.ConnStr = "postgres://user:password@localhost:5432/mydatabase"
70 expected := "mydatabase"
71
72 got := md.GetDatabaseName()
73 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
74
75 got = md.Source.GetDatabaseName()
76 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
77
78 md = &sources.SourceConn{}
79 md.ConnStr = "foo boo"
80 expected = ""
81 got = md.GetDatabaseName()
82 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
83 }
84
85 func TestSourceConn_SetDatabaseName(t *testing.T) {
86 md := &sources.SourceConn{}
87 md.ConnStr = "postgres://user:password@localhost:5432/mydatabase"
88 expected := "mydatabase"
89
90 md.SetDatabaseName(expected)
91 got := md.GetDatabaseName()
92 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
93
94 expected = "newdatabase"
95 md.SetDatabaseName(expected)
96 got = md.GetDatabaseName()
97 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
98
99 md = &sources.SourceConn{}
100 md.ConnStr = "foo boo"
101 expected = ""
102 md.SetDatabaseName("ingored due to invalid ConnStr")
103 got = md.GetDatabaseName()
104 assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
105 }
106
107 func TestSourceConn_DiscoverPlatform(t *testing.T) {
108 ctx := context.Background()
109 mock, err := pgxmock.NewPool()
110 require.NoError(t, err)
111 md := &sources.SourceConn{Conn: mock}
112
113 mock.ExpectQuery("select").WillReturnRows(pgxmock.NewRows([]string{"exec_env"}).AddRow("AZURE_SINGLE"))
114 md.ExecEnv = md.DiscoverPlatform(ctx)
115 assert.Equal(t, "AZURE_SINGLE", md.ExecEnv)
116 assert.Equal(t, "AZURE_SINGLE", md.DiscoverPlatform(ctx))
117 assert.NoError(t, mock.ExpectationsWereMet())
118 }
119
120 func TestSourceConn_GetApproxSize(t *testing.T) {
121 mock, err := pgxmock.NewPool()
122 require.NoError(t, err)
123 md := &sources.SourceConn{Conn: mock}
124
125 mock.ExpectQuery("select").WillReturnRows(pgxmock.NewRows([]string{"size"}).AddRow(42))
126
127 assert.EqualValues(t, 42, md.FetchApproxSize(ctx))
128 assert.NoError(t, err)
129 assert.NoError(t, mock.ExpectationsWereMet())
130 }
131
132 func TestSourceConn_FunctionExists(t *testing.T) {
133 mock, err := pgxmock.NewPool()
134 require.NoError(t, err)
135 md := &sources.SourceConn{Conn: mock}
136
137 mock.ExpectQuery("select").WithArgs("get_foo").WillReturnRows(pgxmock.NewRows([]string{"exists"}))
138
139 assert.False(t, md.FunctionExists(ctx, "get_foo"))
140 assert.NoError(t, mock.ExpectationsWereMet())
141 }
142
143 func TestSourceConn_IsPostgresSource(t *testing.T) {
144 md := &sources.SourceConn{}
145 md.Kind = sources.SourcePostgres
146 assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true")
147
148 md.Kind = sources.SourcePgBouncer
149 assert.False(t, md.IsPostgresSource(), "IsPostgresSource() = true, want false")
150
151 md.Kind = sources.SourcePgPool
152 assert.False(t, md.IsPostgresSource(), "IsPostgresSource() = true, want false")
153
154 md.Kind = sources.SourcePatroni
155 assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true")
156 }
157
158 func TestSourceConn_Ping(t *testing.T) {
159 db, err := pgxmock.NewPool()
160 require.NoError(t, err)
161 md := &sources.SourceConn{Conn: db}
162
163 db.ExpectPing()
164 md.Kind = sources.SourcePostgres
165 assert.NoError(t, md.Ping(ctx), "Ping() = error, want nil")
166
167 db.ExpectExec("SHOW VERSION").WillReturnResult(pgconn.NewCommandTag("SELECT 1"))
168 md.Conn = db
169 md.Kind = sources.SourcePgBouncer
170 assert.NoError(t, md.Ping(ctx), "Ping() = error, want nil")
171 }
172
173 func TestSourceConn_GetMetricInterval(t *testing.T) {
174 md := &sources.SourceConn{
175 Source: sources.Source{
176 Metrics: map[string]float64{"foo": 1.5, "bar": 2.5},
177 MetricsStandby: map[string]float64{"foo": 3.5},
178 },
179 }
180
181 t.Run("primary uses Metrics", func(t *testing.T) {
182 md.IsInRecovery = false
183 assert.Equal(t, 1.5, md.GetMetricInterval("foo"))
184 assert.Equal(t, 2.5, md.GetMetricInterval("bar"))
185 })
186
187 t.Run("standby uses MetricsStandby if present", func(t *testing.T) {
188 md.IsInRecovery = true
189 assert.Equal(t, 3.5, md.GetMetricInterval("foo"))
190 assert.Equal(t, 0.0, md.GetMetricInterval("bar"))
191 })
192
193 t.Run("standby with empty MetricsStandby falls back to Metrics", func(t *testing.T) {
194 md.IsInRecovery = true
195 md.MetricsStandby = map[string]float64{}
196 assert.Equal(t, 1.5, md.GetMetricInterval("foo"))
197 })
198 }
199
200 func TestVersionToInt(t *testing.T) {
201 tests := []struct {
202 arg string
203 want int
204 }{
205 {"", 0},
206 {"foo", 0},
207 {"13", 13_00_00},
208 {"3.0", 3_00_00},
209 {"9.6.3", 9_06_03},
210 {"v9.6-beta2", 9_06_00},
211 }
212 for _, tt := range tests {
213 if got := sources.VersionToInt(tt.arg); got != tt.want {
214 t.Errorf("VersionToInt() = %v, want %v", got, tt.want)
215 }
216 }
217 }
218
219 func TestSourceConn_FetchRuntimeInfo(t *testing.T) {
220 ctx := context.Background()
221
222 t.Run("cancelled context", func(t *testing.T) {
223 ctxNew, cancel := context.WithCancel(ctx)
224 cancel()
225 err := (&sources.SourceConn{}).FetchRuntimeInfo(ctxNew, true)
226 assert.Error(t, err)
227 })
228
229 t.Run("cached version", func(t *testing.T) {
230 md := &sources.SourceConn{
231 RuntimeInfo: sources.RuntimeInfo{
232 LastCheckedOn: time.Now().Add(-time.Minute),
233 Version: 42,
234 },
235 }
236 err := md.FetchRuntimeInfo(ctx, false)
237 assert.NoError(t, err)
238 assert.Equal(t, 42, md.Version)
239 })
240
241 t.Run("pgbouncer version fetch", func(t *testing.T) {
242 mock, err := pgxmock.NewPool()
243 require.NoError(t, err)
244 md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgBouncer})
245 md.Conn = mock
246 mock.ExpectQuery("SHOW VERSION").
247 WithArgs(pgx.QueryExecModeSimpleProtocol).
248 WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("PgBouncer 1.12.0"))
249 err = md.FetchRuntimeInfo(ctx, true)
250 assert.NoError(t, err)
251 assert.Contains(t, md.VersionStr, "PgBouncer")
252 assert.True(t, md.Version > 0)
253 assert.NoError(t, mock.ExpectationsWereMet())
254 })
255
256 t.Run("pgpool version fetch", func(t *testing.T) {
257 mock, err := pgxmock.NewPool()
258 require.NoError(t, err)
259 md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgPool})
260 md.Conn = mock
261 mock.ExpectQuery("SHOW POOL_VERSION").
262 WithArgs(pgx.QueryExecModeSimpleProtocol).
263 WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("4.1.2"))
264 err = md.FetchRuntimeInfo(ctx, true)
265 assert.NoError(t, err)
266 assert.Contains(t, md.VersionStr, "4.1.2")
267 assert.True(t, md.Version > 0)
268 assert.NoError(t, mock.ExpectationsWereMet())
269 })
270
271 t.Run("postgres version and extensions", func(t *testing.T) {
272 mock, err := pgxmock.NewPool()
273 require.NoError(t, err)
274 md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePostgres})
275 md.Conn = mock
276 mock.ExpectQuery("select").WillReturnRows(
277 pgxmock.NewRows([]string{"ver", "version", "pg_is_in_recovery", "current_database", "system_identifier", "is_superuser"}).
278 AddRow(13, "PostgreSQL 13.3", false, "testdb", "42424242", true),
279 )
280 mock.ExpectQuery("select").WillReturnRows(
281 pgxmock.NewRows([]string{"exec_env"}).AddRow("UNKNOWN"),
282 )
283 mock.ExpectQuery("select").WillReturnRows(
284 pgxmock.NewRows([]string{"approx_size"}).AddRow(42),
285 )
286
287 mock.ExpectQuery("select").WillReturnRows(
288 pgxmock.NewRows([]string{"extname", "extversion"}).AddRow("pg_stat_statements", "1.8"),
289 )
290 err = md.FetchRuntimeInfo(ctx, true)
291 assert.NoError(t, err)
292 assert.Equal(t, 13, md.Version)
293 assert.Equal(t, "testdb", md.RealDbname)
294 assert.Contains(t, md.Extensions, "pg_stat_statements")
295 assert.NoError(t, mock.ExpectationsWereMet())
296 })
297
298 t.Run("query error", func(t *testing.T) {
299 mock, err := pgxmock.NewPool()
300 require.NoError(t, err)
301 md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgBouncer})
302 md.Conn = mock
303 mock.ExpectQuery("SHOW VERSION").
304 WithArgs(pgx.QueryExecModeSimpleProtocol).
305 WillReturnError(fmt.Errorf("db error"))
306 err = md.FetchRuntimeInfo(ctx, true)
307 assert.Error(t, err)
308 assert.NoError(t, mock.ExpectationsWereMet())
309 })
310 }
311
312 func TestSourceConn_FetchVersion(t *testing.T) {
313 ctx := context.Background()
314
315 t.Run("valid version string", func(t *testing.T) {
316 mock, err := pgxmock.NewPool()
317 require.NoError(t, err)
318 md := &sources.SourceConn{Conn: mock}
319 mock.ExpectQuery("SHOW VERSION").
320 WithArgs(pgx.QueryExecModeSimpleProtocol).
321 WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("FooBar 1.12.0"))
322 verStr, verInt, err := md.FetchVersion(ctx, "SHOW VERSION")
323 assert.NoError(t, err)
324 assert.Equal(t, "FooBar 1.12.0", verStr)
325 assert.Equal(t, 1_12_00, verInt)
326 assert.NoError(t, mock.ExpectationsWereMet())
327 })
328
329 t.Run("invalid version string", func(t *testing.T) {
330 mock, err := pgxmock.NewPool()
331 require.NoError(t, err)
332 md := &sources.SourceConn{Conn: mock}
333 mock.ExpectQuery("SHOW VERSION").
334 WithArgs(pgx.QueryExecModeSimpleProtocol).
335 WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("invalid version"))
336 _, verInt, err := md.FetchVersion(ctx, "SHOW VERSION")
337 assert.Equal(t, 0, verInt)
338 assert.NoError(t, err)
339 assert.NoError(t, mock.ExpectationsWereMet())
340 })
341
342 t.Run("query error", func(t *testing.T) {
343 mock, err := pgxmock.NewPool()
344 require.NoError(t, err)
345 md := &sources.SourceConn{Conn: mock}
346 mock.ExpectQuery("SHOW VERSION").
347 WithArgs(pgx.QueryExecModeSimpleProtocol).
348 WillReturnError(assert.AnError)
349 _, _, err = md.FetchVersion(ctx, "SHOW VERSION")
350 assert.Error(t, err)
351 assert.NoError(t, mock.ExpectationsWereMet())
352 })
353 }
354
355 func TestSourceConn_GetClusterIdentifier(t *testing.T) {
356 md := &sources.SourceConn{
357 Source: sources.Source{
358 Name: "test",
359 Kind: sources.SourcePostgres,
360 ConnStr: "postgres://user:password@localhost:5432/mydatabase",
361 },
362 RuntimeInfo: sources.RuntimeInfo{
363 SystemIdentifier: "42424242",
364 },
365 }
366 assert.Equal(t, "42424242:localhost:5432", md.GetClusterIdentifier())
367
368 md = &sources.SourceConn{
369 Source: sources.Source{
370 Name: "test",
371 Kind: sources.SourcePostgres,
372 ConnStr: "foo boo",
373 },
374 }
375 assert.Equal(t, "", md.GetClusterIdentifier())
376 }
377