...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/sinks/rpc.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/sinks

     1  package sinks
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"errors"
     8  	"fmt"
     9  	"net/url"
    10  	"os"
    11  	"time"
    12  
    13  	"github.com/cybertec-postgresql/pgwatch/v3/api/pb"
    14  	"github.com/cybertec-postgresql/pgwatch/v3/internal/log"
    15  	"github.com/cybertec-postgresql/pgwatch/v3/internal/metrics"
    16  	jsoniter "github.com/json-iterator/go"
    17  	"google.golang.org/grpc"
    18  	"google.golang.org/grpc/codes"
    19  	"google.golang.org/grpc/credentials"
    20  	"google.golang.org/grpc/credentials/insecure"
    21  	"google.golang.org/grpc/metadata"
    22  	"google.golang.org/grpc/status"
    23  	"google.golang.org/protobuf/types/known/structpb"
    24  )
    25  
    26  // RPCWriter sends metric measurements to a remote server using gRPC.
    27  // Remote servers should make use the .proto file under api/pb/ to integrate with it.
    28  // It's up to the implementer to define the behavior of the server.
    29  // It can be a simple logger, external storage, alerting system, or an analytics system.
    30  type RPCWriter struct {
    31  	ctx    context.Context
    32  	conn   *grpc.ClientConn
    33  	client pb.ReceiverClient
    34  }
    35  
    36  // convertSyncOp converts sinks.SyncOp to pb.SyncOp
    37  func convertSyncOp(op SyncOp) pb.SyncOp {
    38  	switch op {
    39  	case AddOp:
    40  		return pb.SyncOp_AddOp
    41  	case DeleteOp:
    42  		return pb.SyncOp_DeleteOp
    43  	case DefineOp:
    44  		return pb.SyncOp_DefineOp
    45  	default:
    46  		return pb.SyncOp_InvalidOp
    47  	}
    48  }
    49  
    50  
    51  func NewRPCWriter(ctx context.Context, connStr string) (*RPCWriter, error) {
    52  	uri, err := url.Parse(connStr)
    53  	if err != nil {
    54  		return nil, fmt.Errorf("error parsing gRPC URI: %s", err)
    55  	}
    56  
    57  	l := log.GetLogger(ctx).WithField("sink", "grpc").WithField("address", uri.Host)
    58  	ctx = log.WithLogger(ctx, l)
    59  
    60  	params, err := url.ParseQuery(uri.RawQuery)
    61  	if err != nil {
    62  		return nil, fmt.Errorf("error parsing gRPC URI parameters: %s", err)
    63  	}
    64  
    65  	creds := insecure.NewCredentials()
    66  
    67  	CAFile, ok := params["sslrootca"]
    68  	if ok {
    69  		creds, err = LoadTLSCredentials(CAFile[0])
    70  		if err != nil {
    71  			return nil, err
    72  		}
    73  		log.GetLogger(ctx).Infof("Valid CA File %s loaded - enabling TLS", CAFile)
    74  	}
    75  
    76  	conn, err := grpc.NewClient(uri.Host, grpc.WithTransportCredentials(creds))
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	
    81  	password, _ := uri.User.Password()
    82  	md := metadata.Pairs(  
    83  		"username", uri.User.Username(),  
    84  		"password", password,
    85  	)  
    86  	newCtx := metadata.NewOutgoingContext(ctx, md)
    87  	
    88  	client := pb.NewReceiverClient(conn)
    89  	rw := &RPCWriter{
    90  		ctx:    newCtx,
    91  		conn:   conn,
    92  		client: client,
    93  	}
    94  
    95  	if err = rw.Ping(); err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	go rw.watchCtx()
   100  	return rw, nil
   101  }
   102  
   103  func (rw *RPCWriter) Ping() error {
   104  	err := rw.SyncMetric("", "", InvalidOp)
   105  	st, ok := status.FromError(err)
   106  	if ok && st.Code() == codes.Unavailable {
   107  		return err
   108  	}
   109  	return nil
   110  }
   111  
   112  // Sends Measurement Message to RPC Sink
   113  func (rw *RPCWriter) Write(msg metrics.MeasurementEnvelope) error {
   114  	if rw.ctx.Err() != nil {
   115  		return rw.ctx.Err()
   116  	}
   117  
   118  	dataLength := len(msg.Data)
   119  	failCnt := 0
   120  	measurements := make([]*structpb.Struct, 0, dataLength)
   121  	for _, item := range msg.Data {
   122  		st, err := structpb.NewStruct(item)
   123  		if err != nil {
   124  			failCnt++
   125  			continue
   126  		}
   127  		measurements = append(measurements, st)
   128  	}
   129  	if failCnt > 0 {
   130  		log.GetLogger(rw.ctx).WithField("database", msg.DBName).WithField("metric",
   131  			msg.MetricName).Warningf("gRPC sink failed to encode %d rows", failCnt)
   132  	}
   133  
   134  	envelope := &pb.MeasurementEnvelope{
   135  		DBName:     msg.DBName,
   136  		MetricName: msg.MetricName,
   137  		CustomTags: msg.CustomTags,
   138  		Data:       measurements,
   139  	}
   140  
   141  	t1 := time.Now()
   142  	reply, err := rw.client.UpdateMeasurements(rw.ctx, envelope)
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	diff := time.Since(t1)
   148  	log.GetLogger(rw.ctx).WithField("rows", dataLength).WithField("elapsed", diff).Info("measurements written")
   149  	if reply.GetLogmsg() != "" {
   150  		log.GetLogger(rw.ctx).Info(reply.GetLogmsg())
   151  	}
   152  	return nil
   153  }
   154  
   155  // SyncMetric synchronizes a metric and monitored source with the remote server
   156  func (rw *RPCWriter) SyncMetric(sourceName, metricName string, op SyncOp) error {
   157  	syncReq := &pb.SyncReq{
   158  		DBName:     sourceName,
   159  		MetricName: metricName,
   160  		Operation:  convertSyncOp(op),
   161  	}
   162  
   163  	reply, err := rw.client.SyncMetric(rw.ctx, syncReq)
   164  	if err != nil {
   165  		return err
   166  	}
   167  
   168  	if reply.GetLogmsg() != "" {
   169  		log.GetLogger(rw.ctx).Info(reply.GetLogmsg())
   170  	}
   171  	return nil
   172  }
   173  
   174  // DefineMetrics sends metric definitions to the remote server
   175  func (rw *RPCWriter) DefineMetrics(metrics *metrics.Metrics) error {
   176  	var json = jsoniter.ConfigFastest
   177  
   178  	// Convert metrics to JSON first, then to structpb.Struct
   179  	// to automatically handle all the type conversions
   180  	jsonData, err := json.Marshal(metrics)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	var metricMap map[string]any
   186  	if err := json.Unmarshal(jsonData, &metricMap); err != nil {
   187  		return err
   188  	}
   189  
   190  	metricStruct, err := structpb.NewStruct(metricMap)
   191  	if err != nil {
   192  		return err
   193  	}
   194  
   195  	t1 := time.Now()
   196  	reply, err := rw.client.DefineMetrics(rw.ctx, metricStruct)
   197  	if err != nil {
   198  		return err
   199  	}
   200  
   201  	diff := time.Since(t1)
   202  	log.GetLogger(rw.ctx).WithField("elapsed", diff).Info("metric definitions written")
   203  	if reply.GetLogmsg() != "" {
   204  		log.GetLogger(rw.ctx).Info(reply.GetLogmsg())
   205  	}
   206  	return nil
   207  }
   208  
   209  func (rw *RPCWriter) watchCtx() {
   210  	<-rw.ctx.Done()
   211  	rw.conn.Close()
   212  }
   213  
   214  func LoadTLSCredentials(CAFile string) (credentials.TransportCredentials, error) {
   215  	ca, err := os.ReadFile(CAFile)
   216  	if err != nil {
   217  		return nil, fmt.Errorf("error loading CA file: %v", err)
   218  	}
   219  
   220  	certPool := x509.NewCertPool()
   221  	ok := certPool.AppendCertsFromPEM(ca)
   222  	if !ok {
   223  		return nil, errors.New("invalid CA file")
   224  	}
   225  
   226  	tlsClientConfig := &tls.Config{
   227  		RootCAs: certPool,
   228  	}
   229  	return credentials.NewTLS(tlsClientConfig), nil
   230  }