1 package sources
2
3 import (
4 "context"
5 "fmt"
6 "math"
7 "regexp"
8 "strconv"
9 "sync"
10 "time"
11
12 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
13 "github.com/jackc/pgx/v5"
14 "github.com/jackc/pgx/v5/pgxpool"
15 )
16
17
18 var (
19 NewConn = db.New
20 NewConnWithConfig = db.NewWithConfig
21 )
22
23 const (
24 EnvUnknown = "UNKNOWN"
25 EnvAzureSingle = "AZURE_SINGLE"
26 EnvAzureFlexible = "AZURE_FLEXIBLE"
27 EnvGoogle = "GOOGLE"
28 )
29
30 type RuntimeInfo struct {
31 LastCheckedOn time.Time
32 IsInRecovery bool
33 VersionStr string
34 Version int
35 RealDbname string
36 SystemIdentifier string
37 IsSuperuser bool
38 Extensions map[string]int
39 ExecEnv string
40 ApproxDbSize int64
41 ChangeState map[string]map[string]string
42 }
43
44
45
46
47 type (
48 SourceConn struct {
49 Source
50 Conn db.PgxPoolIface
51 ConnConfig *pgxpool.Config
52 RuntimeInfo
53 sync.RWMutex
54 }
55
56 SourceConns []*SourceConn
57 )
58
59 func NewSourceConn(s Source) *SourceConn {
60 return &SourceConn{
61 Source: s,
62 RuntimeInfo: RuntimeInfo{
63 Extensions: make(map[string]int),
64 ChangeState: make(map[string]map[string]string),
65 },
66 }
67 }
68
69
70 func (md *SourceConn) Ping(ctx context.Context) (err error) {
71 if md.Kind == SourcePgBouncer {
72
73 _, err = md.Conn.Exec(ctx, "SHOW VERSION")
74 return
75 }
76 return md.Conn.Ping(ctx)
77 }
78
79
80
81 func (md *SourceConn) Connect(ctx context.Context, opts CmdOpts) (err error) {
82 if md.Conn == nil {
83 if err = md.ParseConfig(); err != nil {
84 return err
85 }
86 if md.Kind == SourcePgBouncer {
87 md.ConnConfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
88 }
89 if opts.MaxParallelConnectionsPerDb > 0 {
90 md.ConnConfig.MaxConns = int32(opts.MaxParallelConnectionsPerDb)
91 }
92 md.Conn, err = NewConnWithConfig(ctx, md.ConnConfig)
93 if err != nil {
94 return err
95 }
96 }
97 return md.Ping(ctx)
98 }
99
100
101 func (md *SourceConn) ParseConfig() (err error) {
102 if md.ConnConfig == nil {
103 md.ConnConfig, err = pgxpool.ParseConfig(md.ConnStr)
104 return
105 }
106 return
107 }
108
109
110
111 func (md *SourceConn) GetClusterIdentifier() string {
112 if err := md.ParseConfig(); err != nil {
113 return ""
114 }
115 return fmt.Sprintf("%s:%s:%d", md.SystemIdentifier, md.ConnConfig.ConnConfig.Host, md.ConnConfig.ConnConfig.Port)
116 }
117
118
119 func (md *SourceConn) GetDatabaseName() string {
120 if err := md.ParseConfig(); err != nil {
121 return ""
122 }
123 return md.ConnConfig.ConnConfig.Database
124 }
125
126
127 func (md *SourceConn) GetMetricInterval(name string) float64 {
128 md.RLock()
129 defer md.RUnlock()
130 if md.IsInRecovery && len(md.MetricsStandby) > 0 {
131 return md.MetricsStandby[name]
132 }
133 return md.Metrics[name]
134 }
135
136
137 func (md *SourceConn) IsClientOnSameHost() bool {
138 ok, err := db.IsClientOnSameHost(md.Conn)
139 return ok && err == nil
140 }
141
142
143 func (md *SourceConn) SetDatabaseName(name string) {
144 if err := md.ParseConfig(); err != nil {
145 return
146 }
147 md.ConnStr = ""
148 md.ConnConfig.ConnConfig.Database = name
149 }
150
151 func (md *SourceConn) IsPostgresSource() bool {
152 return md.Kind != SourcePgBouncer && md.Kind != SourcePgPool
153 }
154
155
156
157
158
159 var regVer = regexp.MustCompile(`(\d+).?(\d*).?(\d*)`)
160
161 func VersionToInt(version string) (v int) {
162 if matches := regVer.FindStringSubmatch(version); len(matches) > 1 {
163 for i, match := range matches[1:] {
164 v += func() (m int) { m, _ = strconv.Atoi(match); return }() * int(math.Pow10(4-i*2))
165 }
166 }
167 return
168 }
169
170 func (md *SourceConn) FetchRuntimeInfo(ctx context.Context, forceRefetch bool) (err error) {
171 md.Lock()
172 defer md.Unlock()
173 if ctx.Err() != nil {
174 return ctx.Err()
175 }
176
177 if !forceRefetch && md.LastCheckedOn.After(time.Now().Add(time.Minute*-2)) {
178 return nil
179 }
180 switch md.Kind {
181 case SourcePgBouncer, SourcePgPool:
182 if md.VersionStr, md.Version, err = md.FetchVersion(ctx, func() string {
183 if md.Kind == SourcePgBouncer {
184 return "SHOW VERSION"
185 }
186 return "SHOW POOL_VERSION"
187 }()); err != nil {
188 return
189 }
190 default:
191 sql := `select /* pgwatch_generated */
192 div(current_setting('server_version_num')::int, 10000) as ver,
193 version(),
194 pg_is_in_recovery(),
195 current_database()::TEXT,
196 system_identifier,
197 current_setting('is_superuser')::bool
198 FROM
199 pg_control_system()`
200
201 err = md.Conn.QueryRow(ctx, sql).
202 Scan(&md.Version, &md.VersionStr,
203 &md.IsInRecovery, &md.RealDbname,
204 &md.SystemIdentifier, &md.IsSuperuser)
205 if err != nil {
206 return err
207 }
208
209 md.ExecEnv = md.DiscoverPlatform(ctx)
210 md.ApproxDbSize = md.FetchApproxSize(ctx)
211
212 sqlExtensions := `select /* pgwatch_generated */ extname::text, (regexp_matches(extversion, $$\d+\.?\d+?$$))[1]::text as extversion from pg_extension order by 1;`
213 var res pgx.Rows
214 res, err = md.Conn.Query(ctx, sqlExtensions)
215 if err == nil {
216 var ext string
217 var ver string
218 _, err = pgx.ForEachRow(res, []any{&ext, &ver}, func() error {
219 extver := VersionToInt(ver)
220 if extver == 0 {
221 return fmt.Errorf("unexpected extension %s version input: %s", ext, ver)
222 }
223 md.Extensions[ext] = extver
224 return nil
225 })
226 }
227
228 }
229 md.LastCheckedOn = time.Now()
230 return err
231 }
232
233 func (md *SourceConn) FetchVersion(ctx context.Context, sql string) (version string, ver int, err error) {
234 if err = md.Conn.QueryRow(ctx, sql, pgx.QueryExecModeSimpleProtocol).Scan(&version); err != nil {
235 return
236 }
237 ver = VersionToInt(version)
238 return
239 }
240
241
242
243 func (md *SourceConn) DiscoverPlatform(ctx context.Context) (platform string) {
244 if md.ExecEnv != "" {
245 return md.ExecEnv
246 }
247 sql := `select /* pgwatch_generated */
248 case
249 when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by Visual C' then 'AZURE_SINGLE'
250 when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by gcc' then 'AZURE_FLEXIBLE'
251 when exists (select * from pg_settings where name = 'cloudsql.supported_extensions') then 'GOOGLE'
252 else
253 'UNKNOWN'
254 end as exec_env`
255 _ = md.Conn.QueryRow(ctx, sql).Scan(&platform)
256 return
257 }
258
259
260 func (md *SourceConn) FetchApproxSize(ctx context.Context) (size int64) {
261 sqlApproxDBSize := `select /* pgwatch_generated */ current_setting('block_size')::int8 * sum(relpages) from pg_class c where c.relpersistence != 't'`
262 _ = md.Conn.QueryRow(ctx, sqlApproxDBSize).Scan(&size)
263 return
264 }
265
266
267 func (md *SourceConn) FunctionExists(ctx context.Context, functionName string) (exists bool) {
268 sql := `select /* pgwatch_generated */ true
269 from
270 pg_proc join pg_namespace n on pronamespace = n.oid
271 where
272 proname = $1 and n.nspname = 'public'`
273 _ = md.Conn.QueryRow(ctx, sql, functionName).Scan(&exists)
274 return
275 }
276
277 func (mds SourceConns) GetMonitoredDatabase(DBUniqueName string) *SourceConn {
278 for _, md := range mds {
279 if md.Name == DBUniqueName {
280 return md
281 }
282 }
283 return nil
284 }
285