1 package db
2
3 import (
4 "context"
5 "encoding/binary"
6 "os"
7 "path/filepath"
8 "reflect"
9
10 jsoniter "github.com/json-iterator/go"
11
12 pgx "github.com/jackc/pgx/v5"
13 pgconn "github.com/jackc/pgx/v5/pgconn"
14 pgxpool "github.com/jackc/pgx/v5/pgxpool"
15 )
16
17 type Querier interface {
18 Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
19 }
20
21
22 type PgxIface interface {
23 Begin(ctx context.Context) (pgx.Tx, error)
24 Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
25 QueryRow(context.Context, string, ...interface{}) pgx.Row
26 Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
27 CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
28 }
29
30
31 type PgxConnIface interface {
32 PgxIface
33 BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
34 Close(ctx context.Context) error
35 Ping(ctx context.Context) error
36 }
37
38
39 type PgxPoolIface interface {
40 PgxIface
41 Acquire(ctx context.Context) (*pgxpool.Conn, error)
42 BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
43 Close()
44 Config() *pgxpool.Config
45 Ping(ctx context.Context) error
46 Stat() *pgxpool.Stat
47 }
48
49
50 type Migrator interface {
51 Migrate() error
52 NeedsMigration() (bool, error)
53 }
54
55 func NeedsMigration(storage any, needsMigrationErr error) error {
56 if m, ok := storage.(Migrator); ok {
57 if needsMigration, err := m.NeedsMigration(); err != nil {
58 return err
59 } else if needsMigration {
60 return needsMigrationErr
61 }
62 }
63 return nil
64 }
65
66 func MarshallParamToJSONB(v any) any {
67 if v == nil {
68 return nil
69 }
70 val := reflect.ValueOf(v)
71 switch val.Kind() {
72 case reflect.Map, reflect.Slice:
73 if val.Len() == 0 {
74 return nil
75 }
76 case reflect.Struct:
77 if reflect.DeepEqual(v, reflect.Zero(val.Type()).Interface()) {
78 return nil
79 }
80 }
81 if b, err := jsoniter.ConfigFastest.Marshal(v); err == nil {
82 return string(b)
83 }
84 return nil
85 }
86
87 func IsPgConnStr(arg string) bool {
88 _, err := pgx.ParseConfig(arg)
89 return err == nil
90 }
91
92
93 func IsClientOnSameHost(conn PgxIface) (bool, error) {
94 ctx := context.Background()
95
96
97 var isUnixSocket bool
98 err := conn.QueryRow(ctx, "SELECT COALESCE(inet_client_addr(), inet_server_addr()) IS NULL").Scan(&isUnixSocket)
99 if err != nil || isUnixSocket {
100 return isUnixSocket, err
101 }
102
103
104
105
106 var dataDirectory string
107 if err := conn.QueryRow(ctx, "SHOW data_directory").Scan(&dataDirectory); err != nil {
108 return false, err
109 }
110
111 var systemIdentifier uint64
112 if err := conn.QueryRow(ctx, "SELECT system_identifier FROM pg_control_system()").Scan(&systemIdentifier); err != nil {
113 return false, err
114 }
115
116
117 pgControlFile := filepath.Join(dataDirectory, "global", "pg_control")
118 file, err := os.Open(pgControlFile)
119 if err != nil {
120 return false, err
121 }
122 defer file.Close()
123
124 var fileSystemIdentifier uint64
125 if err := binary.Read(file, binary.LittleEndian, &fileSystemIdentifier); err != nil {
126 return false, err
127 }
128
129
130 return fileSystemIdentifier == systemIdentifier, nil
131 }
132