...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/sources/conn.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/sources

     1  package sources
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math"
     7  	"regexp"
     8  	"strconv"
     9  	"time"
    10  
    11  	"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
    12  	"github.com/jackc/pgx/v5"
    13  	"github.com/jackc/pgx/v5/pgxpool"
    14  )
    15  
    16  // NewConn and NewConnWithConfig are wrappers to allow testing
    17  var (
    18  	NewConn           = db.New
    19  	NewConnWithConfig = db.NewWithConfig
    20  )
    21  
    22  const (
    23  	EnvUnknown       = "UNKNOWN"
    24  	EnvAzureSingle   = "AZURE_SINGLE" //discontinued
    25  	EnvAzureFlexible = "AZURE_FLEXIBLE"
    26  	EnvGoogle        = "GOOGLE"
    27  )
    28  
    29  type RuntimeInfo struct {
    30  	LastCheckedOn    time.Time
    31  	IsInRecovery     bool
    32  	VersionStr       string
    33  	Version          int
    34  	RealDbname       string
    35  	SystemIdentifier string
    36  	IsSuperuser      bool
    37  	Extensions       map[string]int
    38  	ExecEnv          string
    39  	ApproxDbSize     int64
    40  }
    41  
    42  // SourceConn represents a single connection to monitor. Unlike source, it contains a database connection.
    43  // Continuous discovery sources (postgres-continuous-discovery, patroni-continuous-discovery, patroni-namespace-discovery)
    44  // will produce multiple monitored databases structs based on the discovered databases.
    45  type (
    46  	SourceConn struct {
    47  		Source
    48  		Conn       db.PgxPoolIface
    49  		ConnConfig *pgxpool.Config
    50  		RuntimeInfo
    51  	}
    52  
    53  	SourceConns []*SourceConn
    54  )
    55  
    56  // Ping will try to ping the server to ensure the connection is still alive
    57  func (md *SourceConn) Ping(ctx context.Context) (err error) {
    58  	if md.Kind == SourcePgBouncer {
    59  		// pgbouncer is very picky about the queries it accepts
    60  		_, err = md.Conn.Exec(ctx, "SHOW VERSION")
    61  		return
    62  	}
    63  	return md.Conn.Ping(ctx)
    64  }
    65  
    66  // Connect will establish a connection to the database if it's not already connected.
    67  // If the connection is already established, it pings the server to ensure it's still alive.
    68  func (md *SourceConn) Connect(ctx context.Context, opts CmdOpts) (err error) {
    69  	if md.Conn == nil {
    70  		if err = md.ParseConfig(); err != nil {
    71  			return err
    72  		}
    73  		if md.Kind == SourcePgBouncer {
    74  			md.ConnConfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
    75  		}
    76  		if opts.MaxParallelConnectionsPerDb > 0 {
    77  			md.ConnConfig.MaxConns = int32(opts.MaxParallelConnectionsPerDb)
    78  		}
    79  		md.Conn, err = NewConnWithConfig(ctx, md.ConnConfig)
    80  		if err != nil {
    81  			return err
    82  		}
    83  	}
    84  	return md.Ping(ctx)
    85  }
    86  
    87  // ParseConfig will parse the connection string and store the result in the connection config
    88  func (md *SourceConn) ParseConfig() (err error) {
    89  	if md.ConnConfig == nil {
    90  		md.ConnConfig, err = pgxpool.ParseConfig(md.ConnStr)
    91  		return
    92  	}
    93  	return
    94  }
    95  
    96  // GetUniqueIdentifier returns a unique identifier for the host assuming SysId is the same for
    97  // primary and all replicas but connection information is different
    98  func (md *SourceConn) GetClusterIdentifier() string {
    99  	if err := md.ParseConfig(); err != nil {
   100  		return ""
   101  	}
   102  	return fmt.Sprintf("%s:%s:%d", md.SystemIdentifier, md.ConnConfig.ConnConfig.Host, md.ConnConfig.ConnConfig.Port)
   103  }
   104  
   105  // GetDatabaseName returns the database name from the connection string
   106  func (md *SourceConn) GetDatabaseName() string {
   107  	if err := md.ParseConfig(); err != nil {
   108  		return ""
   109  	}
   110  	return md.ConnConfig.ConnConfig.Database
   111  }
   112  
   113  // GetMetricInterval returns the metric interval for the connection
   114  func (md *SourceConn) GetMetricInterval(name string) float64 {
   115  	if md.IsInRecovery && len(md.MetricsStandby) > 0 {
   116  		return md.MetricsStandby[name]
   117  	}
   118  	return md.Metrics[name]
   119  }
   120  
   121  // SetDatabaseName sets the database name in the connection config for resolved databases
   122  func (md *SourceConn) SetDatabaseName(name string) {
   123  	if err := md.ParseConfig(); err != nil {
   124  		return
   125  	}
   126  	md.ConnStr = "" // unset the connection string to force conn config usage
   127  	md.ConnConfig.ConnConfig.Database = name
   128  }
   129  
   130  func (md *SourceConn) IsPostgresSource() bool {
   131  	return md.Kind != SourcePgBouncer && md.Kind != SourcePgPool
   132  }
   133  
   134  // VersionToInt parses a given version and returns an integer  or
   135  // an error if unable to parse the version. Only parses valid semantic versions.
   136  // Performs checking that can find errors within the version.
   137  // Examples: v1.2 -> 01_02_00, v9.6.3 -> 09_06_03, v11 -> 11_00_00
   138  var regVer = regexp.MustCompile(`(\d+).?(\d*).?(\d*)`)
   139  
   140  func VersionToInt(version string) (v int) {
   141  	if matches := regVer.FindStringSubmatch(version); len(matches) > 1 {
   142  		for i, match := range matches[1:] {
   143  			v += func() (m int) { m, _ = strconv.Atoi(match); return }() * int(math.Pow10(4-i*2))
   144  		}
   145  	}
   146  	return
   147  }
   148  
   149  func (md *SourceConn) FetchRuntimeInfo(ctx context.Context, forceRefetch bool) (err error) {
   150  	if ctx.Err() != nil {
   151  		return ctx.Err()
   152  	}
   153  
   154  	if !forceRefetch && md.LastCheckedOn.After(time.Now().Add(time.Minute*-2)) { // use cached version for 2 min
   155  		return nil
   156  	}
   157  
   158  	dbNewSettings := RuntimeInfo{
   159  		Extensions: make(map[string]int),
   160  	}
   161  
   162  	switch md.Kind {
   163  	case SourcePgBouncer, SourcePgPool:
   164  		if dbNewSettings.VersionStr, dbNewSettings.Version, err = md.FetchVersion(ctx, func() string {
   165  			if md.Kind == SourcePgBouncer {
   166  				return "SHOW VERSION"
   167  			}
   168  			return "SHOW POOL_VERSION"
   169  		}()); err != nil {
   170  			return
   171  		}
   172  	default:
   173  		sql := `select /* pgwatch_generated */ 
   174  	div(current_setting('server_version_num')::int, 10000) as ver, 
   175  	version(), 
   176  	pg_is_in_recovery(), 
   177  	current_database()::TEXT,
   178  	system_identifier,
   179  	current_setting('is_superuser')::bool
   180  FROM
   181  	pg_control_system()`
   182  
   183  		err = md.Conn.QueryRow(ctx, sql).
   184  			Scan(&dbNewSettings.Version, &dbNewSettings.VersionStr,
   185  				&dbNewSettings.IsInRecovery, &dbNewSettings.RealDbname,
   186  				&dbNewSettings.SystemIdentifier, &dbNewSettings.IsSuperuser)
   187  		if err != nil {
   188  			return err
   189  		}
   190  
   191  		dbNewSettings.ExecEnv = md.DiscoverPlatform(ctx)
   192  		dbNewSettings.ApproxDbSize = md.FetchApproxSize(ctx)
   193  
   194  		sqlExtensions := `select /* pgwatch_generated */ extname::text, (regexp_matches(extversion, $$\d+\.?\d+?$$))[1]::text as extversion from pg_extension order by 1;`
   195  		var res pgx.Rows
   196  		res, err = md.Conn.Query(ctx, sqlExtensions)
   197  		if err == nil {
   198  			var ext string
   199  			var ver string
   200  			_, err = pgx.ForEachRow(res, []any{&ext, &ver}, func() error {
   201  				extver := VersionToInt(ver)
   202  				if extver == 0 {
   203  					return fmt.Errorf("unexpected extension %s version input: %s", ext, ver)
   204  				}
   205  				dbNewSettings.Extensions[ext] = extver
   206  				return nil
   207  			})
   208  		}
   209  
   210  	}
   211  	dbNewSettings.LastCheckedOn = time.Now()
   212  	md.RuntimeInfo = dbNewSettings // store the new settings in the struct
   213  	return err
   214  }
   215  
   216  func (md *SourceConn) FetchVersion(ctx context.Context, sql string) (version string, ver int, err error) {
   217  	if err = md.Conn.QueryRow(ctx, sql, pgx.QueryExecModeSimpleProtocol).Scan(&version); err != nil {
   218  		return
   219  	}
   220  	ver = VersionToInt(version)
   221  	return
   222  }
   223  
   224  // TryDiscoverPlatform tries to discover the platform based on the database version string and some special settings
   225  // that are only available on certain platforms. Returns the platform name or "UNKNOWN" if not sure.
   226  func (md *SourceConn) DiscoverPlatform(ctx context.Context) (platform string) {
   227  	if md.ExecEnv != "" {
   228  		return md.ExecEnv // carry over as not likely to change ever
   229  	}
   230  	sql := `select /* pgwatch_generated */
   231  	case
   232  	  when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by Visual C' then 'AZURE_SINGLE'
   233  	  when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by gcc' then 'AZURE_FLEXIBLE'
   234  	  when exists (select * from pg_settings where name = 'cloudsql.supported_extensions') then 'GOOGLE'
   235  	else
   236  	  'UNKNOWN'
   237  	end as exec_env`
   238  	_ = md.Conn.QueryRow(ctx, sql).Scan(&platform)
   239  	return
   240  }
   241  
   242  // FetchApproxSize returns the approximate size of the database in bytes
   243  func (md *SourceConn) FetchApproxSize(ctx context.Context) (size int64) {
   244  	sqlApproxDBSize := `select /* pgwatch_generated */ current_setting('block_size')::int8 * sum(relpages) from pg_class c where c.relpersistence != 't'`
   245  	_ = md.Conn.QueryRow(ctx, sqlApproxDBSize).Scan(&size)
   246  	return
   247  }
   248  
   249  // FunctionExists checks if a function exists in the database
   250  func (md *SourceConn) FunctionExists(ctx context.Context, functionName string) (exists bool) {
   251  	sql := `select /* pgwatch_generated */ true 
   252  from 
   253  	pg_proc join pg_namespace n on pronamespace = n.oid 
   254  where 
   255  	proname = $1 and n.nspname = 'public'`
   256  	_ = md.Conn.QueryRow(ctx, sql, functionName).Scan(&exists)
   257  	return
   258  }
   259  
   260  func (mds SourceConns) GetMonitoredDatabase(DBUniqueName string) *SourceConn {
   261  	for _, md := range mds {
   262  		if md.Name == DBUniqueName {
   263  			return md
   264  		}
   265  	}
   266  	return nil
   267  }
   268