1 package testutil_test
2
3 import (
4 "context"
5 "os"
6 "testing"
7
8 "github.com/cybertec-postgresql/pgwatch/v5/internal/testutil"
9 "github.com/stretchr/testify/assert"
10 "github.com/stretchr/testify/require"
11 "google.golang.org/grpc/codes"
12 "google.golang.org/grpc/metadata"
13 "google.golang.org/grpc/status"
14 )
15
16 func TestLoadServerTLSCredentials(t *testing.T) {
17 t.Run("valid credentials", func(t *testing.T) {
18 creds, err := testutil.LoadServerTLSCredentials()
19 assert.NoError(t, err)
20 assert.NotNil(t, creds)
21 assert.Equal(t, "tls", creds.Info().SecurityProtocol)
22 })
23
24 t.Run("invalid certificate", func(t *testing.T) {
25
26 origCert := testutil.Cert
27 origKey := testutil.PrivateKey
28 defer func() {
29 testutil.Cert = origCert
30 testutil.PrivateKey = origKey
31 }()
32
33
34 testutil.Cert = []byte("invalid cert")
35 testutil.PrivateKey = []byte("invalid key")
36
37 creds, err := testutil.LoadServerTLSCredentials()
38 assert.Error(t, err)
39 assert.Nil(t, creds)
40 })
41 }
42
43 func TestAuthInterceptor(t *testing.T) {
44 handler := func(context.Context, any) (any, error) {
45 return "success", nil
46 }
47
48 t.Run("valid credentials", func(t *testing.T) {
49 md := metadata.Pairs("username", "pgwatch", "password", "pgwatch")
50 ctx := metadata.NewIncomingContext(context.Background(), md)
51
52 result, err := testutil.AuthInterceptor(ctx, nil, nil, handler)
53 assert.NoError(t, err)
54 assert.Equal(t, "success", result)
55 })
56
57 t.Run("empty credentials", func(t *testing.T) {
58 md := metadata.Pairs("username", "", "password", "")
59 ctx := metadata.NewIncomingContext(context.Background(), md)
60
61 result, err := testutil.AuthInterceptor(ctx, nil, nil, handler)
62 assert.NoError(t, err)
63 assert.Equal(t, "success", result)
64 })
65
66 t.Run("invalid credentials", func(t *testing.T) {
67 md := metadata.Pairs("username", "wrong", "password", "wrong")
68 ctx := metadata.NewIncomingContext(context.Background(), md)
69
70 result, err := testutil.AuthInterceptor(ctx, nil, nil, handler)
71 assert.Error(t, err)
72 assert.Nil(t, result)
73
74 st, ok := status.FromError(err)
75 require.True(t, ok)
76 assert.Equal(t, codes.Unauthenticated, st.Code())
77 })
78 }
79
80 func TestSetupPostgresContainer(t *testing.T) {
81 if testing.Short() {
82 t.Skip("Skipping container test in short mode")
83 }
84
85 container, teardown, err := testutil.SetupPostgresContainer()
86 for i := range 2 {
87 if i == 1 {
88 container, teardown, err = testutil.SetupPostgresContainerWithInitScripts("../../docker/bootstrap/create_role_db.sql")
89 }
90
91 if err != nil {
92 t.Skipf("Skipping postgres container test: %v", err)
93 return
94 }
95 defer teardown()
96
97 assert.NotNil(t, container)
98
99
100 state, err := container.State(context.Background())
101 require.NoError(t, err)
102 assert.True(t, state.Running)
103
104
105 connStr, err := container.ConnectionString(context.Background())
106 require.NoError(t, err)
107 assert.NotEmpty(t, connStr)
108 }
109 }
110
111 func TestSetupPostgresContainerWithConfig(t *testing.T) {
112 if testing.Short() {
113 t.Skip("Skipping container test in short mode")
114 }
115
116
117 tempDir := t.TempDir()
118 configPath := tempDir + "/postgresql.conf"
119 configContent := `
120 listen_addresses = '*'
121 log_destination = 'csvlog'
122 logging_collector = on
123 log_directory = 'pg_log'
124 log_filename = 'pgwatch.csv'
125 `
126 err := os.WriteFile(configPath, []byte(configContent), 0644)
127 require.NoError(t, err)
128
129 container, teardown, err := testutil.SetupPostgresContainerWithConfig(configPath)
130 require.NoError(t, err)
131 defer teardown()
132
133 require.NotNil(t, container)
134
135
136 state, err := container.State(context.Background())
137 require.NoError(t, err)
138 assert.True(t, state.Running)
139
140
141 connStr, err := container.ConnectionString(context.Background())
142 assert.NoError(t, err)
143 assert.NotEmpty(t, connStr)
144 }
145
146 func TestSetupEtcdContainer(t *testing.T) {
147 if testing.Short() {
148 t.Skip("Skipping etcd container test in short mode")
149 }
150
151 etcdContainer, etcdTeardown, err := testutil.SetupEtcdContainer()
152 require.NoError(t, err)
153 defer etcdTeardown()
154
155
156 state, err := etcdContainer.State(context.Background())
157 require.NoError(t, err)
158 assert.True(t, state.Running)
159 }
160