1 package testutil
2
3 import (
4 "context"
5 "crypto/tls"
6 "net"
7 "os"
8 "time"
9
10 "github.com/cybertec-postgresql/pgwatch/v5/api/pb"
11 "github.com/testcontainers/testcontainers-go"
12 "github.com/testcontainers/testcontainers-go/modules/etcd"
13 "github.com/testcontainers/testcontainers-go/modules/postgres"
14 "github.com/testcontainers/testcontainers-go/wait"
15 "google.golang.org/grpc"
16 "google.golang.org/grpc/codes"
17 "google.golang.org/grpc/credentials"
18 "google.golang.org/grpc/metadata"
19 "google.golang.org/grpc/status"
20 )
21
22 func SetupPostgresContainer() (*postgres.PostgresContainer, func(), error) {
23 pgContainer, err := postgres.Run(TestContext,
24 PostgresImage,
25 postgres.WithDatabase(MockDatabase),
26 testcontainers.WithWaitStrategy(
27 wait.ForLog("database system is ready to accept connections").
28 WithOccurrence(2).
29 WithStartupTimeout(5*time.Second)),
30 )
31
32 tearDown := func() {
33 _ = pgContainer.Terminate(TestContext)
34 }
35
36 return pgContainer, tearDown, err
37 }
38
39
40
41 func SetupPostgresContainerWithConfig(configPath string) (*postgres.PostgresContainer, func(), error) {
42 pgContainer, err := postgres.Run(TestContext,
43 PostgresImage,
44 postgres.WithDatabase(MockDatabase),
45 postgres.WithConfigFile(configPath),
46 testcontainers.WithWaitStrategy(
47 wait.ForListeningPort("5432/tcp").WithStartupTimeout(5*time.Second)),
48 )
49
50 tearDown := func() {
51 _ = pgContainer.Terminate(TestContext)
52 }
53
54 return pgContainer, tearDown, err
55 }
56
57 func SetupPostgresContainerWithInitScripts(scripts ...string) (*postgres.PostgresContainer, func(), error) {
58 pgContainer, err := postgres.Run(TestContext,
59 PostgresImage,
60 postgres.WithDatabase(MockDatabase),
61 postgres.WithInitScripts(scripts...),
62 testcontainers.WithWaitStrategy(
63 wait.ForLog("database system is ready to accept connections").
64 WithOccurrence(2).
65 WithStartupTimeout(5*time.Second)),
66 )
67
68 tearDown := func() {
69 _ = pgContainer.Terminate(TestContext)
70 }
71
72 return pgContainer, tearDown, err
73 }
74
75 func SetupEtcdContainer() (*etcd.EtcdContainer, func(), error) {
76 etcdContainer, err := etcd.Run(TestContext, EtcdImage,
77 testcontainers.
78 WithWaitStrategy(wait.ForLog("ready to serve client requests").
79 WithStartupTimeout(15*time.Second)))
80
81 tearDown := func() {
82 _ = etcdContainer.Terminate(TestContext)
83 }
84
85 return etcdContainer, tearDown, err
86 }
87
88
89
90 func LoadServerTLSCredentials() (credentials.TransportCredentials, error) {
91 cert, err := tls.X509KeyPair(Cert, PrivateKey)
92 if err != nil {
93 return nil, err
94 }
95
96 tlsConfig := &tls.Config{
97 Certificates: []tls.Certificate{cert},
98 }
99 return credentials.NewTLS(tlsConfig), nil
100 }
101
102 func AuthInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
103 md, _ := metadata.FromIncomingContext(ctx)
104
105 clientUsername := md.Get("username")[0]
106 clientPassword := md.Get("password")[0]
107
108 if clientUsername != "" && clientUsername != "pgwatch" && clientPassword != "pgwatch" {
109 return nil, status.Error(codes.Unauthenticated, "unauthenticated")
110 }
111
112 return handler(ctx, req)
113 }
114
115 func SetupRPCServers() (func(), error) {
116 err := os.WriteFile(CAFile, []byte(CA), 0644)
117 teardown := func() { _ = os.Remove(CAFile) }
118 if err != nil {
119 return teardown, err
120 }
121
122 addresses := [2]string{PlainServerAddress, TLSServerAddress}
123 for _, address := range addresses {
124 lis, err := net.Listen("tcp", address)
125 if err != nil {
126 return teardown, err
127 }
128
129 var creds credentials.TransportCredentials
130 if address == TLSServerAddress {
131 creds, err = LoadServerTLSCredentials()
132 if err != nil {
133 return nil, err
134 }
135 }
136
137 server := grpc.NewServer(
138 grpc.UnaryInterceptor(AuthInterceptor),
139 grpc.Creds(creds),
140 )
141
142 recv := new(Receiver)
143 pb.RegisterReceiverServer(server, recv)
144
145 go func() {
146 if err := server.Serve(lis); err != nil {
147 panic(err)
148 }
149 }()
150 }
151
152 time.Sleep(time.Second)
153 return teardown, nil
154 }
155