...
1 package db
2
3 import (
4 "context"
5 "encoding/binary"
6 "os"
7 "path/filepath"
8 "reflect"
9
10 jsoniter "github.com/json-iterator/go"
11
12 pgx "github.com/jackc/pgx/v5"
13 pgconn "github.com/jackc/pgx/v5/pgconn"
14 pgxpool "github.com/jackc/pgx/v5/pgxpool"
15 )
16
17 type Querier interface {
18 Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
19 }
20
21
22 type PgxIface interface {
23 Begin(ctx context.Context) (pgx.Tx, error)
24 Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
25 QueryRow(context.Context, string, ...interface{}) pgx.Row
26 Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
27 CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
28 }
29
30
31 type PgxConnIface interface {
32 PgxIface
33 BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
34 Close(ctx context.Context) error
35 Ping(ctx context.Context) error
36 }
37
38
39 type PgxPoolIface interface {
40 PgxIface
41 Acquire(ctx context.Context) (*pgxpool.Conn, error)
42 BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
43 Close()
44 Config() *pgxpool.Config
45 Ping(ctx context.Context) error
46 Stat() *pgxpool.Stat
47 }
48
49 func MarshallParamToJSONB(v any) any {
50 if v == nil {
51 return nil
52 }
53 val := reflect.ValueOf(v)
54 switch val.Kind() {
55 case reflect.Map, reflect.Slice:
56 if val.Len() == 0 {
57 return nil
58 }
59 case reflect.Struct:
60 if reflect.DeepEqual(v, reflect.Zero(val.Type()).Interface()) {
61 return nil
62 }
63 }
64 if b, err := jsoniter.ConfigFastest.Marshal(v); err == nil {
65 return string(b)
66 }
67 return nil
68 }
69
70
71 func IsClientOnSameHost(conn PgxIface) (bool, error) {
72 ctx := context.Background()
73
74
75 var isUnixSocket bool
76 err := conn.QueryRow(ctx, "SELECT COALESCE(inet_client_addr(), inet_server_addr()) IS NULL").Scan(&isUnixSocket)
77 if err != nil || isUnixSocket {
78 return isUnixSocket, err
79 }
80
81
82 var dataDirectory string
83 if err := conn.QueryRow(ctx, "SHOW data_directory").Scan(&dataDirectory); err != nil {
84 return false, err
85 }
86
87 var systemIdentifier uint64
88 if err := conn.QueryRow(ctx, "SELECT system_identifier FROM pg_control_system()").Scan(&systemIdentifier); err != nil {
89 return false, err
90 }
91
92
93 pgControlFile := filepath.Join(dataDirectory, "global", "pg_control")
94 file, err := os.Open(pgControlFile)
95 if err != nil {
96 return false, err
97 }
98 defer file.Close()
99
100 var fileSystemIdentifier uint64
101 if err := binary.Read(file, binary.LittleEndian, &fileSystemIdentifier); err != nil {
102 return false, err
103 }
104
105 return fileSystemIdentifier == systemIdentifier, nil
106 }
107