1 package sinks
2
3 import (
4 "context"
5 "crypto/tls"
6 "crypto/x509"
7 "errors"
8 "fmt"
9 "net/url"
10 "os"
11 "time"
12
13 "github.com/cybertec-postgresql/pgwatch/v3/api/pb"
14 "github.com/cybertec-postgresql/pgwatch/v3/internal/log"
15 "github.com/cybertec-postgresql/pgwatch/v3/internal/metrics"
16 jsoniter "github.com/json-iterator/go"
17 "google.golang.org/grpc"
18 "google.golang.org/grpc/codes"
19 "google.golang.org/grpc/credentials"
20 "google.golang.org/grpc/credentials/insecure"
21 "google.golang.org/grpc/metadata"
22 "google.golang.org/grpc/status"
23 "google.golang.org/protobuf/types/known/structpb"
24 )
25
26
27
28
29
30 type RPCWriter struct {
31 ctx context.Context
32 conn *grpc.ClientConn
33 client pb.ReceiverClient
34 }
35
36
37 func convertSyncOp(op SyncOp) pb.SyncOp {
38 switch op {
39 case AddOp:
40 return pb.SyncOp_AddOp
41 case DeleteOp:
42 return pb.SyncOp_DeleteOp
43 case DefineOp:
44 return pb.SyncOp_DefineOp
45 default:
46 return pb.SyncOp_InvalidOp
47 }
48 }
49
50
51 func NewRPCWriter(ctx context.Context, connStr string) (*RPCWriter, error) {
52 uri, err := url.Parse(connStr)
53 if err != nil {
54 return nil, fmt.Errorf("error parsing gRPC URI: %s", err)
55 }
56
57 l := log.GetLogger(ctx).WithField("sink", "grpc").WithField("address", uri.Host)
58 ctx = log.WithLogger(ctx, l)
59
60 params, err := url.ParseQuery(uri.RawQuery)
61 if err != nil {
62 return nil, fmt.Errorf("error parsing gRPC URI parameters: %s", err)
63 }
64
65 creds := insecure.NewCredentials()
66
67 CAFile, ok := params["sslrootca"]
68 if ok {
69 creds, err = LoadTLSCredentials(CAFile[0])
70 if err != nil {
71 return nil, err
72 }
73 log.GetLogger(ctx).Infof("Valid CA File %s loaded - enabling TLS", CAFile)
74 }
75
76 conn, err := grpc.NewClient(uri.Host, grpc.WithTransportCredentials(creds))
77 if err != nil {
78 return nil, err
79 }
80
81 password, _ := uri.User.Password()
82 md := metadata.Pairs(
83 "username", uri.User.Username(),
84 "password", password,
85 )
86 newCtx := metadata.NewOutgoingContext(ctx, md)
87
88 client := pb.NewReceiverClient(conn)
89 rw := &RPCWriter{
90 ctx: newCtx,
91 conn: conn,
92 client: client,
93 }
94
95 if err = rw.Ping(); err != nil {
96 return nil, err
97 }
98
99 go rw.watchCtx()
100 return rw, nil
101 }
102
103 func (rw *RPCWriter) Ping() error {
104 err := rw.SyncMetric("", "", InvalidOp)
105 st, ok := status.FromError(err)
106 if ok && st.Code() == codes.Unavailable {
107 return err
108 }
109 return nil
110 }
111
112
113 func (rw *RPCWriter) Write(msg metrics.MeasurementEnvelope) error {
114 if rw.ctx.Err() != nil {
115 return rw.ctx.Err()
116 }
117
118 dataLength := len(msg.Data)
119 failCnt := 0
120 measurements := make([]*structpb.Struct, 0, dataLength)
121 for _, item := range msg.Data {
122 st, err := structpb.NewStruct(item)
123 if err != nil {
124 failCnt++
125 continue
126 }
127 measurements = append(measurements, st)
128 }
129 if failCnt > 0 {
130 log.GetLogger(rw.ctx).WithField("database", msg.DBName).WithField("metric",
131 msg.MetricName).Warningf("gRPC sink failed to encode %d rows", failCnt)
132 }
133
134 envelope := &pb.MeasurementEnvelope{
135 DBName: msg.DBName,
136 MetricName: msg.MetricName,
137 CustomTags: msg.CustomTags,
138 Data: measurements,
139 }
140
141 t1 := time.Now()
142 reply, err := rw.client.UpdateMeasurements(rw.ctx, envelope)
143 if err != nil {
144 return err
145 }
146
147 diff := time.Since(t1)
148 log.GetLogger(rw.ctx).WithField("rows", dataLength).WithField("elapsed", diff).Info("measurements written")
149 if reply.GetLogmsg() != "" {
150 log.GetLogger(rw.ctx).Info(reply.GetLogmsg())
151 }
152 return nil
153 }
154
155
156 func (rw *RPCWriter) SyncMetric(sourceName, metricName string, op SyncOp) error {
157 syncReq := &pb.SyncReq{
158 DBName: sourceName,
159 MetricName: metricName,
160 Operation: convertSyncOp(op),
161 }
162
163 reply, err := rw.client.SyncMetric(rw.ctx, syncReq)
164 if err != nil {
165 return err
166 }
167
168 if reply.GetLogmsg() != "" {
169 log.GetLogger(rw.ctx).Info(reply.GetLogmsg())
170 }
171 return nil
172 }
173
174
175 func (rw *RPCWriter) DefineMetrics(metrics *metrics.Metrics) error {
176 var json = jsoniter.ConfigFastest
177
178
179
180 jsonData, err := json.Marshal(metrics)
181 if err != nil {
182 return err
183 }
184
185 var metricMap map[string]any
186 if err := json.Unmarshal(jsonData, &metricMap); err != nil {
187 return err
188 }
189
190 metricStruct, err := structpb.NewStruct(metricMap)
191 if err != nil {
192 return err
193 }
194
195 t1 := time.Now()
196 reply, err := rw.client.DefineMetrics(rw.ctx, metricStruct)
197 if err != nil {
198 return err
199 }
200
201 diff := time.Since(t1)
202 log.GetLogger(rw.ctx).WithField("elapsed", diff).Info("metric definitions written")
203 if reply.GetLogmsg() != "" {
204 log.GetLogger(rw.ctx).Info(reply.GetLogmsg())
205 }
206 return nil
207 }
208
209 func (rw *RPCWriter) watchCtx() {
210 <-rw.ctx.Done()
211 rw.conn.Close()
212 }
213
214 func LoadTLSCredentials(CAFile string) (credentials.TransportCredentials, error) {
215 ca, err := os.ReadFile(CAFile)
216 if err != nil {
217 return nil, fmt.Errorf("error loading CA file: %v", err)
218 }
219
220 certPool := x509.NewCertPool()
221 ok := certPool.AppendCertsFromPEM(ca)
222 if !ok {
223 return nil, errors.New("invalid CA file")
224 }
225
226 tlsClientConfig := &tls.Config{
227 RootCAs: certPool,
228 }
229 return credentials.NewTLS(tlsClientConfig), nil
230 }