...
1 package db
2
3 import (
4 "context"
5 "time"
6
7 "github.com/cybertec-postgresql/pgwatch/v3/internal/log"
8 "github.com/jackc/pgx/v5"
9 "github.com/jackc/pgx/v5/pgconn"
10 "github.com/jackc/pgx/v5/pgxpool"
11 "github.com/jackc/pgx/v5/tracelog"
12 retry "github.com/sethvargo/go-retry"
13 )
14
15 const (
16 pgConnRecycleSeconds = 1800
17 applicationName = "pgwatch"
18 )
19
20 func Ping(ctx context.Context, connStr string) error {
21 c, err := pgx.Connect(ctx, connStr)
22 if c != nil {
23 _ = c.Close(ctx)
24 }
25 return err
26 }
27
28 type ConnConfigCallback = func(*pgxpool.Config) error
29
30
31 func New(ctx context.Context, connStr string, callbacks ...ConnConfigCallback) (PgxPoolIface, error) {
32 connConfig, err := pgxpool.ParseConfig(connStr)
33 if err != nil {
34 return nil, err
35 }
36 return NewWithConfig(ctx, connConfig, callbacks...)
37 }
38
39
40 func NewWithConfig(ctx context.Context, connConfig *pgxpool.Config, callbacks ...ConnConfigCallback) (PgxPoolIface, error) {
41 logger := log.GetLogger(ctx)
42 if connConfig.ConnConfig.ConnectTimeout == 0 {
43 connConfig.ConnConfig.ConnectTimeout = time.Second * 5
44 }
45 connConfig.MaxConnIdleTime = 15 * time.Second
46 connConfig.MaxConnLifetime = pgConnRecycleSeconds * time.Second
47 connConfig.ConnConfig.RuntimeParams["application_name"] = applicationName
48 connConfig.ConnConfig.OnNotice = func(_ *pgconn.PgConn, n *pgconn.Notice) {
49 logger.WithField("severity", n.Severity).WithField("notice", n.Message).Info("Notice received")
50 }
51 tracelogger := &tracelog.TraceLog{
52 Logger: log.NewPgxLogger(logger),
53 LogLevel: tracelog.LogLevelDebug,
54 }
55 connConfig.ConnConfig.Tracer = tracelogger
56 for _, f := range callbacks {
57 if err := f(connConfig); err != nil {
58 return nil, err
59 }
60 }
61 return pgxpool.NewWithConfig(ctx, connConfig)
62 }
63
64 type ConnInitCallback = func(context.Context, PgxIface) error
65
66
67 func Init(ctx context.Context, db PgxPoolIface, init ConnInitCallback) error {
68 var backoff = retry.WithMaxRetries(3, retry.NewConstant(1*time.Second))
69 logger := log.GetLogger(ctx)
70 if err := retry.Do(ctx, backoff, func(ctx context.Context) error {
71 if err := db.Ping(ctx); err != nil {
72 logger.WithError(err).Error("connection failed")
73 logger.Info("sleeping before reconnecting...")
74 return retry.RetryableError(err)
75 }
76 return nil
77 }); err != nil {
78 return err
79 }
80 return init(ctx, db)
81 }
82
83
84 func DoesSchemaExist(ctx context.Context, conn PgxIface, schema string) (bool, error) {
85 var exists bool
86 sqlSchemaExists := "SELECT EXISTS(SELECT 1 FROM pg_namespace WHERE nspname = $1)"
87 err := conn.QueryRow(ctx, sqlSchemaExists, schema).Scan(&exists)
88 return exists, err
89 }
90