...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/sources/conn.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/sources

     1  package sources
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  
     7  	"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
     8  	"github.com/jackc/pgx/v5"
     9  	"github.com/jackc/pgx/v5/pgxpool"
    10  )
    11  
    12  // NewConn and NewConnWithConfig are wrappers to allow testing
    13  var (
    14  	NewConn           = db.New
    15  	NewConnWithConfig = db.NewWithConfig
    16  )
    17  
    18  // SourceConn represents a single connection to monitor. Unlike source, it contains a database connection.
    19  // Continuous discovery sources (postgres-continuous-discovery, patroni-continuous-discovery, patroni-namespace-discovery)
    20  // will produce multiple monitored databases structs based on the discovered databases.
    21  type (
    22  	SourceConn struct {
    23  		Source
    24  		Conn       db.PgxPoolIface
    25  		ConnConfig *pgxpool.Config
    26  	}
    27  
    28  	SourceConns []*SourceConn
    29  )
    30  
    31  // Ping will try to ping the server to ensure the connection is still alive
    32  func (md *SourceConn) Ping(ctx context.Context) (err error) {
    33  	if md.Kind == SourcePgBouncer {
    34  		// pgbouncer is very picky about the queries it accepts
    35  		_, err = md.Conn.Exec(ctx, "SHOW VERSION")
    36  		return
    37  	}
    38  	return md.Conn.Ping(ctx)
    39  }
    40  
    41  // Connect will establish a connection to the database if it's not already connected.
    42  // If the connection is already established, it pings the server to ensure it's still alive.
    43  func (md *SourceConn) Connect(ctx context.Context, opts CmdOpts) (err error) {
    44  	if md.Conn == nil {
    45  		if err = md.ParseConfig(); err != nil {
    46  			return err
    47  		}
    48  		if md.Kind == SourcePgBouncer {
    49  			md.ConnConfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
    50  		}
    51  		if opts.MaxParallelConnectionsPerDb > 0 {
    52  			md.ConnConfig.MaxConns = int32(opts.MaxParallelConnectionsPerDb)
    53  		}
    54  		md.Conn, err = NewConnWithConfig(ctx, md.ConnConfig)
    55  		if err != nil {
    56  			return err
    57  		}
    58  	}
    59  	return md.Ping(ctx)
    60  }
    61  
    62  // ParseConfig will parse the connection string and store the result in the connection config
    63  func (md *SourceConn) ParseConfig() (err error) {
    64  	if md.ConnConfig == nil {
    65  		md.ConnConfig, err = pgxpool.ParseConfig(md.ConnStr)
    66  		return
    67  	}
    68  	return
    69  }
    70  
    71  // GetDatabaseName returns the database name from the connection string
    72  func (md *SourceConn) GetDatabaseName() string {
    73  	if err := md.ParseConfig(); err != nil {
    74  		return ""
    75  	}
    76  	return md.ConnConfig.ConnConfig.Database
    77  }
    78  
    79  // SetDatabaseName sets the database name in the connection config for resolved databases
    80  func (md *SourceConn) SetDatabaseName(name string) {
    81  	if err := md.ParseConfig(); err != nil {
    82  		return
    83  	}
    84  	md.ConnStr = "" // unset the connection string to force conn config usage
    85  	md.ConnConfig.ConnConfig.Database = name
    86  }
    87  
    88  func (md *SourceConn) IsPostgresSource() bool {
    89  	return md.Kind != SourcePgBouncer && md.Kind != SourcePgPool
    90  }
    91  
    92  // TryDiscoverPlatform tries to discover the platform based on the database version string and some special settings
    93  // that are only available on certain platforms. Returns the platform name or "UNKNOWN" if not sure.
    94  func (md *SourceConn) DiscoverPlatform(ctx context.Context) (platform string) {
    95  	sql := `select /* pgwatch_generated */
    96  	case
    97  	  when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by Visual C' then 'AZURE_SINGLE'
    98  	  when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by gcc' then 'AZURE_FLEXIBLE'
    99  	  when exists (select * from pg_settings where name = 'cloudsql.supported_extensions') then 'GOOGLE'
   100  	else
   101  	  'UNKNOWN'
   102  	end as exec_env`
   103  	_ = md.Conn.QueryRow(ctx, sql).Scan(&platform)
   104  	return
   105  }
   106  
   107  // GetApproxSize returns the approximate size of the database in bytes
   108  func (md *SourceConn) GetApproxSize(ctx context.Context) (size int64, err error) {
   109  	sqlApproxDBSize := `select /* pgwatch_generated */ 
   110  	current_setting('block_size')::int8 * sum(relpages)
   111  from 
   112  	pg_class c
   113  where
   114  	c.relpersistence != 't'`
   115  	err = md.Conn.QueryRow(ctx, sqlApproxDBSize).Scan(&size)
   116  	return
   117  }
   118  
   119  // FunctionExists checks if a function exists in the database
   120  func (md *SourceConn) FunctionExists(ctx context.Context, functionName string) (exists bool) {
   121  	sql := `select /* pgwatch_generated */ true 
   122  from 
   123  	pg_proc join pg_namespace n on pronamespace = n.oid 
   124  where 
   125  	proname = $1 and n.nspname = 'public'`
   126  	_ = md.Conn.QueryRow(ctx, sql, functionName).Scan(&exists)
   127  	return
   128  }
   129  
   130  func (mds SourceConns) GetMonitoredDatabase(DBUniqueName string) *SourceConn {
   131  	for _, md := range mds {
   132  		if md.Name == DBUniqueName {
   133  			return md
   134  		}
   135  	}
   136  	return nil
   137  }
   138  
   139  // SyncFromReader will update the monitored databases with the latest configuration from the reader.
   140  // Any resolution errors will be returned, e.g. etcd unavailability.
   141  // It's up to the caller to proceed with the databases available or stop the execution due to errors.
   142  func (mds SourceConns) SyncFromReader(r Reader) (newmds SourceConns, err error) {
   143  	srcs, err := r.GetSources()
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	newmds, err = srcs.ResolveDatabases()
   148  	for _, newMD := range newmds {
   149  		md := mds.GetMonitoredDatabase(newMD.Name)
   150  		if md == nil {
   151  			continue
   152  		}
   153  		if reflect.DeepEqual(md.Source, newMD.Source) {
   154  			// keep the existing connection if the source is the same
   155  			newMD.Conn = md.Conn
   156  			newMD.ConnConfig = md.ConnConfig
   157  			continue
   158  		}
   159  		if md.Conn != nil {
   160  			md.Conn.Close()
   161  		}
   162  	}
   163  	return newmds, err
   164  }
   165