...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/db/conn.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/db

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"os"
     7  	"path/filepath"
     8  	"reflect"
     9  
    10  	jsoniter "github.com/json-iterator/go"
    11  
    12  	pgx "github.com/jackc/pgx/v5"
    13  	pgconn "github.com/jackc/pgx/v5/pgconn"
    14  	pgxpool "github.com/jackc/pgx/v5/pgxpool"
    15  )
    16  
    17  type Querier interface {
    18  	Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
    19  }
    20  
    21  // PgxIface is common interface for every pgx class
    22  type PgxIface interface {
    23  	Begin(ctx context.Context) (pgx.Tx, error)
    24  	Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
    25  	QueryRow(context.Context, string, ...interface{}) pgx.Row
    26  	Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
    27  	CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
    28  }
    29  
    30  // PgxConnIface is interface representing pgx connection
    31  type PgxConnIface interface {
    32  	PgxIface
    33  	BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
    34  	Close(ctx context.Context) error
    35  	Ping(ctx context.Context) error
    36  }
    37  
    38  // PgxPoolIface is interface representing pgx pool
    39  type PgxPoolIface interface {
    40  	PgxIface
    41  	Acquire(ctx context.Context) (*pgxpool.Conn, error)
    42  	BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
    43  	Close()
    44  	Config() *pgxpool.Config
    45  	Ping(ctx context.Context) error
    46  	Stat() *pgxpool.Stat
    47  }
    48  
    49  func MarshallParamToJSONB(v any) any {
    50  	if v == nil {
    51  		return nil
    52  	}
    53  	val := reflect.ValueOf(v)
    54  	switch val.Kind() {
    55  	case reflect.Map, reflect.Slice:
    56  		if val.Len() == 0 {
    57  			return nil
    58  		}
    59  	case reflect.Struct:
    60  		if reflect.DeepEqual(v, reflect.Zero(val.Type()).Interface()) {
    61  			return nil
    62  		}
    63  	}
    64  	if b, err := jsoniter.ConfigFastest.Marshal(v); err == nil {
    65  		return string(b)
    66  	}
    67  	return nil
    68  }
    69  
    70  // Function to determine if the client is connected to the same host as the PostgreSQL server
    71  func IsClientOnSameHost(conn PgxIface) (bool, error) {
    72  	ctx := context.Background()
    73  
    74  	// Step 1: Check connection type using SQL
    75  	var isUnixSocket bool
    76  	err := conn.QueryRow(ctx, "SELECT COALESCE(inet_client_addr(), inet_server_addr()) IS NULL").Scan(&isUnixSocket)
    77  	if err != nil || isUnixSocket {
    78  		return isUnixSocket, err
    79  	}
    80  
    81  	// Step 2: Retrieve unique cluster identifier
    82  	var dataDirectory string
    83  	if err := conn.QueryRow(ctx, "SHOW data_directory").Scan(&dataDirectory); err != nil {
    84  		return false, err
    85  	}
    86  
    87  	var systemIdentifier uint64
    88  	if err := conn.QueryRow(ctx, "SELECT system_identifier FROM pg_control_system()").Scan(&systemIdentifier); err != nil {
    89  		return false, err
    90  	}
    91  
    92  	// Step 3: Compare system identifier from file system
    93  	pgControlFile := filepath.Join(dataDirectory, "global", "pg_control")
    94  	file, err := os.Open(pgControlFile)
    95  	if err != nil {
    96  		return false, err
    97  	}
    98  	defer file.Close()
    99  
   100  	var fileSystemIdentifier uint64
   101  	if err := binary.Read(file, binary.LittleEndian, &fileSystemIdentifier); err != nil {
   102  		return false, err
   103  	}
   104  
   105  	return fileSystemIdentifier == systemIdentifier, nil
   106  }
   107