...

Source file src/github.com/cybertec-postgresql/pgwatch/v5/internal/sources/conn.go

Documentation: github.com/cybertec-postgresql/pgwatch/v5/internal/sources

     1  package sources
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"maps"
     8  	"math"
     9  	"regexp"
    10  	"slices"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/cybertec-postgresql/pgwatch/v5/internal/db"
    17  	"github.com/jackc/pgx/v5"
    18  	"github.com/jackc/pgx/v5/pgxpool"
    19  )
    20  
    21  // NewConn and NewConnWithConfig are wrappers to allow testing
    22  var (
    23  	NewConn           = db.New
    24  	NewConnWithConfig = db.NewWithConfig
    25  )
    26  
    27  const (
    28  	EnvUnknown       = "UNKNOWN"
    29  	EnvAzureSingle   = "AZURE_SINGLE" //discontinued
    30  	EnvAzureFlexible = "AZURE_FLEXIBLE"
    31  	EnvGoogle        = "GOOGLE"
    32  )
    33  
    34  type RuntimeInfo struct {
    35  	LastCheckedOn    time.Time
    36  	IsInRecovery     bool
    37  	VersionStr       string
    38  	Version          int
    39  	RealDbname       string
    40  	SystemIdentifier string
    41  	IsSuperuser      bool
    42  	Extensions       map[string]int
    43  	ExecEnv          string
    44  	ApproxDbSize     int64
    45  	ChangeState      map[string]map[string]string // ["category"][object_identifier] = state
    46  }
    47  
    48  // SourceConn represents a single connection to monitor. Unlike source, it contains a database connection.
    49  // Continuous discovery sources (postgres-continuous-discovery, patroni-continuous-discovery, patroni-namespace-discovery)
    50  // will produce multiple monitored databases structs based on the discovered databases.
    51  type (
    52  	SourceConn struct {
    53  		Source
    54  		Conn       db.PgxPoolIface
    55  		ConnConfig *pgxpool.Config
    56  		RuntimeInfo
    57  		sync.RWMutex
    58  	}
    59  
    60  	SourceConns []*SourceConn
    61  )
    62  
    63  func NewSourceConn(s Source) *SourceConn {
    64  	return &SourceConn{
    65  		Source: s,
    66  		RuntimeInfo: RuntimeInfo{
    67  			Extensions:  make(map[string]int),
    68  			ChangeState: make(map[string]map[string]string),
    69  		},
    70  	}
    71  }
    72  
    73  // Ping will try to ping the server to ensure the connection is still alive
    74  func (md *SourceConn) Ping(ctx context.Context) (err error) {
    75  	if md.Kind == SourcePgBouncer {
    76  		// pgbouncer is very picky about the queries it accepts
    77  		_, err = md.Conn.Exec(ctx, "SHOW VERSION")
    78  		return
    79  	}
    80  	return md.Conn.Ping(ctx)
    81  }
    82  
    83  // Connect will establish a connection to the database if it's not already connected.
    84  // If the connection is already established, it pings the server to ensure it's still alive.
    85  func (md *SourceConn) Connect(ctx context.Context, opts CmdOpts) (err error) {
    86  	if md.Conn == nil {
    87  		if err = md.ParseConfig(); err != nil {
    88  			return err
    89  		}
    90  		if md.Kind == SourcePgBouncer {
    91  			md.ConnConfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
    92  		}
    93  		if opts.MaxParallelConnectionsPerDb > 0 {
    94  			md.ConnConfig.MaxConns = int32(opts.MaxParallelConnectionsPerDb)
    95  		}
    96  		md.Conn, err = NewConnWithConfig(ctx, md.ConnConfig)
    97  		if err != nil {
    98  			return err
    99  		}
   100  	}
   101  	return md.Ping(ctx)
   102  }
   103  
   104  // ParseConfig will parse the connection string and store the result in the connection config
   105  func (md *SourceConn) ParseConfig() (err error) {
   106  	if md.ConnConfig == nil {
   107  		md.ConnConfig, err = pgxpool.ParseConfig(md.ConnStr)
   108  		return
   109  	}
   110  	return
   111  }
   112  
   113  // GetUniqueIdentifier returns a unique identifier for the host assuming SysId is the same for
   114  // primary and all replicas but connection information is different
   115  func (md *SourceConn) GetClusterIdentifier() string {
   116  	if err := md.ParseConfig(); err != nil {
   117  		return ""
   118  	}
   119  	return fmt.Sprintf("%s:%s:%d", md.SystemIdentifier, md.ConnConfig.ConnConfig.Host, md.ConnConfig.ConnConfig.Port)
   120  }
   121  
   122  // GetDatabaseName returns the database name from the connection string
   123  func (md *SourceConn) GetDatabaseName() string {
   124  	if err := md.ParseConfig(); err != nil {
   125  		return ""
   126  	}
   127  	return md.ConnConfig.ConnConfig.Database
   128  }
   129  
   130  // GetMetricInterval returns the metric interval for the connection
   131  func (md *SourceConn) GetMetricInterval(name string) time.Duration {
   132  	md.RLock()
   133  	defer md.RUnlock()
   134  	if md.IsInRecovery && len(md.MetricsStandby) > 0 {
   135  		return time.Duration(md.MetricsStandby[name]) * time.Second
   136  	}
   137  	return time.Duration(md.Metrics[name]) * time.Second
   138  }
   139  
   140  // IsClientOnSameHost checks if the pgwatch client is running on the same host as the PostgreSQL server
   141  func (md *SourceConn) IsClientOnSameHost() bool {
   142  	ok, err := db.IsClientOnSameHost(md.Conn)
   143  	return ok && err == nil
   144  }
   145  
   146  // SetDatabaseName sets the database name in the connection config for resolved databases
   147  func (md *SourceConn) SetDatabaseName(name string) {
   148  	if err := md.ParseConfig(); err != nil {
   149  		return
   150  	}
   151  	md.ConnStr = "" // unset the connection string to force conn config usage
   152  	md.ConnConfig.ConnConfig.Database = name
   153  }
   154  
   155  func (md *SourceConn) IsPostgresSource() bool {
   156  	return md.Kind != SourcePgBouncer && md.Kind != SourcePgPool
   157  }
   158  
   159  // VersionToInt parses a given version and returns an integer  or
   160  // an error if unable to parse the version. Only parses valid semantic versions.
   161  // Performs checking that can find errors within the version.
   162  // Examples: v1.2 -> 01_02_00, v9.6.3 -> 09_06_03, v11 -> 11_00_00
   163  var regVer = regexp.MustCompile(`(\d+).?(\d*).?(\d*)`)
   164  
   165  func VersionToInt(version string) (v int) {
   166  	if matches := regVer.FindStringSubmatch(version); len(matches) > 1 {
   167  		for i, match := range matches[1:] {
   168  			v += func() (m int) { m, _ = strconv.Atoi(match); return }() * int(math.Pow10(4-i*2))
   169  		}
   170  	}
   171  	return
   172  }
   173  
   174  func (md *SourceConn) FetchRuntimeInfo(ctx context.Context, forceRefetch bool) (err error) {
   175  	md.Lock()
   176  	defer md.Unlock()
   177  	if ctx.Err() != nil {
   178  		return ctx.Err()
   179  	}
   180  
   181  	if !forceRefetch && md.LastCheckedOn.After(time.Now().Add(time.Minute*-2)) { // use cached version for 2 min
   182  		return nil
   183  	}
   184  	switch md.Kind {
   185  	case SourcePgBouncer, SourcePgPool:
   186  		if md.VersionStr, md.Version, err = md.FetchVersion(ctx, func() string {
   187  			if md.Kind == SourcePgBouncer {
   188  				return "SHOW VERSION"
   189  			}
   190  			return "SHOW POOL_VERSION"
   191  		}()); err != nil {
   192  			return
   193  		}
   194  	default:
   195  		sql := `select /* pgwatch_generated */ 
   196  	div(current_setting('server_version_num')::int, 10000) as ver, 
   197  	version(), 
   198  	pg_is_in_recovery(), 
   199  	current_database()::TEXT,
   200  	system_identifier,
   201  	current_setting('is_superuser')::bool
   202  FROM
   203  	pg_control_system()`
   204  
   205  		err = md.Conn.QueryRow(ctx, sql).
   206  			Scan(&md.Version, &md.VersionStr,
   207  				&md.IsInRecovery, &md.RealDbname,
   208  				&md.SystemIdentifier, &md.IsSuperuser)
   209  		if err != nil {
   210  			return err
   211  		}
   212  
   213  		md.ExecEnv = md.DiscoverPlatform(ctx)
   214  		md.ApproxDbSize = md.FetchApproxSize(ctx)
   215  
   216  		sqlExtensions := `select /* pgwatch_generated */ extname::text, (regexp_matches(extversion, $$\d+\.?\d+?$$))[1]::text as extversion from pg_extension order by 1;`
   217  		var res pgx.Rows
   218  		res, err = md.Conn.Query(ctx, sqlExtensions)
   219  		if err == nil {
   220  			var ext string
   221  			var ver string
   222  			_, err = pgx.ForEachRow(res, []any{&ext, &ver}, func() error {
   223  				extver := VersionToInt(ver)
   224  				if extver == 0 {
   225  					return fmt.Errorf("unexpected extension %s version input: %s", ext, ver)
   226  				}
   227  				md.Extensions[ext] = extver
   228  				return nil
   229  			})
   230  		}
   231  
   232  	}
   233  	md.LastCheckedOn = time.Now()
   234  	return err
   235  }
   236  
   237  func (md *SourceConn) FetchVersion(ctx context.Context, sql string) (version string, ver int, err error) {
   238  	if err = md.Conn.QueryRow(ctx, sql, pgx.QueryExecModeSimpleProtocol).Scan(&version); err != nil {
   239  		return
   240  	}
   241  	ver = VersionToInt(version)
   242  	return
   243  }
   244  
   245  // TryDiscoverPlatform tries to discover the platform based on the database version string and some special settings
   246  // that are only available on certain platforms. Returns the platform name or "UNKNOWN" if not sure.
   247  func (md *SourceConn) DiscoverPlatform(ctx context.Context) (platform string) {
   248  	if md.ExecEnv != "" {
   249  		return md.ExecEnv // carry over as not likely to change ever
   250  	}
   251  	sql := `select /* pgwatch_generated */
   252  	case
   253  	  when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by Visual C' then 'AZURE_SINGLE'
   254  	  when exists (select * from pg_settings where name = 'pg_qs.host_database' and setting = 'azure_sys') and version() ~* 'compiled by gcc' then 'AZURE_FLEXIBLE'
   255  	  when exists (select * from pg_settings where name = 'cloudsql.supported_extensions') then 'GOOGLE'
   256  	else
   257  	  'UNKNOWN'
   258  	end as exec_env`
   259  	_ = md.Conn.QueryRow(ctx, sql).Scan(&platform)
   260  	return
   261  }
   262  
   263  // FetchApproxSize returns the approximate size of the database in bytes
   264  func (md *SourceConn) FetchApproxSize(ctx context.Context) (size int64) {
   265  	sqlApproxDBSize := `select /* pgwatch_generated */ current_setting('block_size')::int8 * sum(relpages) from pg_class c where c.relpersistence != 't'`
   266  	_ = md.Conn.QueryRow(ctx, sqlApproxDBSize).Scan(&size)
   267  	return
   268  }
   269  
   270  // FunctionExists checks if a function exists in the database
   271  func (md *SourceConn) FunctionExists(ctx context.Context, functionName string) (exists bool) {
   272  	sql := `select /* pgwatch_generated */ true 
   273  from 
   274  	pg_proc join pg_namespace n on pronamespace = n.oid 
   275  where 
   276  	proname = $1 and n.nspname = 'public'`
   277  	_ = md.Conn.QueryRow(ctx, sql, functionName).Scan(&exists)
   278  	return
   279  }
   280  
   281  // TryCreateMissingExtensions should be called once on daemon startup if some commonly wanted extension (most notably pg_stat_statements) is missing.
   282  // With newer Postgres version can even succeed if the user is not a real superuser due to some cloud-specific
   283  // whitelisting or "trusted extensions"
   284  func (md *SourceConn) TryCreateMissingExtensions(ctx context.Context, extensions []string) (string, error) {
   285  	md.RLock()
   286  	defer md.RUnlock()
   287  
   288  	sqlAvailableExts := `select name::text from pg_available_extensions order by 1`
   289  	CreatedExts := make([]string, 0)
   290  
   291  	// For security reasons don't allow to execute random strings but check that it's an existing extension
   292  	data, err := md.Conn.Query(ctx, sqlAvailableExts)
   293  	if err != nil {
   294  		return "", err
   295  	}
   296  	availableExts, err := pgx.CollectRows(data, pgx.RowTo[string])
   297  	if err != nil {
   298  		return "", err
   299  	}
   300  
   301  	for _, extToCreate := range extensions {
   302  		if _, ok := md.Extensions[extToCreate]; ok {
   303  			continue
   304  		}
   305  		if _, ok := slices.BinarySearch(availableExts, extToCreate); !ok {
   306  			err = errors.Join(err, fmt.Errorf("requested extension %s is not available on instance", extToCreate))
   307  			continue
   308  		}
   309  		if _, e := md.Conn.Exec(ctx, fmt.Sprintf(`create extension if not exists "%s"`, extToCreate)); e != nil {
   310  			err = errors.Join(err, fmt.Errorf("failed to create extension %s: %w", extToCreate, e))
   311  		} else {
   312  			CreatedExts = append(CreatedExts, extToCreate)
   313  		}
   314  	}
   315  	return strings.Join(CreatedExts, ","), err
   316  }
   317  
   318  // TryCreateMetricsHelpers should be called once on daemon startup to try to create "metric fetching helper" functions automatically
   319  func (md *SourceConn) TryCreateMetricsHelpers(ctx context.Context, getSQLFn func(string) string) (err error) {
   320  	md.RLock()
   321  	defer md.RUnlock()
   322  	var sql string
   323  	metrics := maps.Clone(md.Metrics)
   324  	maps.Insert(metrics, maps.All(md.MetricsStandby))
   325  	for metricName := range metrics {
   326  		if sql = getSQLFn(metricName); sql == "" {
   327  			continue
   328  		}
   329  		if _, e := md.Conn.Exec(ctx, sql); e != nil {
   330  			err = errors.Join(err, fmt.Errorf("failed to create helper for metric %s: %w", metricName, e))
   331  		}
   332  	}
   333  	return
   334  }
   335  
   336  func (mds SourceConns) GetMonitoredDatabase(DBUniqueName string) *SourceConn {
   337  	for _, md := range mds {
   338  		if md.Name == DBUniqueName {
   339  			return md
   340  		}
   341  	}
   342  	return nil
   343  }
   344