...

Source file src/github.com/cybertec-postgresql/pgwatch/v5/internal/sinks/postgres.go

Documentation: github.com/cybertec-postgresql/pgwatch/v5/internal/sinks

     1  package sinks
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"errors"
     7  	"fmt"
     8  	"maps"
     9  	"slices"
    10  	"strings"
    11  	"time"
    12  
    13  	jsoniter "github.com/json-iterator/go"
    14  
    15  	"github.com/cybertec-postgresql/pgwatch/v5/internal/db"
    16  	"github.com/cybertec-postgresql/pgwatch/v5/internal/log"
    17  	"github.com/cybertec-postgresql/pgwatch/v5/internal/metrics"
    18  	migrator "github.com/cybertec-postgresql/pgx-migrator"
    19  	"github.com/jackc/pgx/v5"
    20  	"github.com/jackc/pgx/v5/pgconn"
    21  	"github.com/jackc/pgx/v5/pgxpool"
    22  )
    23  
    24  var (
    25  	cacheLimit      = 256
    26  	highLoadTimeout = time.Second * 5
    27  	targetColumns   = [...]string{"time", "dbname", "data", "tag_data"}
    28  )
    29  
    30  //go:embed sql/admin_schema.sql
    31  var sqlMetricAdminSchema string
    32  
    33  //go:embed sql/admin_functions.sql
    34  var sqlMetricAdminFunctions string
    35  
    36  //go:embed sql/ensure_partition_postgres.sql
    37  var sqlMetricEnsurePartitionPostgres string
    38  
    39  //go:embed sql/ensure_partition_timescale.sql
    40  var sqlMetricEnsurePartitionTimescale string
    41  
    42  //go:embed sql/change_chunk_interval.sql
    43  var sqlMetricChangeChunkIntervalTimescale string
    44  
    45  //go:embed sql/change_compression_interval.sql
    46  var sqlMetricChangeCompressionIntervalTimescale string
    47  
    48  var (
    49  	metricSchemaSQLs = []string{
    50  		sqlMetricAdminSchema,
    51  		sqlMetricAdminFunctions,
    52  		sqlMetricEnsurePartitionPostgres,
    53  		sqlMetricEnsurePartitionTimescale,
    54  		sqlMetricChangeChunkIntervalTimescale,
    55  		sqlMetricChangeCompressionIntervalTimescale,
    56  	}
    57  )
    58  
    59  // PostgresWriter is a sink that writes metric measurements to a Postgres database.
    60  // At the moment, it supports both Postgres and TimescaleDB as a storage backend.
    61  // However, one is able to use any Postgres-compatible database as a storage backend,
    62  // e.g. PGEE, Citus, Greenplum, CockroachDB, etc.
    63  type PostgresWriter struct {
    64  	ctx                      context.Context
    65  	sinkDb                   db.PgxPoolIface
    66  	metricSchema             DbStorageSchemaType
    67  	opts                     *CmdOpts
    68  	retentionInterval        time.Duration
    69  	maintenanceInterval      time.Duration
    70  	input                    chan metrics.MeasurementEnvelope
    71  	lastError                chan error
    72  	forceRecreatePartitions  bool                                        // to signal override PG metrics storage cache
    73  	partitionMapMetric       map[string]ExistingPartitionInfo            // metric = min/max bounds
    74  	partitionMapMetricDbname map[string]map[string]ExistingPartitionInfo // metric[dbname = min/max bounds]
    75  }
    76  
    77  // make sure *dbMetricReaderWriter implements the Migrator interface
    78  var _ db.Migrator = (*PostgresWriter)(nil)
    79  
    80  func NewPostgresWriter(ctx context.Context, connstr string, opts *CmdOpts) (pgw *PostgresWriter, err error) {
    81  	var conn db.PgxPoolIface
    82  	if conn, err = db.New(ctx, connstr); err != nil {
    83  		return
    84  	}
    85  	return NewWriterFromPostgresConn(ctx, conn, opts)
    86  }
    87  
    88  var ErrNeedsMigration = errors.New("sink database schema is outdated, please run migrations using `pgwatch config upgrade` command")
    89  
    90  func NewWriterFromPostgresConn(ctx context.Context, conn db.PgxPoolIface, opts *CmdOpts) (pgw *PostgresWriter, err error) {
    91  	l := log.GetLogger(ctx).WithField("sink", "postgres").WithField("db", conn.Config().ConnConfig.Database)
    92  	ctx = log.WithLogger(ctx, l)
    93  	pgw = &PostgresWriter{
    94  		ctx:                      ctx,
    95  		opts:                     opts,
    96  		input:                    make(chan metrics.MeasurementEnvelope, cacheLimit),
    97  		lastError:                make(chan error),
    98  		sinkDb:                   conn,
    99  		forceRecreatePartitions:  false,
   100  		partitionMapMetric:       make(map[string]ExistingPartitionInfo),
   101  		partitionMapMetricDbname: make(map[string]map[string]ExistingPartitionInfo),
   102  	}
   103  	l.Info("initialising measurements database...")
   104  	if err = pgw.init(); err != nil {
   105  		return nil, err
   106  	}
   107  	if err = pgw.ReadMetricSchemaType(); err != nil {
   108  		return nil, err
   109  	}
   110  	if err = pgw.EnsureBuiltinMetricDummies(); err != nil {
   111  		return nil, err
   112  	}
   113  	pgw.scheduleJob(pgw.maintenanceInterval, func() {
   114  		pgw.DeleteOldPartitions()
   115  		pgw.MaintainUniqueSources()
   116  	})
   117  	go pgw.poll()
   118  	l.Info(`measurements sink is activated`)
   119  	return
   120  }
   121  
   122  func (pgw *PostgresWriter) init() (err error) {
   123  	return db.Init(pgw.ctx, pgw.sinkDb, func(ctx context.Context, conn db.PgxIface) error {
   124  		var isValidPartitionInterval bool
   125  		if err = conn.QueryRow(ctx,
   126  			"SELECT extract(epoch from $1::interval), extract(epoch from $2::interval), $3::interval >= '1h'::interval",
   127  			pgw.opts.RetentionInterval, pgw.opts.MaintenanceInterval, pgw.opts.PartitionInterval,
   128  		).Scan(&pgw.retentionInterval, &pgw.maintenanceInterval, &isValidPartitionInterval); err != nil {
   129  			return err
   130  		}
   131  
   132  		// epoch returns seconds but time.Duration represents nanoseconds
   133  		pgw.retentionInterval *= time.Second
   134  		pgw.maintenanceInterval *= time.Second
   135  
   136  		if !isValidPartitionInterval {
   137  			return fmt.Errorf("--partition-interval must be at least 1 hour, got: %s", pgw.opts.PartitionInterval)
   138  		}
   139  		if pgw.maintenanceInterval < 0 {
   140  			return errors.New("--maintenance-interval must be a positive PostgreSQL interval or 0 to disable it")
   141  		}
   142  		if pgw.retentionInterval < time.Hour && pgw.retentionInterval != 0 {
   143  			return errors.New("--retention must be at least 1 hour PostgreSQL interval or 0 to disable it")
   144  		}
   145  
   146  		exists, err := db.DoesSchemaExist(ctx, conn, "admin")
   147  		if err != nil || exists {
   148  			return err
   149  		}
   150  		for _, sql := range metricSchemaSQLs {
   151  			if _, err = conn.Exec(ctx, sql); err != nil {
   152  				return err
   153  			}
   154  		}
   155  		return nil
   156  	})
   157  }
   158  
   159  type ExistingPartitionInfo struct {
   160  	StartTime time.Time
   161  	EndTime   time.Time
   162  }
   163  
   164  type MeasurementMessagePostgres struct {
   165  	Time    time.Time
   166  	DBName  string
   167  	Metric  string
   168  	Data    map[string]any
   169  	TagData map[string]string
   170  }
   171  
   172  type DbStorageSchemaType int
   173  
   174  const (
   175  	DbStorageSchemaPostgres DbStorageSchemaType = iota
   176  	DbStorageSchemaTimescale
   177  )
   178  
   179  func (pgw *PostgresWriter) scheduleJob(interval time.Duration, job func()) {
   180  	if interval > 0 {
   181  		go func() {
   182  			for {
   183  				select {
   184  				case <-pgw.ctx.Done():
   185  					return
   186  				case <-time.After(interval):
   187  					job()
   188  				}
   189  			}
   190  		}()
   191  	}
   192  }
   193  
   194  func (pgw *PostgresWriter) ReadMetricSchemaType() (err error) {
   195  	var isTs bool
   196  	pgw.metricSchema = DbStorageSchemaPostgres
   197  	sqlSchemaType := `SELECT schema_type = 'timescale' FROM admin.storage_schema_type`
   198  	if err = pgw.sinkDb.QueryRow(pgw.ctx, sqlSchemaType).Scan(&isTs); err == nil && isTs {
   199  		pgw.metricSchema = DbStorageSchemaTimescale
   200  	}
   201  	return
   202  }
   203  
   204  // SyncMetric ensures that tables exist for newly added metrics and/or sources
   205  func (pgw *PostgresWriter) SyncMetric(sourceName, metricName string, op SyncOp) error {
   206  	if op == AddOp {
   207  		return errors.Join(
   208  			pgw.AddDBUniqueMetricToListingTable(sourceName, metricName),
   209  			pgw.EnsureMetricDummy(metricName), // ensure that there is at least an empty top-level table not to get ugly Grafana notifications
   210  		)
   211  	}
   212  	return nil
   213  }
   214  
   215  // EnsureBuiltinMetricDummies creates empty tables for all built-in metrics if they don't exist
   216  func (pgw *PostgresWriter) EnsureBuiltinMetricDummies() (err error) {
   217  	for _, name := range metrics.GetDefaultBuiltInMetrics() {
   218  		err = errors.Join(err, pgw.EnsureMetricDummy(name))
   219  	}
   220  	return
   221  }
   222  
   223  // EnsureMetricDummy creates an empty table for a metric measurements if it doesn't exist
   224  func (pgw *PostgresWriter) EnsureMetricDummy(metric string) (err error) {
   225  	_, err = pgw.sinkDb.Exec(pgw.ctx, "SELECT admin.ensure_dummy_metrics_table($1)", metric)
   226  	return
   227  }
   228  
   229  // Write sends the measurements to the cache channel
   230  func (pgw *PostgresWriter) Write(msg metrics.MeasurementEnvelope) error {
   231  	if pgw.ctx.Err() != nil {
   232  		return pgw.ctx.Err()
   233  	}
   234  	select {
   235  	case pgw.input <- msg:
   236  		// msgs sent
   237  	case <-time.After(highLoadTimeout):
   238  		// msgs dropped due to a huge load, check stdout or file for detailed log
   239  	}
   240  	select {
   241  	case err := <-pgw.lastError:
   242  		return err
   243  	default:
   244  		return nil
   245  	}
   246  }
   247  
   248  // poll is the main loop that reads from the input channel and flushes the data to the database
   249  func (pgw *PostgresWriter) poll() {
   250  	cache := make([]metrics.MeasurementEnvelope, 0, cacheLimit)
   251  	cacheTimeout := pgw.opts.BatchingDelay
   252  	tick := time.NewTicker(cacheTimeout)
   253  	for {
   254  		select {
   255  		case <-pgw.ctx.Done(): //check context with high priority
   256  			return
   257  		default:
   258  			select {
   259  			case entry := <-pgw.input:
   260  				cache = append(cache, entry)
   261  				if len(cache) < cacheLimit {
   262  					break
   263  				}
   264  				tick.Stop()
   265  				pgw.flush(cache)
   266  				cache = cache[:0]
   267  				tick = time.NewTicker(cacheTimeout)
   268  			case <-tick.C:
   269  				pgw.flush(cache)
   270  				cache = cache[:0]
   271  			case <-pgw.ctx.Done():
   272  				return
   273  			}
   274  		}
   275  	}
   276  }
   277  
   278  func newCopyFromMeasurements(rows []metrics.MeasurementEnvelope) *copyFromMeasurements {
   279  	return &copyFromMeasurements{envelopes: rows, envelopeIdx: -1, measurementIdx: -1}
   280  }
   281  
   282  type copyFromMeasurements struct {
   283  	envelopes      []metrics.MeasurementEnvelope
   284  	envelopeIdx    int
   285  	measurementIdx int // index of the current measurement in the envelope
   286  	metricName     string
   287  }
   288  
   289  func (c *copyFromMeasurements) NextEnvelope() bool {
   290  	c.envelopeIdx++
   291  	c.measurementIdx = -1
   292  	return c.envelopeIdx < len(c.envelopes)
   293  }
   294  
   295  func (c *copyFromMeasurements) Next() bool {
   296  	for {
   297  		// Check if we need to advance to the next envelope
   298  		if c.envelopeIdx < 0 || c.measurementIdx+1 >= len(c.envelopes[c.envelopeIdx].Data) {
   299  			// Advance to next envelope
   300  			if ok := c.NextEnvelope(); !ok {
   301  				return false // No more envelopes
   302  			}
   303  			// Set metric name from first envelope, or detect metric boundary
   304  			if c.metricName == "" {
   305  				c.metricName = c.envelopes[c.envelopeIdx].MetricName
   306  			} else if c.metricName != c.envelopes[c.envelopeIdx].MetricName {
   307  				// We've hit a different metric - we're done with current metric
   308  				// Reset position to process this envelope on next call
   309  				c.envelopeIdx--
   310  				c.measurementIdx = len(c.envelopes[c.envelopeIdx].Data) // Set to length so we've "finished" this envelope
   311  				c.metricName = ""                                       // Reset for next metric
   312  				return false
   313  			}
   314  		}
   315  
   316  		// Advance to next measurement in current envelope
   317  		c.measurementIdx++
   318  		if c.measurementIdx < len(c.envelopes[c.envelopeIdx].Data) {
   319  			return true // Found valid measurement
   320  		}
   321  		// If we reach here, we've exhausted current envelope, loop will advance to next envelope
   322  	}
   323  }
   324  
   325  func (c *copyFromMeasurements) EOF() bool {
   326  	return c.envelopeIdx >= len(c.envelopes)
   327  }
   328  
   329  func (c *copyFromMeasurements) Values() ([]any, error) {
   330  	row := maps.Clone(c.envelopes[c.envelopeIdx].Data[c.measurementIdx])
   331  	tagRow := maps.Clone(c.envelopes[c.envelopeIdx].CustomTags)
   332  	if tagRow == nil {
   333  		tagRow = make(map[string]string)
   334  	}
   335  	for k, v := range row {
   336  		if after, ok := strings.CutPrefix(k, metrics.TagPrefix); ok {
   337  			tagRow[after] = fmt.Sprintf("%v", v)
   338  			delete(row, k)
   339  		}
   340  	}
   341  	jsonTags, terr := jsoniter.ConfigFastest.MarshalToString(tagRow)
   342  	json, err := jsoniter.ConfigFastest.MarshalToString(row)
   343  	if err != nil || terr != nil {
   344  		return nil, errors.Join(err, terr)
   345  	}
   346  	return []any{time.Unix(0, c.envelopes[c.envelopeIdx].Data.GetEpoch()), c.envelopes[c.envelopeIdx].DBName, json, jsonTags}, nil
   347  }
   348  
   349  func (c *copyFromMeasurements) Err() error {
   350  	return nil
   351  }
   352  
   353  func (c *copyFromMeasurements) MetricName() (ident pgx.Identifier) {
   354  	if c.envelopeIdx+1 < len(c.envelopes) {
   355  		// Metric name is taken from the next envelope
   356  		ident = pgx.Identifier{c.envelopes[c.envelopeIdx+1].MetricName}
   357  	}
   358  	return
   359  }
   360  
   361  // flush sends the cached measurements to the database
   362  func (pgw *PostgresWriter) flush(msgs []metrics.MeasurementEnvelope) {
   363  	if len(msgs) == 0 {
   364  		return
   365  	}
   366  	logger := log.GetLogger(pgw.ctx)
   367  	pgPartBounds := make(map[string]ExistingPartitionInfo)                  // metric=min/max
   368  	pgPartBoundsDbName := make(map[string]map[string]ExistingPartitionInfo) // metric=[dbname=min/max]
   369  	var err error
   370  
   371  	slices.SortFunc(msgs, func(a, b metrics.MeasurementEnvelope) int {
   372  		if a.MetricName < b.MetricName {
   373  			return -1
   374  		} else if a.MetricName > b.MetricName {
   375  			return 1
   376  		}
   377  		return 0
   378  	})
   379  
   380  	for _, msg := range msgs {
   381  		for _, dataRow := range msg.Data {
   382  			epochTime := time.Unix(0, metrics.Measurement(dataRow).GetEpoch())
   383  			switch pgw.metricSchema {
   384  			case DbStorageSchemaTimescale:
   385  				// set min/max timestamps to check/create partitions
   386  				bounds, ok := pgPartBounds[msg.MetricName]
   387  				if !ok || (ok && epochTime.Before(bounds.StartTime)) {
   388  					bounds.StartTime = epochTime
   389  					pgPartBounds[msg.MetricName] = bounds
   390  				}
   391  				if !ok || (ok && epochTime.After(bounds.EndTime)) {
   392  					bounds.EndTime = epochTime
   393  					pgPartBounds[msg.MetricName] = bounds
   394  				}
   395  			case DbStorageSchemaPostgres:
   396  				_, ok := pgPartBoundsDbName[msg.MetricName]
   397  				if !ok {
   398  					pgPartBoundsDbName[msg.MetricName] = make(map[string]ExistingPartitionInfo)
   399  				}
   400  				bounds, ok := pgPartBoundsDbName[msg.MetricName][msg.DBName]
   401  				if !ok || (ok && epochTime.Before(bounds.StartTime)) {
   402  					bounds.StartTime = epochTime
   403  					pgPartBoundsDbName[msg.MetricName][msg.DBName] = bounds
   404  				}
   405  				if !ok || (ok && epochTime.After(bounds.EndTime)) {
   406  					bounds.EndTime = epochTime
   407  					pgPartBoundsDbName[msg.MetricName][msg.DBName] = bounds
   408  				}
   409  			default:
   410  				logger.Fatal("unknown storage schema...")
   411  			}
   412  		}
   413  	}
   414  
   415  	switch pgw.metricSchema {
   416  	case DbStorageSchemaPostgres:
   417  		err = pgw.EnsureMetricDbnameTime(pgPartBoundsDbName)
   418  	case DbStorageSchemaTimescale:
   419  		err = pgw.EnsureMetricTimescale(pgPartBounds)
   420  	default:
   421  		logger.Fatal("unknown storage schema...")
   422  	}
   423  	pgw.forceRecreatePartitions = false
   424  	if err != nil {
   425  		pgw.lastError <- err
   426  	}
   427  
   428  	var rowsBatched, n int64
   429  	t1 := time.Now()
   430  	cfm := newCopyFromMeasurements(msgs)
   431  	for !cfm.EOF() {
   432  		n, err = pgw.sinkDb.CopyFrom(context.Background(), cfm.MetricName(), targetColumns[:], cfm)
   433  		rowsBatched += n
   434  		if err != nil {
   435  			logger.Error(err)
   436  			if PgError, ok := err.(*pgconn.PgError); ok {
   437  				pgw.forceRecreatePartitions = PgError.Code == "23514"
   438  			}
   439  			if pgw.forceRecreatePartitions {
   440  				logger.Warning("Some metric partitions might have been removed, halting all metric storage. Trying to re-create all needed partitions on next run")
   441  			}
   442  		}
   443  	}
   444  	diff := time.Since(t1)
   445  	if err == nil {
   446  		logger.WithField("rows", rowsBatched).WithField("elapsed", diff).Info("measurements written")
   447  		return
   448  	}
   449  	pgw.lastError <- err
   450  }
   451  
   452  func (pgw *PostgresWriter) EnsureMetricTimescale(pgPartBounds map[string]ExistingPartitionInfo) (err error) {
   453  	logger := log.GetLogger(pgw.ctx)
   454  	sqlEnsure := `select * from admin.ensure_partition_timescale($1)`
   455  	for metric := range pgPartBounds {
   456  		if _, ok := pgw.partitionMapMetric[metric]; !ok {
   457  			if _, err = pgw.sinkDb.Exec(pgw.ctx, sqlEnsure, metric); err != nil {
   458  				logger.Errorf("Failed to create a TimescaleDB table for metric '%s': %v", metric, err)
   459  				return err
   460  			}
   461  			pgw.partitionMapMetric[metric] = ExistingPartitionInfo{}
   462  		}
   463  	}
   464  	return
   465  }
   466  
   467  func (pgw *PostgresWriter) EnsureMetricDbnameTime(metricDbnamePartBounds map[string]map[string]ExistingPartitionInfo) (err error) {
   468  	var rows pgx.Rows
   469  	sqlEnsure := `select * from admin.ensure_partition_metric_dbname_time($1, $2, $3, $4)`
   470  	for metric, dbnameTimestampMap := range metricDbnamePartBounds {
   471  		_, ok := pgw.partitionMapMetricDbname[metric]
   472  		if !ok {
   473  			pgw.partitionMapMetricDbname[metric] = make(map[string]ExistingPartitionInfo)
   474  		}
   475  
   476  		for dbname, pb := range dbnameTimestampMap {
   477  			if pb.StartTime.IsZero() || pb.EndTime.IsZero() {
   478  				return fmt.Errorf("zero StartTime/EndTime in partitioning request: [%s:%v]", metric, pb)
   479  			}
   480  			partInfo, ok := pgw.partitionMapMetricDbname[metric][dbname]
   481  			if !ok || (ok && (pb.StartTime.Before(partInfo.StartTime))) || pgw.forceRecreatePartitions {
   482  				if rows, err = pgw.sinkDb.Query(pgw.ctx, sqlEnsure, metric, dbname, pb.StartTime, pgw.opts.PartitionInterval); err != nil {
   483  					return
   484  				}
   485  				if partInfo, err = pgx.CollectOneRow(rows, pgx.RowToStructByPos[ExistingPartitionInfo]); err != nil {
   486  					return err
   487  				}
   488  				pgw.partitionMapMetricDbname[metric][dbname] = partInfo
   489  			}
   490  			if pb.EndTime.After(partInfo.EndTime) || pb.EndTime.Equal(partInfo.EndTime) || pgw.forceRecreatePartitions {
   491  				if rows, err = pgw.sinkDb.Query(pgw.ctx, sqlEnsure, metric, dbname, pb.EndTime, pgw.opts.PartitionInterval); err != nil {
   492  					return
   493  				}
   494  				if partInfo, err = pgx.CollectOneRow(rows, pgx.RowToStructByPos[ExistingPartitionInfo]); err != nil {
   495  					return err
   496  				}
   497  				pgw.partitionMapMetricDbname[metric][dbname] = partInfo
   498  			}
   499  		}
   500  	}
   501  	return nil
   502  }
   503  
   504  // DeleteOldPartitions is a background task that deletes old partitions from the measurements DB
   505  func (pgw *PostgresWriter) DeleteOldPartitions() {
   506  	l := log.GetLogger(pgw.ctx)
   507  	var partsDropped int
   508  	err := pgw.sinkDb.QueryRow(pgw.ctx, `SELECT admin.drop_old_time_partitions(older_than => $1::interval)`,
   509  		pgw.opts.RetentionInterval).Scan(&partsDropped)
   510  	if err != nil {
   511  		l.Error("Could not drop old time partitions:", err)
   512  	} else if partsDropped > 0 {
   513  		l.Infof("Dropped %d old time partitions", partsDropped)
   514  	}
   515  }
   516  
   517  // MaintainUniqueSources is a background task that maintains a mapping of unique sources
   518  // in each metric table in admin.all_distinct_dbname_metrics.
   519  // This is used to avoid listing the same source multiple times in Grafana dropdowns.
   520  func (pgw *PostgresWriter) MaintainUniqueSources() {
   521  	logger := log.GetLogger(pgw.ctx)
   522  	var rowsAffected int
   523  	if err := pgw.sinkDb.QueryRow(pgw.ctx, `SELECT admin.maintain_unique_sources()`).Scan(&rowsAffected); err != nil {
   524  		logger.Error("Failed to run admin.all_distinct_dbname_metrics maintenance:", err)
   525  		return
   526  	}
   527  	logger.WithField("rows", rowsAffected).Info("Successfully processed admin.all_distinct_dbname_metrics")
   528  }
   529  
   530  func (pgw *PostgresWriter) AddDBUniqueMetricToListingTable(dbUnique, metric string) error {
   531  	sql := `INSERT INTO admin.all_distinct_dbname_metrics
   532  			SELECT $1, $2
   533  			WHERE NOT EXISTS (
   534  				SELECT * FROM admin.all_distinct_dbname_metrics WHERE dbname = $1 AND metric = $2
   535  			)`
   536  	_, err := pgw.sinkDb.Exec(pgw.ctx, sql, dbUnique, metric)
   537  	return err
   538  }
   539  
   540  func NewPostgresSinkMigrator(ctx context.Context, connStr string) (db.Migrator, error) {
   541  	conn, err := pgxpool.New(ctx, connStr)
   542  	if err != nil {
   543  		return nil, err
   544  	}
   545  	pgw := &PostgresWriter{
   546  		ctx:    ctx,
   547  		sinkDb: conn,
   548  	}
   549  	exists, err := db.DoesSchemaExist(ctx, conn, "admin")
   550  	if err != nil {
   551  		return nil, err
   552  	}
   553  	if exists {
   554  		return pgw, nil
   555  	}
   556  	for _, sql := range metricSchemaSQLs {
   557  		if _, err = conn.Exec(ctx, sql); err != nil {
   558  			return nil, err
   559  		}
   560  	}
   561  	return pgw, nil
   562  }
   563  
   564  var initMigrator = func(pgw *PostgresWriter) (*migrator.Migrator, error) {
   565  	return migrator.New(
   566  		migrator.TableName("admin.migration"),
   567  		migrator.SetNotice(func(s string) {
   568  			log.GetLogger(pgw.ctx).Info(s)
   569  		}),
   570  		migrations(),
   571  	)
   572  }
   573  
   574  // Migrate upgrades database with all migrations
   575  func (pgw *PostgresWriter) Migrate() error {
   576  	m, err := initMigrator(pgw)
   577  	if err != nil {
   578  		return fmt.Errorf("cannot initialize migration: %w", err)
   579  	}
   580  	return m.Migrate(pgw.ctx, pgw.sinkDb)
   581  }
   582  
   583  // NeedsMigration checks if database needs migration
   584  func (pgw *PostgresWriter) NeedsMigration() (bool, error) {
   585  	m, err := initMigrator(pgw)
   586  	if err != nil {
   587  		return false, err
   588  	}
   589  	return m.NeedUpgrade(pgw.ctx, pgw.sinkDb)
   590  }
   591  
   592  // MigrationsCount is the total number of migrations in admin.migration table
   593  const MigrationsCount = 1
   594  
   595  // migrations holds function returning all upgrade migrations needed
   596  var migrations func() migrator.Option = func() migrator.Option {
   597  	return migrator.Migrations(
   598  		&migrator.Migration{
   599  			Name: "01110 Apply postgres sink schema migrations",
   600  			Func: func(context.Context, pgx.Tx) error {
   601  				// "migration" table will be created automatically
   602  				return nil
   603  			},
   604  		},
   605  
   606  		&migrator.Migration{
   607  			Name: "01180 Apply admin functions migrations for v5",
   608  			Func: func(ctx context.Context, tx pgx.Tx) error {
   609  				_, err := tx.Exec(ctx, `
   610  					DROP FUNCTION IF EXISTS admin.ensure_partition_metric_dbname_time;
   611  					DROP FUNCTION IF EXISTS admin.ensure_partition_metric_time;
   612  					DROP FUNCTION IF EXISTS admin.get_old_time_partitions(integer, text);
   613  					DROP FUNCTION IF EXISTS admin.drop_old_time_partitions(integer, boolean, text);
   614  				`)
   615  				if err != nil {
   616  					return err
   617  				}
   618  
   619  				_, err = tx.Exec(ctx, sqlMetricEnsurePartitionPostgres)
   620  				if err != nil {
   621  					return err
   622  				}
   623  				_, err = tx.Exec(ctx, sqlMetricAdminFunctions)
   624  				return err
   625  			},
   626  		},
   627  
   628  		// adding new migration here, update "admin"."migration" in "admin_schema.sql"!
   629  
   630  		// &migrator.Migration{
   631  		// 	Name: "000XX Short description of a migration",
   632  		// 	Func: func(ctx context.Context, tx pgx.Tx) error {
   633  		// 		return executeMigrationScript(ctx, tx, "000XX.sql")
   634  		// 	},
   635  		// },
   636  	)
   637  }
   638