1 package sources
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "maps"
8 "math"
9 "regexp"
10 "slices"
11 "strconv"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/cybertec-postgresql/pgwatch/v5/internal/db"
17 "github.com/jackc/pgx/v5"
18 "github.com/jackc/pgx/v5/pgxpool"
19 )
20
21
22 var (
23 NewConn = db.New
24 NewConnWithConfig = db.NewWithConfig
25 )
26
27 const (
28 EnvUnknown = "UNKNOWN"
29 EnvAzureSingle = "AZURE_SINGLE"
30 EnvAzureFlexible = "AZURE_FLEXIBLE"
31 EnvGoogle = "GOOGLE"
32 )
33
34 type RuntimeInfo struct {
35 LastCheckedOn time.Time
36 IsInRecovery bool
37 VersionStr string
38 Version int
39 RealDbname string
40 SystemIdentifier string
41 IsSuperuser bool
42 Extensions map[string]int
43 ExecEnv string
44 ApproxDbSize int64
45 ChangeState map[string]map[string]string
46 }
47
48
49
50
51 type (
52 SourceConn struct {
53 Source
54 Conn db.PgxPoolIface
55 ConnConfig *pgxpool.Config
56 RuntimeInfo
57 sync.RWMutex
58 }
59
60 SourceConns []*SourceConn
61 )
62
63 func NewSourceConn(s Source) *SourceConn {
64 return &SourceConn{
65 Source: s,
66 RuntimeInfo: RuntimeInfo{
67 Extensions: make(map[string]int),
68 ChangeState: make(map[string]map[string]string),
69 },
70 }
71 }
72
73
74 func (md *SourceConn) Ping(ctx context.Context) (err error) {
75 if md.Kind == SourcePgBouncer {
76
77 _, err = md.Conn.Exec(ctx, "SHOW VERSION")
78 return
79 }
80 return md.Conn.Ping(ctx)
81 }
82
83
84
85 func (md *SourceConn) Connect(ctx context.Context, opts CmdOpts) (err error) {
86 if md.Conn == nil {
87 if err = md.ParseConfig(); err != nil {
88 return err
89 }
90 if md.Kind == SourcePgBouncer {
91 md.ConnConfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
92 }
93 if opts.MaxParallelConnectionsPerDb > 0 {
94 md.ConnConfig.MaxConns = int32(opts.MaxParallelConnectionsPerDb)
95 }
96 md.Conn, err = NewConnWithConfig(ctx, md.ConnConfig)
97 if err != nil {
98 return err
99 }
100 }
101 return md.Ping(ctx)
102 }
103
104
105 func (md *SourceConn) ParseConfig() (err error) {
106 if md.ConnConfig == nil {
107 md.ConnConfig, err = pgxpool.ParseConfig(md.ConnStr)
108 return
109 }
110 return
111 }
112
113
114
115 func (md *SourceConn) GetClusterIdentifier() string {
116 if err := md.ParseConfig(); err != nil {
117 return ""
118 }
119 return fmt.Sprintf("%s:%s:%d", md.SystemIdentifier, md.ConnConfig.ConnConfig.Host, md.ConnConfig.ConnConfig.Port)
120 }
121
122
123 func (md *SourceConn) GetDatabaseName() string {
124 if err := md.ParseConfig(); err != nil {
125 return ""
126 }
127 return md.ConnConfig.ConnConfig.Database
128 }
129
130
131 func (md *SourceConn) GetMetricInterval(name string) time.Duration {
132 md.RLock()
133 defer md.RUnlock()
134 if md.IsInRecovery && len(md.MetricsStandby) > 0 {
135 return time.Duration(md.MetricsStandby[name]) * time.Second
136 }
137 return time.Duration(md.Metrics[name]) * time.Second
138 }
139
140
141 func (md *SourceConn) IsClientOnSameHost() bool {
142 ok, err := db.IsClientOnSameHost(md.Conn)
143 return ok && err == nil
144 }
145
146
147 func (md *SourceConn) SetDatabaseName(name string) {
148 if err := md.ParseConfig(); err != nil {
149 return
150 }
151 md.ConnStr = ""
152 md.ConnConfig.ConnConfig.Database = name
153 }
154
155 func (md *SourceConn) IsPostgresSource() bool {
156 return md.Kind != SourcePgBouncer && md.Kind != SourcePgPool
157 }
158
159
160
161
162
163 var regVer = regexp.MustCompile(`(\d+).?(\d*).?(\d*)`)
164
165 func VersionToInt(version string) (v int) {
166 if matches := regVer.FindStringSubmatch(version); len(matches) > 1 {
167 for i, match := range matches[1:] {
168 v += func() (m int) { m, _ = strconv.Atoi(match); return }() * int(math.Pow10(4-i*2))
169 }
170 }
171 return
172 }
173
174 func (md *SourceConn) FetchRuntimeInfo(ctx context.Context, forceRefetch bool) (err error) {
175 md.Lock()
176 defer md.Unlock()
177 if ctx.Err() != nil {
178 return ctx.Err()
179 }
180
181 if !forceRefetch && md.LastCheckedOn.After(time.Now().Add(time.Minute*-2)) {
182 return nil
183 }
184 switch md.Kind {
185 case SourcePgBouncer, SourcePgPool:
186 if md.VersionStr, md.Version, err = md.FetchVersion(ctx, func() string {
187 if md.Kind == SourcePgBouncer {
188 return "SHOW VERSION"
189 }
190 return "SHOW POOL_VERSION"
191 }()); err != nil {
192 return
193 }
194 default:
195 sql := `select /* pgwatch_generated */
196 div(current_setting('server_version_num')::int, 10000) as ver,
197 version(),
198 pg_is_in_recovery(),
199 current_database()::TEXT,
200 system_identifier,
201 current_setting('is_superuser')::bool
202 FROM
203 pg_control_system()`
204
205 err = md.Conn.QueryRow(ctx, sql).
206 Scan(&md.Version, &md.VersionStr,
207 &md.IsInRecovery, &md.RealDbname,
208 &md.SystemIdentifier, &md.IsSuperuser)
209 if err != nil {
210 return err
211 }
212
213 md.ExecEnv = md.DiscoverPlatform(ctx)
214 md.ApproxDbSize = md.FetchApproxSize(ctx)
215
216 sqlExtensions := `select /* pgwatch_generated */ extname::text, (regexp_matches(extversion, $$\d+\.?\d+?$$))[1]::text as extversion from pg_extension order by 1;`
217 var res pgx.Rows
218 res, err = md.Conn.Query(ctx, sqlExtensions)
219 if err == nil {
220 var ext string
221 var ver string
222 _, err = pgx.ForEachRow(res, []any{&ext, &ver}, func() error {
223 extver := VersionToInt(ver)
224 if extver == 0 {
225 return fmt.Errorf("unexpected extension %s version input: %s", ext, ver)
226 }
227 md.Extensions[ext] = extver
228 return nil
229 })
230 }
231
232 }
233 md.LastCheckedOn = time.Now()
234 return err
235 }
236
237 func (md *SourceConn) FetchVersion(ctx context.Context, sql string) (version string, ver int, err error) {
238 if err = md.Conn.QueryRow(ctx, sql, pgx.QueryExecModeSimpleProtocol).Scan(&version); err != nil {
239 return
240 }
241 ver = VersionToInt(version)
242 return
243 }
244
245
246
247 func (md *SourceConn) DiscoverPlatform(ctx context.Context) (platform string) {
248 if md.ExecEnv != "" {
249 return md.ExecEnv
250 }
251 sql := `select /* pgwatch_generated */
252 case
253 when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by Visual C' then 'AZURE_SINGLE'
254 when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by gcc' then 'AZURE_FLEXIBLE'
255 when exists (select * from pg_settings where name = 'cloudsql.supported_extensions') then 'GOOGLE'
256 else
257 'UNKNOWN'
258 end as exec_env`
259 _ = md.Conn.QueryRow(ctx, sql).Scan(&platform)
260 return
261 }
262
263
264 func (md *SourceConn) FetchApproxSize(ctx context.Context) (size int64) {
265 sqlApproxDBSize := `select /* pgwatch_generated */ current_setting('block_size')::int8 * sum(relpages) from pg_class c where c.relpersistence != 't'`
266 _ = md.Conn.QueryRow(ctx, sqlApproxDBSize).Scan(&size)
267 return
268 }
269
270
271 func (md *SourceConn) FunctionExists(ctx context.Context, functionName string) (exists bool) {
272 sql := `select /* pgwatch_generated */ true
273 from
274 pg_proc join pg_namespace n on pronamespace = n.oid
275 where
276 proname = $1 and n.nspname = 'public'`
277 _ = md.Conn.QueryRow(ctx, sql, functionName).Scan(&exists)
278 return
279 }
280
281
282
283
284 func (md *SourceConn) TryCreateMissingExtensions(ctx context.Context, extensions []string) (string, error) {
285 md.RLock()
286 defer md.RUnlock()
287
288 sqlAvailableExts := `select name::text from pg_available_extensions order by 1`
289 CreatedExts := make([]string, 0)
290
291
292 data, err := md.Conn.Query(ctx, sqlAvailableExts)
293 if err != nil {
294 return "", err
295 }
296 availableExts, err := pgx.CollectRows(data, pgx.RowTo[string])
297 if err != nil {
298 return "", err
299 }
300
301 for _, extToCreate := range extensions {
302 if _, ok := md.Extensions[extToCreate]; ok {
303 continue
304 }
305 if _, ok := slices.BinarySearch(availableExts, extToCreate); !ok {
306 err = errors.Join(err, fmt.Errorf("requested extension %s is not available on instance", extToCreate))
307 continue
308 }
309 if _, e := md.Conn.Exec(ctx, fmt.Sprintf(`create extension if not exists "%s"`, extToCreate)); e != nil {
310 err = errors.Join(err, fmt.Errorf("failed to create extension %s: %w", extToCreate, e))
311 } else {
312 CreatedExts = append(CreatedExts, extToCreate)
313 }
314 }
315 return strings.Join(CreatedExts, ","), err
316 }
317
318
319 func (md *SourceConn) TryCreateMetricsHelpers(ctx context.Context, getSQLFn func(string) string) (err error) {
320 md.RLock()
321 defer md.RUnlock()
322 var sql string
323 metrics := maps.Clone(md.Metrics)
324 maps.Insert(metrics, maps.All(md.MetricsStandby))
325 for metricName := range metrics {
326 if sql = getSQLFn(metricName); sql == "" {
327 continue
328 }
329 if _, e := md.Conn.Exec(ctx, sql); e != nil {
330 err = errors.Join(err, fmt.Errorf("failed to create helper for metric %s: %w", metricName, e))
331 }
332 }
333 return
334 }
335
336 func (mds SourceConns) GetMonitoredDatabase(DBUniqueName string) *SourceConn {
337 for _, md := range mds {
338 if md.Name == DBUniqueName {
339 return md
340 }
341 }
342 return nil
343 }
344