1 package testutil_test
2
3 import (
4 "context"
5 "testing"
6
7 "github.com/cybertec-postgresql/pgwatch/v3/internal/testutil"
8 "github.com/stretchr/testify/assert"
9 "github.com/stretchr/testify/require"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/metadata"
12 "google.golang.org/grpc/status"
13 )
14
15 func TestLoadServerTLSCredentials(t *testing.T) {
16 t.Run("valid credentials", func(t *testing.T) {
17 creds, err := testutil.LoadServerTLSCredentials()
18 assert.NoError(t, err)
19 assert.NotNil(t, creds)
20 assert.Equal(t, "tls", creds.Info().SecurityProtocol)
21 })
22
23 t.Run("invalid certificate", func(t *testing.T) {
24
25 origCert := testutil.Cert
26 origKey := testutil.PrivateKey
27 defer func() {
28 testutil.Cert = origCert
29 testutil.PrivateKey = origKey
30 }()
31
32
33 testutil.Cert = []byte("invalid cert")
34 testutil.PrivateKey = []byte("invalid key")
35
36 creds, err := testutil.LoadServerTLSCredentials()
37 assert.Error(t, err)
38 assert.Nil(t, creds)
39 })
40 }
41
42 func TestAuthInterceptor(t *testing.T) {
43 handler := func(context.Context, any) (any, error) {
44 return "success", nil
45 }
46
47 t.Run("valid credentials", func(t *testing.T) {
48 md := metadata.Pairs("username", "pgwatch", "password", "pgwatch")
49 ctx := metadata.NewIncomingContext(context.Background(), md)
50
51 result, err := testutil.AuthInterceptor(ctx, nil, nil, handler)
52 assert.NoError(t, err)
53 assert.Equal(t, "success", result)
54 })
55
56 t.Run("empty credentials", func(t *testing.T) {
57 md := metadata.Pairs("username", "", "password", "")
58 ctx := metadata.NewIncomingContext(context.Background(), md)
59
60 result, err := testutil.AuthInterceptor(ctx, nil, nil, handler)
61 assert.NoError(t, err)
62 assert.Equal(t, "success", result)
63 })
64
65 t.Run("invalid credentials", func(t *testing.T) {
66 md := metadata.Pairs("username", "wrong", "password", "wrong")
67 ctx := metadata.NewIncomingContext(context.Background(), md)
68
69 result, err := testutil.AuthInterceptor(ctx, nil, nil, handler)
70 assert.Error(t, err)
71 assert.Nil(t, result)
72
73 st, ok := status.FromError(err)
74 require.True(t, ok)
75 assert.Equal(t, codes.Unauthenticated, st.Code())
76 })
77 }
78
79 func TestSetupPostgresContainer(t *testing.T) {
80 if testing.Short() {
81 t.Skip("Skipping container test in short mode")
82 }
83
84 container, teardown, err := testutil.SetupPostgresContainer()
85 for i := range 2 {
86 if i == 1 {
87 container, teardown, err = testutil.SetupPostgresContainerWithInitScripts("../../docker/bootstrap/create_role_db.sql")
88 }
89
90 if err != nil {
91 t.Skipf("Skipping postgres container test: %v", err)
92 return
93 }
94 defer teardown()
95
96 assert.NotNil(t, container)
97
98
99 state, err := container.State(context.Background())
100 require.NoError(t, err)
101 assert.True(t, state.Running)
102
103
104 connStr, err := container.ConnectionString(context.Background())
105 require.NoError(t, err)
106 assert.NotEmpty(t, connStr)
107 }
108 }
109
110 func TestSetupEtcdContainer(t *testing.T) {
111 if testing.Short() {
112 t.Skip("Skipping etcd container test in short mode")
113 }
114
115 etcdContainer, etcdTeardown, err := testutil.SetupEtcdContainer()
116 require.NoError(t, err)
117 defer etcdTeardown()
118
119
120 state, err := etcdContainer.State(context.Background())
121 require.NoError(t, err)
122 assert.True(t, state.Running)
123 }