...

Source file src/github.com/cybertec-postgresql/pgwatch/v5/internal/sinks/postgres_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v5/internal/sinks

     1  package sinks
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/cybertec-postgresql/pgwatch/v5/internal/log"
    11  	"github.com/cybertec-postgresql/pgwatch/v5/internal/metrics"
    12  	"github.com/cybertec-postgresql/pgwatch/v5/internal/testutil"
    13  	"github.com/jackc/pgx/v5"
    14  	jsoniter "github.com/json-iterator/go"
    15  	"github.com/pashagolub/pgxmock/v4"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  )
    19  
    20  var ctx = log.WithLogger(context.Background(), log.NewNoopLogger())
    21  
    22  func TestReadMetricSchemaType(t *testing.T) {
    23  	conn, err := pgxmock.NewPool()
    24  	assert.NoError(t, err)
    25  
    26  	pgw := PostgresWriter{
    27  		ctx:    ctx,
    28  		sinkDb: conn,
    29  	}
    30  
    31  	conn.ExpectQuery("SELECT schema_type").
    32  		WillReturnError(errors.New("expected"))
    33  	assert.Error(t, pgw.ReadMetricSchemaType())
    34  
    35  	conn.ExpectQuery("SELECT schema_type").
    36  		WillReturnRows(pgxmock.NewRows([]string{"schema_type"}).AddRow(true))
    37  	assert.NoError(t, pgw.ReadMetricSchemaType())
    38  	assert.Equal(t, DbStorageSchemaTimescale, pgw.metricSchema)
    39  }
    40  
    41  func TestNewWriterFromPostgresConn(t *testing.T) {
    42  	a := assert.New(t)
    43  	opts := &CmdOpts{
    44  		BatchingDelay:       time.Hour,
    45  		RetentionInterval:   "1 day",
    46  		MaintenanceInterval: "1 day",
    47  		PartitionInterval:   "1 hour",
    48  	}
    49  
    50  	t.Run("Success", func(*testing.T) {
    51  		conn, err := pgxmock.NewPool()
    52  		a.NoError(err)
    53  
    54  		conn.ExpectPing()
    55  		conn.ExpectQuery("SELECT extract").WithArgs("1 day", "1 day", "1 hour").WillReturnRows(
    56  			pgxmock.NewRows([]string{"col1", "col2", "col3"}).AddRow((24 * time.Hour).Seconds(), (24 * time.Hour).Seconds(), true),
    57  		)
    58  		conn.ExpectQuery("SELECT EXISTS").WithArgs("admin").WillReturnRows(pgxmock.NewRows([]string{"schema_type"}).AddRow(true))
    59  		conn.ExpectQuery("SELECT schema_type").WillReturnRows(pgxmock.NewRows([]string{"schema_type"}).AddRow(true))
    60  		for _, m := range metrics.GetDefaultBuiltInMetrics() {
    61  			conn.ExpectExec("SELECT admin.ensure_dummy_metrics_table").WithArgs(m).WillReturnResult(pgxmock.NewResult("EXECUTE", 1))
    62  		}
    63  
    64  		pgw, err := NewWriterFromPostgresConn(ctx, conn, opts)
    65  		a.NoError(err)
    66  		a.NotNil(pgw)
    67  		a.NoError(conn.ExpectationsWereMet())
    68  	})
    69  
    70  	t.Run("InitFail", func(*testing.T) {
    71  		conn, err := pgxmock.NewPool()
    72  		a.NoError(err)
    73  
    74  		conn.ExpectPing().WillReturnError(assert.AnError)
    75  
    76  		pgw, err := NewWriterFromPostgresConn(ctx, conn, opts)
    77  		a.Error(err)
    78  		a.Nil(pgw)
    79  		a.NoError(conn.ExpectationsWereMet())
    80  	})
    81  
    82  
    83  
    84  	t.Run("ReadMetricSchemaTypeFail", func(*testing.T) {
    85  		conn, err := pgxmock.NewPool()
    86  		a.NoError(err)
    87  
    88  		conn.ExpectPing()
    89  		conn.ExpectQuery("SELECT extract").WithArgs("1 day", "1 day", "1 hour").WillReturnRows(
    90  			pgxmock.NewRows([]string{"col1", "col2", "col3"}).AddRow((24 * time.Hour).Seconds(), (24 * time.Hour).Seconds(), true),
    91  		)
    92  		conn.ExpectQuery("SELECT EXISTS").WithArgs("admin").WillReturnRows(pgxmock.NewRows([]string{"schema_type"}).AddRow(true))
    93  		conn.ExpectQuery("SELECT schema_type").WillReturnError(assert.AnError)
    94  
    95  		pgw, err := NewWriterFromPostgresConn(ctx, conn, opts)
    96  		a.Error(err)
    97  		a.Nil(pgw)
    98  		a.NoError(conn.ExpectationsWereMet())
    99  	})
   100  
   101  	t.Run("EnsureBuiltinMetricDummiesFail", func(*testing.T) {
   102  		conn, err := pgxmock.NewPool()
   103  		a.NoError(err)
   104  
   105  		conn.ExpectPing()
   106  		conn.ExpectQuery("SELECT extract").WithArgs("1 day", "1 day", "1 hour").WillReturnRows(
   107  			pgxmock.NewRows([]string{"col1", "col2", "col3"}).AddRow((24 * time.Hour).Seconds(), (24 * time.Hour).Seconds(), true),
   108  		)
   109  		conn.ExpectQuery("SELECT EXISTS").WithArgs("admin").WillReturnRows(pgxmock.NewRows([]string{"schema_type"}).AddRow(true))
   110  		conn.ExpectQuery("SELECT schema_type").WillReturnRows(pgxmock.NewRows([]string{"schema_type"}).AddRow(true))
   111  		conn.ExpectExec("SELECT admin.ensure_dummy_metrics_table").WithArgs(pgxmock.AnyArg()).WillReturnError(assert.AnError)
   112  
   113  		pgw, err := NewWriterFromPostgresConn(ctx, conn, opts)
   114  		a.Error(err)
   115  		a.Nil(pgw)
   116  		a.NoError(conn.ExpectationsWereMet())
   117  	})
   118  }
   119  
   120  func TestSyncMetric(t *testing.T) {
   121  	conn, err := pgxmock.NewPool()
   122  	assert.NoError(t, err)
   123  	pgw := PostgresWriter{
   124  		ctx:    ctx,
   125  		sinkDb: conn,
   126  	}
   127  	dbUnique := "mydb"
   128  	metricName := "mymetric"
   129  	op := AddOp
   130  	conn.ExpectExec("INSERT INTO admin\\.all_distinct_dbname_metrics").WithArgs(dbUnique, metricName).WillReturnResult(pgxmock.NewResult("EXECUTE", 1))
   131  	conn.ExpectExec("SELECT admin\\.ensure_dummy_metrics_table").WithArgs(metricName).WillReturnResult(pgxmock.NewResult("EXECUTE", 1))
   132  	err = pgw.SyncMetric(dbUnique, metricName, op)
   133  	assert.NoError(t, err)
   134  	assert.NoError(t, conn.ExpectationsWereMet())
   135  
   136  	op = InvalidOp
   137  	err = pgw.SyncMetric(dbUnique, metricName, op)
   138  	assert.NoError(t, err, "ignore unknown operation")
   139  }
   140  
   141  func TestWrite(t *testing.T) {
   142  	conn, err := pgxmock.NewPool()
   143  	assert.NoError(t, err)
   144  	ctx, cancel := context.WithCancel(ctx)
   145  	pgw := PostgresWriter{
   146  		ctx:    ctx,
   147  		sinkDb: conn,
   148  	}
   149  	message := metrics.MeasurementEnvelope{
   150  		MetricName: "test_metric",
   151  		Data: metrics.Measurements{
   152  			{"number": 1, "string": "test_data"},
   153  		},
   154  		DBName:     "test_db",
   155  		CustomTags: map[string]string{"foo": "boo"},
   156  	}
   157  
   158  	highLoadTimeout = 0
   159  	err = pgw.Write(message)
   160  	assert.NoError(t, err, "messages skipped due to high load")
   161  
   162  	highLoadTimeout = time.Second * 5
   163  	pgw.input = make(chan metrics.MeasurementEnvelope, cacheLimit)
   164  	err = pgw.Write(message)
   165  	assert.NoError(t, err, "write successful")
   166  
   167  	cancel()
   168  	err = pgw.Write(message)
   169  	assert.Error(t, err, "context canceled")
   170  }
   171  
   172  func TestCopyFromMeasurements_Basic(t *testing.T) {
   173  	// Test basic iteration through single envelope with multiple measurements
   174  	data := []metrics.MeasurementEnvelope{
   175  		{
   176  			MetricName: "metric1",
   177  			DBName:     "db1",
   178  			CustomTags: map[string]string{"env": "test"},
   179  			Data: metrics.Measurements{
   180  				{"epoch_ns": int64(1000), "value": 1},
   181  				{"epoch_ns": int64(2000), "value": 2},
   182  				{"epoch_ns": int64(3000), "value": 3},
   183  			},
   184  		},
   185  	}
   186  
   187  	cfm := newCopyFromMeasurements(data)
   188  
   189  	// Test Next() and Values() for each measurement
   190  	assert.Equal(t, "metric1", cfm.MetricName()[0], "Metric name should be obtained before Next()")
   191  	assert.True(t, cfm.Next(), "Should have first measurement")
   192  	values, err := cfm.Values()
   193  	assert.NoError(t, err)
   194  	assert.Len(t, values, 4) // time, dbname, data, tag_data
   195  	assert.Equal(t, "db1", values[1])
   196  
   197  	assert.True(t, cfm.Next(), "Should have second measurement")
   198  	values, err = cfm.Values()
   199  	assert.NoError(t, err)
   200  	assert.Equal(t, "db1", values[1])
   201  
   202  	assert.True(t, cfm.Next(), "Should have third measurement")
   203  	values, err = cfm.Values()
   204  	assert.NoError(t, err)
   205  	assert.Equal(t, "db1", values[1])
   206  
   207  	assert.False(t, cfm.Next(), "Should not have more measurements")
   208  	assert.True(t, cfm.EOF(), "Should be at end")
   209  }
   210  
   211  func TestCopyFromMeasurements_MultipleEnvelopes(t *testing.T) {
   212  	// Test iteration through multiple envelopes of same metric
   213  	data := []metrics.MeasurementEnvelope{
   214  		{
   215  			MetricName: "metric1",
   216  			DBName:     "db1",
   217  			CustomTags: map[string]string{"env": "test1"},
   218  			Data: metrics.Measurements{
   219  				{"epoch_ns": int64(1000), "value": 1},
   220  				{"epoch_ns": int64(2000), "value": 2},
   221  			},
   222  		},
   223  		{
   224  			MetricName: "metric1",
   225  			DBName:     "db2",
   226  			CustomTags: map[string]string{"env": "test2"},
   227  			Data: metrics.Measurements{
   228  				{"epoch_ns": int64(3000), "value": 3},
   229  			},
   230  		},
   231  	}
   232  
   233  	cfm := newCopyFromMeasurements(data)
   234  
   235  	// First envelope, first measurement
   236  	assert.True(t, cfm.Next())
   237  	values, err := cfm.Values()
   238  	assert.NoError(t, err)
   239  	assert.Equal(t, "db1", values[1])
   240  	// First envelope, second measurement
   241  	assert.True(t, cfm.Next())
   242  	values, err = cfm.Values()
   243  	assert.NoError(t, err)
   244  	assert.Equal(t, "db1", values[1])
   245  
   246  	// Second envelope, first measurement
   247  	assert.Equal(t, "metric1", cfm.MetricName()[0])
   248  	assert.True(t, cfm.Next())
   249  	values, err = cfm.Values()
   250  	assert.NoError(t, err)
   251  	assert.Equal(t, "db2", values[1])
   252  
   253  	assert.False(t, cfm.Next())
   254  }
   255  
   256  func TestCopyFromMeasurements_MetricBoundaries(t *testing.T) {
   257  	// Test metric boundary detection with different metrics
   258  	data := []metrics.MeasurementEnvelope{
   259  		{
   260  			MetricName: "metric1",
   261  			DBName:     "db1",
   262  			CustomTags: map[string]string{},
   263  			Data: metrics.Measurements{
   264  				{"epoch_ns": int64(1000), "value": 1},
   265  				{"epoch_ns": int64(2000), "value": 2},
   266  			},
   267  		},
   268  		{
   269  			MetricName: "metric2", // Different metric
   270  			DBName:     "db1",
   271  			CustomTags: map[string]string{},
   272  			Data: metrics.Measurements{
   273  				{"epoch_ns": int64(3000), "value": 3},
   274  			},
   275  		},
   276  		{
   277  			MetricName: "metric2",
   278  			DBName:     "db2",
   279  			CustomTags: map[string]string{},
   280  			Data: metrics.Measurements{
   281  				{"epoch_ns": int64(4000), "value": 4},
   282  			},
   283  		},
   284  	}
   285  
   286  	cfm := newCopyFromMeasurements(data)
   287  
   288  	// Process metric1 completely
   289  	assert.Equal(t, "metric1", cfm.MetricName()[0])
   290  	assert.True(t, cfm.Next())
   291  	assert.True(t, cfm.Next())
   292  
   293  	// Should stop at metric boundary
   294  	assert.False(t, cfm.Next())
   295  	assert.False(t, cfm.EOF(), "Should not be at EOF yet, there's more data")
   296  
   297  	assert.Equal(t, "metric2", cfm.MetricName()[0])
   298  	assert.True(t, cfm.Next())
   299  	assert.True(t, cfm.Next())
   300  
   301  	assert.False(t, cfm.Next())
   302  	assert.True(t, cfm.EOF(), "Should be at EOF after processing all measurements")
   303  }
   304  
   305  func TestCopyFromMeasurements_EmptyData(t *testing.T) {
   306  	// Test with empty envelopes slice
   307  	cfm := newCopyFromMeasurements([]metrics.MeasurementEnvelope{})
   308  	assert.False(t, cfm.Next())
   309  	assert.True(t, cfm.EOF())
   310  }
   311  
   312  func TestCopyFromMeasurements_EmptyMeasurements(t *testing.T) {
   313  	// Test with envelope containing no measurements
   314  	data := []metrics.MeasurementEnvelope{
   315  		{
   316  			MetricName: "metric1",
   317  			DBName:     "db1",
   318  			CustomTags: map[string]string{},
   319  			Data:       metrics.Measurements{}, // Empty measurements
   320  		},
   321  		{
   322  			MetricName: "metric1",
   323  			DBName:     "db2",
   324  			CustomTags: map[string]string{},
   325  			Data: metrics.Measurements{
   326  				{"epoch_ns": int64(1000), "value": 1},
   327  			},
   328  		},
   329  	}
   330  
   331  	cfm := newCopyFromMeasurements(data)
   332  
   333  	// Should skip empty envelope and go to second one
   334  	assert.True(t, cfm.Next())
   335  	values, err := cfm.Values()
   336  	assert.NoError(t, err)
   337  	assert.Equal(t, "db2", values[1])
   338  
   339  	assert.False(t, cfm.Next())
   340  	assert.True(t, cfm.EOF())
   341  }
   342  
   343  func TestCopyFromMeasurements_TagProcessing(t *testing.T) {
   344  	// Test that tag_ prefixed fields are moved to CustomTags
   345  	data := []metrics.MeasurementEnvelope{
   346  		{
   347  			MetricName: "metric1",
   348  			DBName:     "db1",
   349  			CustomTags: map[string]string{"existing": "tag"},
   350  			Data: metrics.Measurements{
   351  				{
   352  					"epoch_ns":     int64(1000),
   353  					"value":        1,
   354  					"tag_env":      "production",
   355  					"tag_version":  "1.0",
   356  					"normal_field": "stays",
   357  				},
   358  			},
   359  		},
   360  		{
   361  			MetricName: "metric1",
   362  			DBName:     "db2",
   363  			CustomTags: nil,
   364  			Data: metrics.Measurements{
   365  				{
   366  					"epoch_ns":     int64(1000),
   367  					"value":        1,
   368  					"tag_env":      "production",
   369  					"tag_version":  "1.0",
   370  					"normal_field": "stays",
   371  				},
   372  			},
   373  		},
   374  	}
   375  
   376  	cfm := newCopyFromMeasurements(data)
   377  	assert.True(t, cfm.Next())
   378  
   379  	values, err := cfm.Values()
   380  	assert.NoError(t, err)
   381  	assert.Len(t, values, 4) // Verify structure: time, dbname, data, tag_data
   382  
   383  	// Check that custom tags were updated
   384  	// Check data JSON (should contain normal fields but not tag_ fields)
   385  	dataJSON, ok := values[2].(string)
   386  	assert.True(t, ok, "Data should be JSON string")
   387  
   388  	var dataMap map[string]any
   389  	err = jsoniter.ConfigFastest.UnmarshalFromString(dataJSON, &dataMap)
   390  	assert.NoError(t, err)
   391  	assert.Contains(t, dataMap, "normal_field")
   392  	assert.NotContains(t, dataMap, "tag_env", "tag_env should not be in data")
   393  	assert.NotContains(t, dataMap, "tag_version", "tag_version should not be in data")
   394  
   395  	// Check tag JSON (should contain converted tags)
   396  	tagJSON, ok := values[3].(string)
   397  	assert.True(t, ok, "Tag data should be JSON string")
   398  
   399  	var tagMap map[string]string
   400  	err = jsoniter.ConfigFastest.UnmarshalFromString(tagJSON, &tagMap)
   401  	assert.NoError(t, err)
   402  	assert.Contains(t, tagMap, "existing")
   403  	assert.Contains(t, tagMap, "env", "tag_env should be converted to env")
   404  	assert.Contains(t, tagMap, "version", "tag_version should be converted to version")
   405  	assert.Equal(t, "production", tagMap["env"])
   406  	assert.Equal(t, "1.0", tagMap["version"])
   407  
   408  	assert.True(t, cfm.Next())
   409  	_, err = cfm.Values()
   410  	assert.NoError(t, err, "should process nil CustomTags without error")
   411  }
   412  
   413  func TestCopyFromMeasurements_JsonMarshaling(t *testing.T) {
   414  	// Test that JSON marshaling works correctly
   415  	data := []metrics.MeasurementEnvelope{
   416  		{
   417  			MetricName: "metric1",
   418  			DBName:     "db1",
   419  			CustomTags: map[string]string{"env": "test"},
   420  			Data: metrics.Measurements{
   421  				{
   422  					"epoch_ns": int64(1000),
   423  					"value":    42,
   424  					"name":     "test_measurement",
   425  				},
   426  				{
   427  					"epoch_ns": int64(1000),
   428  					"value": func() string {
   429  						return "should produce error while marshaled"
   430  					},
   431  					"name": "test_measurement",
   432  				},
   433  			},
   434  		},
   435  	}
   436  
   437  	cfm := newCopyFromMeasurements(data)
   438  	assert.True(t, cfm.Next())
   439  
   440  	values, err := cfm.Values()
   441  	assert.NoError(t, err)
   442  	assert.Len(t, values, 4)
   443  
   444  	// Values should be: [time, dbname, data_json, tag_data_json]
   445  	assert.Equal(t, "db1", values[1])
   446  
   447  	// Check that JSON strings are valid
   448  	dataJSON, ok := values[2].(string)
   449  	assert.True(t, ok, "Data should be JSON string")
   450  	assert.Contains(t, dataJSON, `"value":42`)
   451  	assert.Contains(t, dataJSON, `"name":"test_measurement"`)
   452  
   453  	tagJSON, ok := values[3].(string)
   454  	assert.True(t, ok, "Tag data should be JSON string")
   455  	assert.Contains(t, tagJSON, `"env":"test"`)
   456  
   457  	assert.True(t, cfm.Next())
   458  	_, err = cfm.Values()
   459  	assert.Error(t, err, "cannot marshal function value to JSON")
   460  
   461  	cfm.NextEnvelope()
   462  	assert.NotPanics(t, func() { _ = cfm.MetricName() })
   463  }
   464  
   465  func TestCopyFromMeasurements_ErrorHandling(t *testing.T) {
   466  	// Test Err() method
   467  	cfm := newCopyFromMeasurements([]metrics.MeasurementEnvelope{})
   468  	assert.NoError(t, cfm.Err(), "Err() should always return nil")
   469  }
   470  
   471  func TestCopyFromMeasurements_StateManagement(t *testing.T) {
   472  	// Test that internal state is managed correctly during iteration
   473  	data := []metrics.MeasurementEnvelope{
   474  		{
   475  			MetricName: "metric1",
   476  			DBName:     "db1",
   477  			CustomTags: map[string]string{},
   478  			Data: metrics.Measurements{
   479  				{"epoch_ns": int64(1000), "value": 1},
   480  			},
   481  		},
   482  		{
   483  			MetricName: "metric2", // Different metric
   484  			DBName:     "db1",
   485  			CustomTags: map[string]string{},
   486  			Data: metrics.Measurements{
   487  				{"epoch_ns": int64(2000), "value": 2},
   488  			},
   489  		},
   490  	}
   491  
   492  	cfm := newCopyFromMeasurements(data)
   493  
   494  	// Initial state
   495  	assert.Equal(t, -1, cfm.envelopeIdx)
   496  	assert.Equal(t, -1, cfm.measurementIdx)
   497  	assert.Equal(t, "", cfm.metricName)
   498  
   499  	// After first Next()
   500  	assert.True(t, cfm.Next())
   501  	assert.Equal(t, 0, cfm.envelopeIdx)
   502  	assert.Equal(t, 0, cfm.measurementIdx)
   503  	assert.Equal(t, "metric1", cfm.metricName)
   504  
   505  	// After hitting metric boundary
   506  	assert.False(t, cfm.Next())
   507  	// State should be positioned to restart on next metric
   508  	assert.Equal(t, "", cfm.metricName)
   509  }
   510  
   511  func TestCopyFromMeasurements_CopyFail(t *testing.T) {
   512  	a := assert.New(t)
   513  	r := require.New(t)
   514  
   515  	pgContainer, pgTearDown, err := testutil.SetupPostgresContainer()
   516  	r.NoError(err)
   517  	defer pgTearDown()
   518  
   519  	connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
   520  	r.NoError(err)
   521  	conn, err := pgx.Connect(ctx, connStr)
   522  	r.NoError(err)
   523  
   524  	_, err = conn.Exec(ctx, `CREATE TABLE IF NOT EXISTS test_metric (
   525  		time timestamptz not null default now(),
   526  		dbname text NOT NULL,
   527  		data jsonb,
   528  		tag_data jsonb)`)
   529  	r.NoError(err)
   530  
   531  	msgs := []metrics.MeasurementEnvelope{
   532  		{
   533  			MetricName: "test_metric",
   534  			Data: metrics.Measurements{
   535  				{"epoch_ns": int64(2000), "value": func() {}},
   536  				{"epoch_ns": int64(2000), "value": struct{}{}},
   537  			},
   538  			DBName: "test_db",
   539  		},
   540  	}
   541  
   542  	cfm := newCopyFromMeasurements(msgs)
   543  
   544  	for !cfm.EOF() {
   545  		_, err = conn.CopyFrom(context.Background(), cfm.MetricName(), targetColumns[:], cfm)
   546  		a.Error(err)
   547  		if err != nil {
   548  			if !cfm.NextEnvelope() {
   549  				break
   550  			}
   551  		}
   552  	}
   553  
   554  }
   555  
   556  // tests interval string validation for all
   557  // cli flags that expect a PostgreSQL interval string
   558  func TestIntervalValidation(t *testing.T) {
   559  	a := assert.New(t)
   560  
   561  	pgContainer, pgTearDown, err := testutil.SetupPostgresContainer()
   562  	a.NoError(err)
   563  	defer pgTearDown()
   564  
   565  	connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
   566  	a.NoError(err)
   567  
   568  	opts := &CmdOpts{
   569  		PartitionInterval:   "1 minute",
   570  		MaintenanceInterval: "-1 hours",
   571  		RetentionInterval:   "00:01:30",
   572  		BatchingDelay:       time.Second,
   573  	}
   574  
   575  	_, err = NewPostgresWriter(ctx, connStr, opts)
   576  	a.EqualError(err, "--partition-interval must be at least 1 hour, got: 1 minute")
   577  	opts.PartitionInterval = "1 hour"
   578  
   579  	_, err = NewPostgresWriter(ctx, connStr, opts)
   580  	a.EqualError(err, "--maintenance-interval must be a positive PostgreSQL interval or 0 to disable it")
   581  	opts.MaintenanceInterval = "0 hours"
   582  
   583  	_, err = NewPostgresWriter(ctx, connStr, opts)
   584  	a.Error(err)
   585  
   586  	invalidIntervals := []string{
   587  		"not an interval", "3 dayss",
   588  		"four hours",
   589  	}
   590  
   591  	for _, interval := range invalidIntervals {
   592  		opts.PartitionInterval = interval
   593  		_, err = NewPostgresWriter(ctx, connStr, opts)
   594  		a.Error(err)
   595  		opts.PartitionInterval = "1 hour"
   596  
   597  		opts.MaintenanceInterval = interval
   598  		_, err = NewPostgresWriter(ctx, connStr, opts)
   599  		a.Error(err)
   600  		opts.MaintenanceInterval = "1 hour"
   601  
   602  		opts.RetentionInterval = interval
   603  		_, err = NewPostgresWriter(ctx, connStr, opts)
   604  		a.Error(err)
   605  		opts.RetentionInterval = "1 hour"
   606  	}
   607  
   608  	validIntervals := []string{
   609  		"3 days 4 hours", "1 year",
   610  		"P3D", "PT3H", "0-02", "1 00:00:00",
   611  		"P0-02", "P1", "2 weeks",
   612  	}
   613  
   614  	for _, interval := range validIntervals {
   615  		opts.PartitionInterval = interval
   616  		opts.MaintenanceInterval = interval
   617  		opts.RetentionInterval = interval
   618  
   619  		_, err = NewPostgresWriter(ctx, connStr, opts)
   620  		a.NoError(err)
   621  	}
   622  }
   623  
   624  func TestPartitionInterval(t *testing.T) {
   625  	a := assert.New(t)
   626  	r := require.New(t)
   627  
   628  	pgContainer, pgTearDown, err := testutil.SetupPostgresContainer()
   629  	r.NoError(err)
   630  	defer pgTearDown()
   631  
   632  	connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
   633  	r.NoError(err)
   634  
   635  	opts := &CmdOpts{
   636  		PartitionInterval:   "3 weeks",
   637  		RetentionInterval:   "14 days",
   638  		MaintenanceInterval: "12 hours",
   639  		BatchingDelay:       time.Second,
   640  	}
   641  
   642  	pgw, err := NewPostgresWriter(ctx, connStr, opts)
   643  	r.NoError(err)
   644  
   645  	conn, err := pgx.Connect(ctx, connStr)
   646  	r.NoError(err)
   647  
   648  	m := map[string]map[string]ExistingPartitionInfo{
   649  		"test_metric": {
   650  			"test_db": {
   651  				time.Now(), time.Now().Add(time.Hour),
   652  			},
   653  		},
   654  	}
   655  	err = pgw.EnsureMetricDbnameTime(m)
   656  	r.NoError(err)
   657  
   658  	var partitionsNum int
   659  	err = conn.QueryRow(ctx, "SELECT COUNT(*) FROM pg_partition_tree('test_metric');").Scan(&partitionsNum)
   660  	a.NoError(err)
   661  	// 1 the metric table itself + 1 dbname partition
   662  	// + 4 time partitions (1 we asked for + 3 precreated)
   663  	a.Equal(6, partitionsNum)
   664  
   665  	part := pgw.partitionMapMetricDbname["test_metric"]["test_db"]
   666  	// partition bounds should have a difference of 3 weeks
   667  	a.Equal(part.StartTime.Add(3*7*24*time.Hour), part.EndTime)
   668  }
   669  
   670  func Test_Maintain(t *testing.T) {
   671  	a := assert.New(t)
   672  	r := require.New(t)
   673  
   674  	pgContainer, pgTearDown, err := testutil.SetupPostgresContainer()
   675  	r.NoError(err)
   676  	defer pgTearDown()
   677  
   678  	connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
   679  	r.NoError(err)
   680  	conn, err := pgx.Connect(ctx, connStr)
   681  	r.NoError(err)
   682  
   683  	opts := &CmdOpts{
   684  		PartitionInterval:   "1 hour",
   685  		RetentionInterval:   "1 hour",
   686  		MaintenanceInterval: "0 days",
   687  		BatchingDelay:       time.Hour,
   688  	}
   689  
   690  	pgw, err := NewPostgresWriter(ctx, connStr, opts)
   691  	r.NoError(err)
   692  
   693  	t.Run("MaintainUniqueSources", func(_ *testing.T) {
   694  		// adds an entry to `admin.all_distinct_dbname_metrics`
   695  		err = pgw.SyncMetric("test", "test_metric_1", AddOp)
   696  		r.NoError(err)
   697  
   698  		var numOfEntries int
   699  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics;").Scan(&numOfEntries)
   700  		a.NoError(err)
   701  		a.Equal(1, numOfEntries)
   702  
   703  		// manually call the maintenance routine
   704  		pgw.MaintainUniqueSources()
   705  
   706  		// entry should have been deleted, because it has no corresponding entries in `test_metric_1` table.
   707  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics;").Scan(&numOfEntries)
   708  		a.NoError(err)
   709  		a.Equal(0, numOfEntries)
   710  
   711  		message := []metrics.MeasurementEnvelope{
   712  			{
   713  				MetricName: "test_metric_1",
   714  				Data: metrics.Measurements{
   715  					{"number": 1, "string": "test_data"},
   716  				},
   717  				DBName: "test_db",
   718  			},
   719  		}
   720  		pgw.flush(message)
   721  
   722  		// manually call the maintenance routine
   723  		pgw.MaintainUniqueSources()
   724  
   725  		// entry should have been added, because there is a corresponding entry in `test_metric_1` table just written.
   726  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics;").Scan(&numOfEntries)
   727  		a.NoError(err)
   728  		a.Equal(1, numOfEntries)
   729  
   730  		_, err = conn.Exec(ctx, "DROP TABLE test_metric_1;")
   731  		r.NoError(err)
   732  
   733  		// the corresponding entry should be deleted
   734  		pgw.MaintainUniqueSources()
   735  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics;").Scan(&numOfEntries)
   736  		a.NoError(err)
   737  		a.Equal(0, numOfEntries)
   738  	})
   739  
   740  	t.Run("MaintainUniqueSources_MultipleMetricsAndSources", func(_ *testing.T) {
   741  		// Create metric tables with partitions
   742  		err = pgw.EnsureMetricDummy("test_metric_a")
   743  		r.NoError(err)
   744  		err = pgw.EnsureMetricDummy("test_metric_b")
   745  		r.NoError(err)
   746  
   747  		// Create partitions for each dbname
   748  		_, err = conn.Exec(ctx, `
   749  			CREATE TABLE subpartitions.test_metric_a_db1 PARTITION OF public.test_metric_a FOR VALUES IN ('db1');
   750  			CREATE TABLE subpartitions.test_metric_a_db2 PARTITION OF public.test_metric_a FOR VALUES IN ('db2');
   751  			CREATE TABLE subpartitions.test_metric_a_db3 PARTITION OF public.test_metric_a FOR VALUES IN ('db3');
   752  			CREATE TABLE subpartitions.test_metric_b_db1 PARTITION OF public.test_metric_b FOR VALUES IN ('db1');
   753  			CREATE TABLE subpartitions.test_metric_b_db2 PARTITION OF public.test_metric_b FOR VALUES IN ('db2');
   754  		`)
   755  		r.NoError(err)
   756  
   757  		// Directly insert test data with different dbnames
   758  		_, err = conn.Exec(ctx, `
   759  			INSERT INTO test_metric_a (time, dbname, data) VALUES 
   760  				(now(), 'db1', '{}'::jsonb),
   761  				(now(), 'db2', '{}'::jsonb),
   762  				(now(), 'db3', '{}'::jsonb)
   763  		`)
   764  		r.NoError(err)
   765  
   766  		_, err = conn.Exec(ctx, `
   767  			INSERT INTO test_metric_b (time, dbname, data) VALUES 
   768  				(now(), 'db1', '{}'::jsonb),
   769  				(now(), 'db2', '{}'::jsonb)
   770  		`)
   771  		r.NoError(err)
   772  
   773  		// Run maintenance
   774  		pgw.MaintainUniqueSources()
   775  
   776  		// Should have 3 entries for test_metric_a and 2 for test_metric_b
   777  		var count int
   778  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_a';").Scan(&count)
   779  		a.NoError(err)
   780  		a.Equal(3, count)
   781  
   782  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_b';").Scan(&count)
   783  		a.NoError(err)
   784  		a.Equal(2, count)
   785  
   786  		// Cleanup
   787  		_, err = conn.Exec(ctx, "DROP TABLE test_metric_a, test_metric_b;")
   788  		r.NoError(err)
   789  		pgw.MaintainUniqueSources()
   790  	})
   791  
   792  	t.Run("MaintainUniqueSources_StaleEntriesCleanup", func(_ *testing.T) {
   793  		// Create metric table
   794  		err = pgw.EnsureMetricDummy("test_metric_c")
   795  		r.NoError(err)
   796  
   797  		// Create partition for db_active
   798  		_, err = conn.Exec(ctx, `
   799  			CREATE TABLE subpartitions.test_metric_c_db_active PARTITION OF public.test_metric_c FOR VALUES IN ('db_active');
   800  		`)
   801  		r.NoError(err)
   802  
   803  		// Directly insert test data with one active dbname
   804  		_, err = conn.Exec(ctx, `
   805  			INSERT INTO test_metric_c (time, dbname, data) VALUES 
   806  				(now(), 'db_active', '{}'::jsonb)
   807  		`)
   808  		r.NoError(err)
   809  
   810  		// Manually add the active entry and stale entries to the listing table
   811  		_, err = conn.Exec(ctx, "INSERT INTO admin.all_distinct_dbname_metrics (dbname, metric) VALUES ('db_active', 'test_metric_c'), ('db_stale1', 'test_metric_c'), ('db_stale2', 'test_metric_c');")
   812  		r.NoError(err)
   813  
   814  		var count int
   815  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_c';").Scan(&count)
   816  		a.NoError(err)
   817  		a.Equal(3, count) // 1 active + 2 stale
   818  
   819  		// Run maintenance - should remove stale entries
   820  		pgw.MaintainUniqueSources()
   821  
   822  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_c';").Scan(&count)
   823  		a.NoError(err)
   824  		a.Equal(1, count) // only active one remains
   825  
   826  		var dbname string
   827  		err = conn.QueryRow(ctx, "SELECT dbname FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_c';").Scan(&dbname)
   828  		a.NoError(err)
   829  		a.Equal("db_active", dbname)
   830  
   831  		// Cleanup
   832  		_, err = conn.Exec(ctx, "DROP TABLE test_metric_c;")
   833  		r.NoError(err)
   834  		pgw.MaintainUniqueSources()
   835  	})
   836  
   837  	t.Run("MaintainUniqueSources_AdvisoryLock", func(_ *testing.T) {
   838  		// Create a second connection to simulate concurrent maintenance
   839  		conn2, err := pgx.Connect(ctx, connStr)
   840  		r.NoError(err)
   841  		defer conn2.Close(ctx)
   842  
   843  		// Create metric table and partition
   844  		err = pgw.EnsureMetricDummy("test_metric_d")
   845  		r.NoError(err)
   846  
   847  		_, err = conn.Exec(ctx, `
   848  			CREATE TABLE subpartitions.test_metric_d_db1 PARTITION OF public.test_metric_d FOR VALUES IN ('db1');
   849  		`)
   850  		r.NoError(err)
   851  
   852  		// Directly insert test data for only db1
   853  		_, err = conn.Exec(ctx, `
   854  			INSERT INTO test_metric_d (time, dbname, data) VALUES 
   855  				(now(), 'db1', '{}'::jsonb)
   856  		`)
   857  		r.NoError(err)
   858  
   859  		// Add both active and stale entries to the listing table
   860  		_, err = conn.Exec(ctx, "INSERT INTO admin.all_distinct_dbname_metrics (dbname, metric) VALUES ('db1', 'test_metric_d'), ('db_stale', 'test_metric_d');")
   861  		r.NoError(err)
   862  
   863  		var count int
   864  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_d';").Scan(&count)
   865  		a.NoError(err)
   866  		a.Equal(2, count, "Should have 2 entries initially (1 active + 1 stale)")
   867  
   868  		// Acquire the advisory lock using session-level lock in conn2
   869  		// This will block transaction-level locks from the same lock ID
   870  		var lockAcquired bool
   871  		err = conn2.QueryRow(ctx, "SELECT pg_try_advisory_lock(1571543679778230000);").Scan(&lockAcquired)
   872  		r.NoError(err)
   873  		a.True(lockAcquired, "Should acquire advisory lock")
   874  
   875  		// Try to run maintenance - should skip because lock is held by conn2
   876  		pgw.MaintainUniqueSources()
   877  
   878  		// Stale entry should still exist because maintenance was skipped
   879  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_d';").Scan(&count)
   880  		a.NoError(err)
   881  		a.Equal(2, count, "Stale entry should remain because maintenance was skipped due to lock")
   882  
   883  		// Release lock from conn2
   884  		_, err = conn2.Exec(ctx, "SELECT pg_advisory_unlock(1571543679778230000);")
   885  		r.NoError(err)
   886  
   887  		// Now maintenance should work and clean up stale entry
   888  		pgw.MaintainUniqueSources()
   889  
   890  		// Should only have the active entry, stale one removed
   891  		err = conn.QueryRow(ctx, "SELECT count(*) FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_d';").Scan(&count)
   892  		a.NoError(err)
   893  		a.Equal(1, count, "Only active entry should remain after maintenance runs")
   894  
   895  		var dbname string
   896  		err = conn.QueryRow(ctx, "SELECT dbname FROM admin.all_distinct_dbname_metrics WHERE metric = 'test_metric_d';").Scan(&dbname)
   897  		a.NoError(err)
   898  		a.Equal("db1", dbname, "Remaining entry should be the active one")
   899  
   900  		// Cleanup
   901  		_, err = conn.Exec(ctx, "DROP TABLE test_metric_d;")
   902  		r.NoError(err)
   903  		pgw.MaintainUniqueSources()
   904  	})
   905  
   906  	t.Run("DeleteOldPartitions", func(_ *testing.T) {
   907  		// Creates a new top level table for `test_metric_2`
   908  		err = pgw.SyncMetric("test", "test_metric_2", AddOp)
   909  		r.NoError(err)
   910  
   911  		// create the 2nd level dbname partition
   912  		_, err = conn.Exec(ctx, "CREATE TABLE subpartitions.test_metric_2_dbname PARTITION OF public.test_metric_2 FOR VALUES IN ('test') PARTITION BY RANGE (time)")
   913  		a.NoError(err)
   914  
   915  		boundStart := time.Now().Add(-1 * 2 * 24 * time.Hour).Format("2006-01-02")
   916  		boundEnd := time.Now().Add(-1 * 24 * time.Hour).Format("2006-01-02")
   917  
   918  		// create the 3rd level time partition with end bound yesterday
   919  		_, err = conn.Exec(ctx,
   920  			fmt.Sprintf(
   921  				`CREATE TABLE subpartitions.test_metric_2_dbname_time 
   922  			PARTITION OF subpartitions.test_metric_2_dbname 
   923  			FOR VALUES FROM ('%s') TO ('%s')`,
   924  				boundStart, boundEnd),
   925  		)
   926  		a.NoError(err)
   927  		_, err = conn.Exec(ctx, "COMMENT ON TABLE subpartitions.test_metric_2_dbname_time IS $$pgwatch-generated-metric-dbname-time-lvl$$")
   928  		a.NoError(err)
   929  
   930  		var partitionsNum int
   931  		err = conn.QueryRow(ctx, "SELECT COUNT(*) FROM pg_partition_tree('test_metric_2');").Scan(&partitionsNum)
   932  		a.NoError(err)
   933  		a.Equal(3, partitionsNum)
   934  
   935  		pgw.opts.RetentionInterval = "2 days"
   936  		pgw.DeleteOldPartitions() // 1 day < 2 days, shouldn't delete anything
   937  
   938  		err = conn.QueryRow(ctx, "SELECT COUNT(*) FROM pg_partition_tree('test_metric_2');").Scan(&partitionsNum)
   939  		a.NoError(err)
   940  		a.Equal(3, partitionsNum)
   941  
   942  		pgw.opts.RetentionInterval = "1 hour"
   943  		pgw.DeleteOldPartitions() // 1 day > 1 hour, should delete the partition
   944  
   945  		err = conn.QueryRow(ctx, "SELECT COUNT(*) FROM pg_partition_tree('test_metric_2');").Scan(&partitionsNum)
   946  		a.NoError(err)
   947  		a.Equal(2, partitionsNum)
   948  	})
   949  
   950  	t.Run("Epoch to Duration Conversion", func(_ *testing.T) {
   951  		table := map[string]time.Duration{
   952  			"1 hour":   time.Hour,
   953  			"2 hours":  2 * time.Hour,
   954  			"4 days":   4 * 24 * time.Hour,
   955  			"1 day":    24 * time.Hour,
   956  			"1 year":   365.25 * 24 * time.Hour,
   957  			"1 week":   7 * 24 * time.Hour,
   958  			"3 weeks":  3 * 7 * 24 * time.Hour,
   959  			"2 months": 2 * 30 * 24 * time.Hour,
   960  			"1 month":  30 * 24 * time.Hour,
   961  		}
   962  
   963  		for k, v := range table {
   964  			opts := &CmdOpts{
   965  				PartitionInterval:   "1 hour",
   966  				RetentionInterval:   k,
   967  				MaintenanceInterval: k,
   968  				BatchingDelay:       time.Hour,
   969  			}
   970  
   971  			pgw, err := NewPostgresWriter(ctx, connStr, opts)
   972  			a.NoError(err)
   973  			a.Equal(pgw.retentionInterval, v)
   974  			a.Equal(pgw.maintenanceInterval, v)
   975  		}
   976  	})
   977  }
   978