...
1 package sources
2
3 import (
4 "context"
5 "reflect"
6
7 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
8 "github.com/jackc/pgx/v5"
9 "github.com/jackc/pgx/v5/pgxpool"
10 )
11
12
13 var (
14 NewConn = db.New
15 NewConnWithConfig = db.NewWithConfig
16 )
17
18
19
20
21 type (
22 SourceConn struct {
23 Source
24 Conn db.PgxPoolIface
25 ConnConfig *pgxpool.Config
26 }
27
28 SourceConns []*SourceConn
29 )
30
31
32 func (md *SourceConn) Ping(ctx context.Context) (err error) {
33 if md.Kind == SourcePgBouncer {
34
35 _, err = md.Conn.Exec(ctx, "SHOW VERSION")
36 return
37 }
38 return md.Conn.Ping(ctx)
39 }
40
41
42
43 func (md *SourceConn) Connect(ctx context.Context, opts CmdOpts) (err error) {
44 if md.Conn == nil {
45 if err = md.ParseConfig(); err != nil {
46 return err
47 }
48 if md.Kind == SourcePgBouncer {
49 md.ConnConfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
50 }
51 if opts.MaxParallelConnectionsPerDb > 0 {
52 md.ConnConfig.MaxConns = int32(opts.MaxParallelConnectionsPerDb)
53 }
54 md.Conn, err = NewConnWithConfig(ctx, md.ConnConfig)
55 if err != nil {
56 return err
57 }
58 }
59 return md.Ping(ctx)
60 }
61
62
63 func (md *SourceConn) ParseConfig() (err error) {
64 if md.ConnConfig == nil {
65 md.ConnConfig, err = pgxpool.ParseConfig(md.ConnStr)
66 return
67 }
68 return
69 }
70
71
72 func (md *SourceConn) GetDatabaseName() string {
73 if err := md.ParseConfig(); err != nil {
74 return ""
75 }
76 return md.ConnConfig.ConnConfig.Database
77 }
78
79
80 func (md *SourceConn) SetDatabaseName(name string) {
81 if err := md.ParseConfig(); err != nil {
82 return
83 }
84 md.ConnStr = ""
85 md.ConnConfig.ConnConfig.Database = name
86 }
87
88 func (md *SourceConn) IsPostgresSource() bool {
89 return md.Kind != SourcePgBouncer && md.Kind != SourcePgPool
90 }
91
92
93
94 func (md *SourceConn) DiscoverPlatform(ctx context.Context) (platform string) {
95 sql := `select /* pgwatch_generated */
96 case
97 when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by Visual C' then 'AZURE_SINGLE'
98 when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by gcc' then 'AZURE_FLEXIBLE'
99 when exists (select * from pg_settings where name = 'cloudsql.supported_extensions') then 'GOOGLE'
100 else
101 'UNKNOWN'
102 end as exec_env`
103 _ = md.Conn.QueryRow(ctx, sql).Scan(&platform)
104 return
105 }
106
107
108 func (md *SourceConn) GetApproxSize(ctx context.Context) (size int64, err error) {
109 sqlApproxDBSize := `select /* pgwatch_generated */
110 current_setting('block_size')::int8 * sum(relpages)
111 from
112 pg_class c
113 where
114 c.relpersistence != 't'`
115 err = md.Conn.QueryRow(ctx, sqlApproxDBSize).Scan(&size)
116 return
117 }
118
119
120 func (md *SourceConn) FunctionExists(ctx context.Context, functionName string) (exists bool) {
121 sql := `select /* pgwatch_generated */ true
122 from
123 pg_proc join pg_namespace n on pronamespace = n.oid
124 where
125 proname = $1 and n.nspname = 'public'`
126 _ = md.Conn.QueryRow(ctx, sql, functionName).Scan(&exists)
127 return
128 }
129
130 func (mds SourceConns) GetMonitoredDatabase(DBUniqueName string) *SourceConn {
131 for _, md := range mds {
132 if md.Name == DBUniqueName {
133 return md
134 }
135 }
136 return nil
137 }
138
139
140
141
142 func (mds SourceConns) SyncFromReader(r Reader) (newmds SourceConns, err error) {
143 srcs, err := r.GetSources()
144 if err != nil {
145 return nil, err
146 }
147 newmds, err = srcs.ResolveDatabases()
148 for _, newMD := range newmds {
149 md := mds.GetMonitoredDatabase(newMD.Name)
150 if md == nil {
151 continue
152 }
153 if reflect.DeepEqual(md.Source, newMD.Source) {
154
155 newMD.Conn = md.Conn
156 newMD.ConnConfig = md.ConnConfig
157 continue
158 }
159 if md.Conn != nil {
160 md.Conn.Close()
161 }
162 }
163 return newmds, err
164 }
165