1 package db
2
3 import (
4 "context"
5 "encoding/binary"
6 "os"
7 "path/filepath"
8 "reflect"
9
10 jsoniter "github.com/json-iterator/go"
11
12 pgx "github.com/jackc/pgx/v5"
13 pgconn "github.com/jackc/pgx/v5/pgconn"
14 pgxpool "github.com/jackc/pgx/v5/pgxpool"
15 )
16
17 type Querier interface {
18 Query(ctx context.Context, query string, args ...any) (pgx.Rows, error)
19 }
20
21
22 type PgxIface interface {
23 Begin(ctx context.Context) (pgx.Tx, error)
24 Exec(context.Context, string, ...any) (pgconn.CommandTag, error)
25 QueryRow(context.Context, string, ...any) pgx.Row
26 Query(ctx context.Context, query string, args ...any) (pgx.Rows, error)
27 CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
28 }
29
30
31 type PgxConnIface interface {
32 PgxIface
33 BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
34 Close(ctx context.Context) error
35 Ping(ctx context.Context) error
36 }
37
38
39 type PgxPoolIface interface {
40 PgxIface
41 Acquire(ctx context.Context) (*pgxpool.Conn, error)
42 BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
43 Close()
44 Config() *pgxpool.Config
45 Ping(ctx context.Context) error
46 SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults
47 Stat() *pgxpool.Stat
48 }
49
50
51 type Migrator interface {
52 Migrate() error
53 NeedsMigration() (bool, error)
54 }
55
56 func NeedsMigration(storage any, needsMigrationErr error) error {
57 if m, ok := storage.(Migrator); ok {
58 if needsMigration, err := m.NeedsMigration(); err != nil {
59 return err
60 } else if needsMigration {
61 return needsMigrationErr
62 }
63 }
64 return nil
65 }
66
67 func MarshallParamToJSONB(v any) any {
68 if v == nil {
69 return nil
70 }
71 val := reflect.ValueOf(v)
72 switch val.Kind() {
73 case reflect.Map, reflect.Slice:
74 if val.Len() == 0 {
75 return nil
76 }
77 case reflect.Struct:
78 if reflect.DeepEqual(v, reflect.Zero(val.Type()).Interface()) {
79 return nil
80 }
81 }
82 if b, err := jsoniter.ConfigFastest.Marshal(v); err == nil {
83 return string(b)
84 }
85 return nil
86 }
87
88 func IsPgConnStr(arg string) bool {
89 _, err := pgx.ParseConfig(arg)
90 return err == nil
91 }
92
93
94 func IsClientOnSameHost(conn PgxIface) (bool, error) {
95 ctx := context.Background()
96
97
98 var isUnixSocket bool
99 err := conn.QueryRow(ctx, "SELECT COALESCE(inet_client_addr(), inet_server_addr()) IS NULL").Scan(&isUnixSocket)
100 if err != nil || isUnixSocket {
101 return isUnixSocket, err
102 }
103
104
105
106
107 var dataDirectory string
108 if err := conn.QueryRow(ctx, "SHOW data_directory").Scan(&dataDirectory); err != nil {
109 return false, err
110 }
111
112 var systemIdentifier uint64
113 if err := conn.QueryRow(ctx, "SELECT system_identifier FROM pg_control_system()").Scan(&systemIdentifier); err != nil {
114 return false, err
115 }
116
117
118 pgControlFile := filepath.Join(dataDirectory, "global", "pg_control")
119 file, err := os.Open(pgControlFile)
120 if err != nil {
121 return false, err
122 }
123 defer file.Close()
124
125 var fileSystemIdentifier uint64
126 if err := binary.Read(file, binary.LittleEndian, &fileSystemIdentifier); err != nil {
127 return false, err
128 }
129
130
131 return fileSystemIdentifier == systemIdentifier, nil
132 }
133