...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/testutil/setup.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/testutil

     1  package testutil
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"net"
     7  	"os"
     8  	"time"
     9  
    10  	"github.com/cybertec-postgresql/pgwatch/v3/api/pb"
    11  	"github.com/testcontainers/testcontainers-go"
    12  	"github.com/testcontainers/testcontainers-go/modules/etcd"
    13  	"github.com/testcontainers/testcontainers-go/modules/postgres"
    14  	"github.com/testcontainers/testcontainers-go/wait"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/codes"
    17  	"google.golang.org/grpc/credentials"
    18  	"google.golang.org/grpc/metadata"
    19  	"google.golang.org/grpc/status"
    20  )
    21  
    22  func SetupPostgresContainer() (*postgres.PostgresContainer, func(), error) {
    23  	pgContainer, err := postgres.Run(ctx,
    24  		pgImageName,
    25  		postgres.WithDatabase(MockDatabase),
    26  		testcontainers.WithWaitStrategy(
    27  			wait.ForLog("database system is ready to accept connections").
    28  				WithOccurrence(2).
    29  				WithStartupTimeout(5*time.Second)),
    30  	)
    31  
    32  	tearDown := func() {
    33  		_ = pgContainer.Terminate(ctx)
    34  	}
    35  
    36  	return pgContainer, tearDown, err
    37  }
    38  
    39  func SetupPostgresContainerWithInitScripts(scripts ...string) (*postgres.PostgresContainer, func(), error) {
    40  	pgContainer, err := postgres.Run(ctx,
    41  		pgImageName,
    42  		postgres.WithDatabase(MockDatabase),
    43  		postgres.WithInitScripts(scripts...),
    44  		testcontainers.WithWaitStrategy(
    45  			wait.ForLog("database system is ready to accept connections").
    46  				WithOccurrence(2).
    47  				WithStartupTimeout(5*time.Second)),
    48  	)
    49  
    50  	tearDown := func() {
    51  		_ = pgContainer.Terminate(ctx)
    52  	}
    53  
    54  	return pgContainer, tearDown, err
    55  }
    56  
    57  func SetupEtcdContainer() (*etcd.EtcdContainer, func(), error) {
    58  	etcdContainer, err := etcd.Run(ctx, etcdImage,
    59  		testcontainers.
    60  			WithWaitStrategy(wait.ForLog("ready to serve client requests").
    61  			WithStartupTimeout(15*time.Second)))
    62  
    63  	tearDown := func() {
    64  		_ = etcdContainer.Terminate(ctx)
    65  	}
    66  
    67  	return etcdContainer, tearDown, err
    68  }
    69  
    70  //-----------Setup gRPC test servers-----------------
    71  
    72  func LoadServerTLSCredentials() (credentials.TransportCredentials, error) {
    73  	cert, err := tls.X509KeyPair(Cert, PrivateKey)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	tlsConfig := &tls.Config{
    79  		Certificates: []tls.Certificate{cert},
    80  	}
    81  	return credentials.NewTLS(tlsConfig), nil
    82  }
    83  
    84  func AuthInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    85  	md, _ := metadata.FromIncomingContext(ctx)
    86  
    87  	clientUsername := md.Get("username")[0]
    88  	clientPassword := md.Get("password")[0]
    89  
    90  	if clientUsername != "" && clientUsername != "pgwatch" && clientPassword != "pgwatch" {
    91  		return nil, status.Error(codes.Unauthenticated, "unauthenticated")
    92  	}
    93  
    94  	return handler(ctx, req)
    95  }
    96  
    97  func SetupRPCServers() (func(), error) {
    98  	err := os.WriteFile(CAFile, []byte(CA), 0644)
    99  	teardown := func() { _ = os.Remove(CAFile) }
   100  	if err != nil {
   101  		return teardown, err
   102  	}
   103  
   104  	addresses := [2]string{PlainServerAddress, TLSServerAddress}
   105  	for _, address := range addresses {
   106  		lis, err := net.Listen("tcp", address)
   107  		if err != nil {
   108  			return teardown, err
   109  		}
   110  
   111  		var creds credentials.TransportCredentials
   112  		if address == TLSServerAddress {
   113  			creds, err = LoadServerTLSCredentials()
   114  			if err != nil {
   115  				return nil, err
   116  			}
   117  		}
   118  
   119  		server := grpc.NewServer(
   120  			grpc.UnaryInterceptor(AuthInterceptor),
   121  			grpc.Creds(creds),
   122  		)
   123  
   124  		recv := new(Receiver)
   125  		pb.RegisterReceiverServer(server, recv)
   126  
   127  		go func() {
   128  			if err := server.Serve(lis); err != nil {
   129  				panic(err)
   130  			}
   131  		}()
   132  	}
   133  	// wait a little for servers to start
   134  	time.Sleep(time.Second)
   135  	return teardown, nil
   136  }
   137