1 package testutil
2
3 import (
4 "context"
5 "crypto/tls"
6 "net"
7 "os"
8 "time"
9
10 "github.com/cybertec-postgresql/pgwatch/v3/api/pb"
11 "github.com/testcontainers/testcontainers-go"
12 "github.com/testcontainers/testcontainers-go/modules/etcd"
13 "github.com/testcontainers/testcontainers-go/modules/postgres"
14 "github.com/testcontainers/testcontainers-go/wait"
15 "google.golang.org/grpc"
16 "google.golang.org/grpc/codes"
17 "google.golang.org/grpc/credentials"
18 "google.golang.org/grpc/metadata"
19 "google.golang.org/grpc/status"
20 )
21
22 func SetupPostgresContainer() (*postgres.PostgresContainer, func(), error) {
23 pgContainer, err := postgres.Run(ctx,
24 pgImageName,
25 postgres.WithDatabase(MockDatabase),
26 testcontainers.WithWaitStrategy(
27 wait.ForLog("database system is ready to accept connections").
28 WithOccurrence(2).
29 WithStartupTimeout(5*time.Second)),
30 )
31
32 tearDown := func() {
33 _ = pgContainer.Terminate(ctx)
34 }
35
36 return pgContainer, tearDown, err
37 }
38
39 func SetupPostgresContainerWithInitScripts(scripts ...string) (*postgres.PostgresContainer, func(), error) {
40 pgContainer, err := postgres.Run(ctx,
41 pgImageName,
42 postgres.WithDatabase(MockDatabase),
43 postgres.WithInitScripts(scripts...),
44 testcontainers.WithWaitStrategy(
45 wait.ForLog("database system is ready to accept connections").
46 WithOccurrence(2).
47 WithStartupTimeout(5*time.Second)),
48 )
49
50 tearDown := func() {
51 _ = pgContainer.Terminate(ctx)
52 }
53
54 return pgContainer, tearDown, err
55 }
56
57 func SetupEtcdContainer() (*etcd.EtcdContainer, func(), error) {
58 etcdContainer, err := etcd.Run(ctx, etcdImage,
59 testcontainers.
60 WithWaitStrategy(wait.ForLog("ready to serve client requests").
61 WithStartupTimeout(15*time.Second)))
62
63 tearDown := func() {
64 _ = etcdContainer.Terminate(ctx)
65 }
66
67 return etcdContainer, tearDown, err
68 }
69
70
71
72 func LoadServerTLSCredentials() (credentials.TransportCredentials, error) {
73 cert, err := tls.X509KeyPair(Cert, PrivateKey)
74 if err != nil {
75 return nil, err
76 }
77
78 tlsConfig := &tls.Config{
79 Certificates: []tls.Certificate{cert},
80 }
81 return credentials.NewTLS(tlsConfig), nil
82 }
83
84 func AuthInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
85 md, _ := metadata.FromIncomingContext(ctx)
86
87 clientUsername := md.Get("username")[0]
88 clientPassword := md.Get("password")[0]
89
90 if clientUsername != "" && clientUsername != "pgwatch" && clientPassword != "pgwatch" {
91 return nil, status.Error(codes.Unauthenticated, "unauthenticated")
92 }
93
94 return handler(ctx, req)
95 }
96
97 func SetupRPCServers() (func(), error) {
98 err := os.WriteFile(CAFile, []byte(CA), 0644)
99 teardown := func() { _ = os.Remove(CAFile) }
100 if err != nil {
101 return teardown, err
102 }
103
104 addresses := [2]string{PlainServerAddress, TLSServerAddress}
105 for _, address := range addresses {
106 lis, err := net.Listen("tcp", address)
107 if err != nil {
108 return teardown, err
109 }
110
111 var creds credentials.TransportCredentials
112 if address == TLSServerAddress {
113 creds, err = LoadServerTLSCredentials()
114 if err != nil {
115 return nil, err
116 }
117 }
118
119 server := grpc.NewServer(
120 grpc.UnaryInterceptor(AuthInterceptor),
121 grpc.Creds(creds),
122 )
123
124 recv := new(Receiver)
125 pb.RegisterReceiverServer(server, recv)
126
127 go func() {
128 if err := server.Serve(lis); err != nil {
129 panic(err)
130 }
131 }()
132 }
133
134 time.Sleep(time.Second)
135 return teardown, nil
136 }
137