...
1 package sinks
2
3 import (
4 "context"
5 "crypto/tls"
6 "crypto/x509"
7 "fmt"
8 "net/rpc"
9 "net/url"
10 "os"
11 "time"
12
13 "github.com/cybertec-postgresql/pgwatch/v3/internal/log"
14 "github.com/cybertec-postgresql/pgwatch/v3/internal/metrics"
15 )
16
17 func NewRPCWriter(ctx context.Context, ConnStr string) (*RPCWriter, error) {
18 uri, err := url.Parse(ConnStr)
19 if err != nil {
20 return nil, fmt.Errorf("error parsing RPC URI: %s", err)
21 }
22
23 params, err := url.ParseQuery(uri.RawQuery)
24 if err != nil {
25 return nil, fmt.Errorf("error parsing RPC URI parameters: %s", err)
26 }
27
28 RootCA, exists := params["sslrootca"]
29 var client *rpc.Client
30 if exists {
31 client, err = connectViaTLS(uri.Host, RootCA[0])
32 } else {
33 client, err = rpc.DialHTTP("tcp", uri.Host)
34 }
35
36 if err != nil {
37 return nil, err
38 }
39
40 l := log.GetLogger(ctx).WithField("sink", "rpc").WithField("address", uri.Host)
41 ctx = log.WithLogger(ctx, l)
42 rw := &RPCWriter{
43 ctx: ctx,
44 client: client,
45 }
46 go rw.watchCtx()
47 return rw, nil
48 }
49
50 func connectViaTLS(address, RootCA string) (*rpc.Client, error) {
51 ca, err := os.ReadFile(RootCA)
52 if err != nil {
53 return nil, fmt.Errorf("cannot load CA file: %s", err)
54 }
55
56 certPool := x509.NewCertPool()
57 certPool.AppendCertsFromPEM(ca)
58
59 tlsClientConfig := &tls.Config{
60 RootCAs: certPool,
61 }
62
63 conn, err := tls.Dial("tcp", address, tlsClientConfig)
64 if err != nil {
65 return nil, err
66 }
67 return rpc.NewClient(conn), nil
68 }
69
70
71 func (rw *RPCWriter) Write(msg metrics.MeasurementEnvelope) error {
72 if rw.ctx.Err() != nil {
73 return rw.ctx.Err()
74 }
75
76 t1 := time.Now()
77 var logMsg string
78 if err := rw.client.Call("Receiver.UpdateMeasurements", &msg, &logMsg); err != nil {
79 return err
80 }
81
82 diff := time.Since(t1)
83 written := len(msg.Data)
84 log.GetLogger(rw.ctx).WithField("rows", written).WithField("elapsed", diff).Info("measurements written")
85 if len(logMsg) > 0 {
86 log.GetLogger(rw.ctx).Info(logMsg)
87 }
88 return nil
89 }
90
91 func (rw *RPCWriter) SyncMetric(dbUnique, metricName string, op SyncOp) error {
92 var logMsg string
93 if err := rw.client.Call("Receiver.SyncMetric", &SyncReq{
94 Operation: op,
95 DbName: dbUnique,
96 MetricName: metricName,
97 }, &logMsg); err != nil {
98 return err
99 }
100 if len(logMsg) > 0 {
101 log.GetLogger(rw.ctx).Info(logMsg)
102 }
103 return nil
104 }
105
106 func (rw *RPCWriter) watchCtx() {
107 <-rw.ctx.Done()
108 rw.client.Close()
109 }
110