...

Source file src/github.com/cybertec-postgresql/pgwatch/v5/internal/testutil/setup.go

Documentation: github.com/cybertec-postgresql/pgwatch/v5/internal/testutil

     1  package testutil
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"net"
     7  	"os"
     8  	"time"
     9  
    10  	"github.com/cybertec-postgresql/pgwatch/v5/api/pb"
    11  	"github.com/testcontainers/testcontainers-go"
    12  	"github.com/testcontainers/testcontainers-go/modules/etcd"
    13  	"github.com/testcontainers/testcontainers-go/modules/postgres"
    14  	"github.com/testcontainers/testcontainers-go/wait"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/codes"
    17  	"google.golang.org/grpc/credentials"
    18  	"google.golang.org/grpc/metadata"
    19  	"google.golang.org/grpc/status"
    20  )
    21  
    22  func SetupPostgresContainer() (*postgres.PostgresContainer, func(), error) {
    23  	pgContainer, err := postgres.Run(TestContext,
    24  		PostgresImage,
    25  		postgres.WithDatabase(MockDatabase),
    26  		testcontainers.WithWaitStrategy(
    27  			wait.ForLog("database system is ready to accept connections").
    28  				WithOccurrence(2).
    29  				WithStartupTimeout(5*time.Second)),
    30  	)
    31  
    32  	tearDown := func() {
    33  		_ = pgContainer.Terminate(TestContext)
    34  	}
    35  
    36  	return pgContainer, tearDown, err
    37  }
    38  
    39  // Creates a PostgreSQL container with CSV logging enabled.
    40  // This is useful for testing log parsing functionality with server_log_event_counts metric.
    41  func SetupPostgresContainerWithConfig(configPath string) (*postgres.PostgresContainer, func(), error) {
    42  	pgContainer, err := postgres.Run(TestContext,
    43  		PostgresImage,
    44  		postgres.WithDatabase(MockDatabase),
    45  		postgres.WithConfigFile(configPath),
    46  		testcontainers.WithWaitStrategy(
    47  			wait.ForListeningPort("5432/tcp").WithStartupTimeout(5*time.Second)),
    48  	)
    49  
    50  	tearDown := func() {
    51  		_ = pgContainer.Terminate(TestContext)
    52  	}
    53  
    54  	return pgContainer, tearDown, err
    55  }
    56  
    57  func SetupPostgresContainerWithInitScripts(scripts ...string) (*postgres.PostgresContainer, func(), error) {
    58  	pgContainer, err := postgres.Run(TestContext,
    59  		PostgresImage,
    60  		postgres.WithDatabase(MockDatabase),
    61  		postgres.WithInitScripts(scripts...),
    62  		testcontainers.WithWaitStrategy(
    63  			wait.ForLog("database system is ready to accept connections").
    64  				WithOccurrence(2).
    65  				WithStartupTimeout(5*time.Second)),
    66  	)
    67  
    68  	tearDown := func() {
    69  		_ = pgContainer.Terminate(TestContext)
    70  	}
    71  
    72  	return pgContainer, tearDown, err
    73  }
    74  
    75  func SetupEtcdContainer() (*etcd.EtcdContainer, func(), error) {
    76  	etcdContainer, err := etcd.Run(TestContext, EtcdImage,
    77  		testcontainers.
    78  			WithWaitStrategy(wait.ForLog("ready to serve client requests").
    79  				WithStartupTimeout(15*time.Second)))
    80  
    81  	tearDown := func() {
    82  		_ = etcdContainer.Terminate(TestContext)
    83  	}
    84  
    85  	return etcdContainer, tearDown, err
    86  }
    87  
    88  //-----------Setup gRPC test servers-----------------
    89  
    90  func LoadServerTLSCredentials() (credentials.TransportCredentials, error) {
    91  	cert, err := tls.X509KeyPair(Cert, PrivateKey)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	tlsConfig := &tls.Config{
    97  		Certificates: []tls.Certificate{cert},
    98  	}
    99  	return credentials.NewTLS(tlsConfig), nil
   100  }
   101  
   102  func AuthInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
   103  	md, _ := metadata.FromIncomingContext(ctx)
   104  
   105  	clientUsername := md.Get("username")[0]
   106  	clientPassword := md.Get("password")[0]
   107  
   108  	if clientUsername != "" && clientUsername != "pgwatch" && clientPassword != "pgwatch" {
   109  		return nil, status.Error(codes.Unauthenticated, "unauthenticated")
   110  	}
   111  
   112  	return handler(ctx, req)
   113  }
   114  
   115  func SetupRPCServers() (func(), error) {
   116  	err := os.WriteFile(CAFile, []byte(CA), 0644)
   117  	teardown := func() { _ = os.Remove(CAFile) }
   118  	if err != nil {
   119  		return teardown, err
   120  	}
   121  
   122  	addresses := [2]string{PlainServerAddress, TLSServerAddress}
   123  	for _, address := range addresses {
   124  		lis, err := net.Listen("tcp", address)
   125  		if err != nil {
   126  			return teardown, err
   127  		}
   128  
   129  		var creds credentials.TransportCredentials
   130  		if address == TLSServerAddress {
   131  			creds, err = LoadServerTLSCredentials()
   132  			if err != nil {
   133  				return nil, err
   134  			}
   135  		}
   136  
   137  		server := grpc.NewServer(
   138  			grpc.UnaryInterceptor(AuthInterceptor),
   139  			grpc.Creds(creds),
   140  		)
   141  
   142  		recv := new(Receiver)
   143  		pb.RegisterReceiverServer(server, recv)
   144  
   145  		go func() {
   146  			if err := server.Serve(lis); err != nil {
   147  				panic(err)
   148  			}
   149  		}()
   150  	}
   151  	// wait a little for servers to start
   152  	time.Sleep(time.Second)
   153  	return teardown, nil
   154  }
   155