...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/sources/conn_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/sources

     1  package sources_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/jackc/pgx/v5"
    10  	"github.com/jackc/pgx/v5/pgconn"
    11  	"github.com/jackc/pgx/v5/pgxpool"
    12  	"github.com/pashagolub/pgxmock/v4"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  
    16  	"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
    17  	"github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
    18  )
    19  
    20  const ImageName = "docker.io/postgres:17-alpine"
    21  
    22  func TestSourceConn_Connect(t *testing.T) {
    23  
    24  	t.Run("failed config parsing", func(t *testing.T) {
    25  		md := &sources.SourceConn{}
    26  		md.ConnStr = "invalid connection string"
    27  		err := md.Connect(ctx, sources.CmdOpts{})
    28  		assert.Error(t, err)
    29  	})
    30  
    31  	t.Run("failed connection", func(t *testing.T) {
    32  		md := &sources.SourceConn{}
    33  		sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
    34  			return nil, assert.AnError
    35  		}
    36  		err := md.Connect(ctx, sources.CmdOpts{})
    37  		assert.ErrorIs(t, err, assert.AnError)
    38  	})
    39  
    40  	t.Run("successful connection to pgbouncer", func(t *testing.T) {
    41  		mock, err := pgxmock.NewPool()
    42  		require.NoError(t, err)
    43  		sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
    44  			return mock, nil
    45  		}
    46  
    47  		md := &sources.SourceConn{}
    48  		md.Kind = sources.SourcePgBouncer
    49  
    50  		opts := sources.CmdOpts{}
    51  		opts.MaxParallelConnectionsPerDb = 3
    52  
    53  		mock.ExpectExec("SHOW VERSION").WillReturnResult(pgconn.NewCommandTag("SELECT 1"))
    54  
    55  		err = md.Connect(ctx, opts)
    56  		assert.NoError(t, err)
    57  
    58  		assert.NoError(t, mock.ExpectationsWereMet())
    59  	})
    60  }
    61  
    62  func TestSourceConn_ParseConfig(t *testing.T) {
    63  	md := &sources.SourceConn{}
    64  	assert.NoError(t, md.ParseConfig())
    65  	//cached config
    66  	assert.NoError(t, md.ParseConfig())
    67  }
    68  
    69  func TestSourceConn_GetDatabaseName(t *testing.T) {
    70  	md := &sources.SourceConn{}
    71  	md.ConnStr = "postgres://user:password@localhost:5432/mydatabase"
    72  	expected := "mydatabase"
    73  	// check pgx.ConnConfig related code
    74  	got := md.GetDatabaseName()
    75  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    76  	// check ConnStr parsing
    77  	got = md.Source.GetDatabaseName()
    78  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    79  
    80  	md = &sources.SourceConn{}
    81  	md.ConnStr = "foo boo"
    82  	expected = ""
    83  	got = md.GetDatabaseName()
    84  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    85  }
    86  
    87  func TestSourceConn_SetDatabaseName(t *testing.T) {
    88  	md := &sources.SourceConn{}
    89  	md.ConnStr = "postgres://user:password@localhost:5432/mydatabase"
    90  	expected := "mydatabase"
    91  	// check ConnStr parsing
    92  	md.SetDatabaseName(expected)
    93  	got := md.GetDatabaseName()
    94  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    95  	// check pgx.ConnConfig related code
    96  	expected = "newdatabase"
    97  	md.SetDatabaseName(expected)
    98  	got = md.GetDatabaseName()
    99  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
   100  
   101  	md = &sources.SourceConn{}
   102  	md.ConnStr = "foo boo"
   103  	expected = ""
   104  	md.SetDatabaseName("ingored due to invalid ConnStr")
   105  	got = md.GetDatabaseName()
   106  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
   107  }
   108  
   109  func TestSourceConn_DiscoverPlatform(t *testing.T) {
   110  	ctx := context.Background()
   111  	mock, err := pgxmock.NewPool()
   112  	require.NoError(t, err)
   113  	md := &sources.SourceConn{Conn: mock}
   114  
   115  	mock.ExpectQuery("select").WillReturnRows(pgxmock.NewRows([]string{"exec_env"}).AddRow("AZURE_SINGLE"))
   116  	md.ExecEnv = md.DiscoverPlatform(ctx)
   117  	assert.Equal(t, "AZURE_SINGLE", md.ExecEnv)
   118  	assert.Equal(t, "AZURE_SINGLE", md.DiscoverPlatform(ctx)) // cached
   119  	assert.NoError(t, mock.ExpectationsWereMet())
   120  }
   121  
   122  func TestSourceConn_GetApproxSize(t *testing.T) {
   123  	mock, err := pgxmock.NewPool()
   124  	require.NoError(t, err)
   125  	md := &sources.SourceConn{Conn: mock}
   126  
   127  	mock.ExpectQuery("select").WillReturnRows(pgxmock.NewRows([]string{"size"}).AddRow(42))
   128  
   129  	assert.EqualValues(t, 42, md.FetchApproxSize(ctx))
   130  	assert.NoError(t, err)
   131  	assert.NoError(t, mock.ExpectationsWereMet())
   132  }
   133  
   134  func TestSourceConn_FunctionExists(t *testing.T) {
   135  	mock, err := pgxmock.NewPool()
   136  	require.NoError(t, err)
   137  	md := &sources.SourceConn{Conn: mock}
   138  
   139  	mock.ExpectQuery("select").WithArgs("get_foo").WillReturnRows(pgxmock.NewRows([]string{"exists"}))
   140  
   141  	assert.False(t, md.FunctionExists(ctx, "get_foo"))
   142  	assert.NoError(t, mock.ExpectationsWereMet())
   143  }
   144  
   145  func TestSourceConn_IsPostgresSource(t *testing.T) {
   146  	md := &sources.SourceConn{}
   147  	md.Kind = sources.SourcePostgres
   148  	assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true")
   149  
   150  	md.Kind = sources.SourcePgBouncer
   151  	assert.False(t, md.IsPostgresSource(), "IsPostgresSource() = true, want false")
   152  
   153  	md.Kind = sources.SourcePgPool
   154  	assert.False(t, md.IsPostgresSource(), "IsPostgresSource() = true, want false")
   155  
   156  	md.Kind = sources.SourcePatroni
   157  	assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true")
   158  }
   159  
   160  func TestSourceConn_Ping(t *testing.T) {
   161  	db, err := pgxmock.NewPool()
   162  	require.NoError(t, err)
   163  	md := &sources.SourceConn{Conn: db}
   164  
   165  	db.ExpectPing()
   166  	md.Kind = sources.SourcePostgres
   167  	assert.NoError(t, md.Ping(ctx), "Ping() = error, want nil")
   168  
   169  	db.ExpectExec("SHOW VERSION").WillReturnResult(pgconn.NewCommandTag("SELECT 1"))
   170  	md.Conn = db
   171  	md.Kind = sources.SourcePgBouncer
   172  	assert.NoError(t, md.Ping(ctx), "Ping() = error, want nil")
   173  }
   174  
   175  func TestSourceConn_GetMetricInterval(t *testing.T) {
   176  	md := &sources.SourceConn{
   177  		Source: sources.Source{
   178  			Metrics:        map[string]float64{"foo": 1.5, "bar": 2.5},
   179  			MetricsStandby: map[string]float64{"foo": 3.5},
   180  		},
   181  	}
   182  
   183  	t.Run("primary uses Metrics", func(t *testing.T) {
   184  		md.IsInRecovery = false
   185  		assert.Equal(t, 1.5, md.GetMetricInterval("foo"))
   186  		assert.Equal(t, 2.5, md.GetMetricInterval("bar"))
   187  	})
   188  
   189  	t.Run("standby uses MetricsStandby if present", func(t *testing.T) {
   190  		md.IsInRecovery = true
   191  		assert.Equal(t, 3.5, md.GetMetricInterval("foo"))
   192  		assert.Equal(t, 0.0, md.GetMetricInterval("bar"))
   193  	})
   194  
   195  	t.Run("standby with empty MetricsStandby falls back to Metrics", func(t *testing.T) {
   196  		md.IsInRecovery = true
   197  		md.MetricsStandby = map[string]float64{}
   198  		assert.Equal(t, 1.5, md.GetMetricInterval("foo"))
   199  	})
   200  }
   201  
   202  func TestVersionToInt(t *testing.T) {
   203  	tests := []struct {
   204  		arg  string
   205  		want int
   206  	}{
   207  		{"", 0},
   208  		{"foo", 0},
   209  		{"13", 13_00_00},
   210  		{"3.0", 3_00_00},
   211  		{"9.6.3", 9_06_03},
   212  		{"v9.6-beta2", 9_06_00},
   213  	}
   214  	for _, tt := range tests {
   215  		if got := sources.VersionToInt(tt.arg); got != tt.want {
   216  			t.Errorf("VersionToInt() = %v, want %v", got, tt.want)
   217  		}
   218  	}
   219  }
   220  
   221  func TestSourceConn_FetchRuntimeInfo(t *testing.T) {
   222  	ctx := context.Background()
   223  
   224  	t.Run("cancelled context", func(t *testing.T) {
   225  		ctxNew, cancel := context.WithCancel(ctx)
   226  		cancel()
   227  		err := (&sources.SourceConn{}).FetchRuntimeInfo(ctxNew, true)
   228  		assert.Error(t, err)
   229  	})
   230  
   231  	t.Run("cached version", func(t *testing.T) {
   232  		md := &sources.SourceConn{
   233  			RuntimeInfo: sources.RuntimeInfo{
   234  				LastCheckedOn: time.Now().Add(-time.Minute),
   235  				Version:       42,
   236  			},
   237  		}
   238  		err := md.FetchRuntimeInfo(ctx, false)
   239  		assert.NoError(t, err)
   240  		assert.Equal(t, 42, md.Version)
   241  	})
   242  
   243  	t.Run("pgbouncer version fetch", func(t *testing.T) {
   244  		mock, err := pgxmock.NewPool()
   245  		require.NoError(t, err)
   246  		md := &sources.SourceConn{
   247  			Conn:   mock,
   248  			Source: sources.Source{Kind: sources.SourcePgBouncer},
   249  		}
   250  		mock.ExpectQuery("SHOW VERSION").
   251  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   252  			WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("PgBouncer 1.12.0"))
   253  		err = md.FetchRuntimeInfo(ctx, true)
   254  		assert.NoError(t, err)
   255  		assert.Contains(t, md.VersionStr, "PgBouncer")
   256  		assert.True(t, md.Version > 0)
   257  		assert.NoError(t, mock.ExpectationsWereMet())
   258  	})
   259  
   260  	t.Run("pgpool version fetch", func(t *testing.T) {
   261  		mock, err := pgxmock.NewPool()
   262  		require.NoError(t, err)
   263  		md := &sources.SourceConn{
   264  			Conn:   mock,
   265  			Source: sources.Source{Kind: sources.SourcePgPool},
   266  		}
   267  		mock.ExpectQuery("SHOW POOL_VERSION").
   268  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   269  			WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("4.1.2"))
   270  		err = md.FetchRuntimeInfo(ctx, true)
   271  		assert.NoError(t, err)
   272  		assert.Contains(t, md.VersionStr, "4.1.2")
   273  		assert.True(t, md.Version > 0)
   274  		assert.NoError(t, mock.ExpectationsWereMet())
   275  	})
   276  
   277  	t.Run("postgres version and extensions", func(t *testing.T) {
   278  		mock, err := pgxmock.NewPool()
   279  		require.NoError(t, err)
   280  		md := &sources.SourceConn{
   281  			Conn:   mock,
   282  			Source: sources.Source{Kind: sources.SourcePostgres},
   283  		}
   284  		mock.ExpectQuery("select").WillReturnRows(
   285  			pgxmock.NewRows([]string{"ver", "version", "pg_is_in_recovery", "current_database", "system_identifier", "is_superuser"}).
   286  				AddRow(13, "PostgreSQL 13.3", false, "testdb", "42424242", true),
   287  		)
   288  		mock.ExpectQuery("select").WillReturnRows(
   289  			pgxmock.NewRows([]string{"exec_env"}).AddRow("UNKNOWN"),
   290  		)
   291  		mock.ExpectQuery("select").WillReturnRows(
   292  			pgxmock.NewRows([]string{"approx_size"}).AddRow(42),
   293  		)
   294  
   295  		mock.ExpectQuery("select").WillReturnRows(
   296  			pgxmock.NewRows([]string{"extname", "extversion"}).AddRow("pg_stat_statements", "1.8"),
   297  		)
   298  		err = md.FetchRuntimeInfo(ctx, true)
   299  		assert.NoError(t, err)
   300  		assert.Equal(t, 13, md.Version)
   301  		assert.Equal(t, "testdb", md.RealDbname)
   302  		assert.Contains(t, md.Extensions, "pg_stat_statements")
   303  		assert.NoError(t, mock.ExpectationsWereMet())
   304  	})
   305  
   306  	t.Run("query error", func(t *testing.T) {
   307  		mock, err := pgxmock.NewPool()
   308  		require.NoError(t, err)
   309  		md := &sources.SourceConn{
   310  			Conn:   mock,
   311  			Source: sources.Source{Kind: sources.SourcePgBouncer},
   312  		}
   313  		mock.ExpectQuery("SHOW VERSION").
   314  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   315  			WillReturnError(fmt.Errorf("db error"))
   316  		err = md.FetchRuntimeInfo(ctx, true)
   317  		assert.Error(t, err)
   318  		assert.NoError(t, mock.ExpectationsWereMet())
   319  	})
   320  }
   321  
   322  func TestSourceConn_FetchVersion(t *testing.T) {
   323  	ctx := context.Background()
   324  
   325  	t.Run("valid version string", func(t *testing.T) {
   326  		mock, err := pgxmock.NewPool()
   327  		require.NoError(t, err)
   328  		md := &sources.SourceConn{Conn: mock}
   329  		mock.ExpectQuery("SHOW VERSION").
   330  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   331  			WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("FooBar 1.12.0"))
   332  		verStr, verInt, err := md.FetchVersion(ctx, "SHOW VERSION")
   333  		assert.NoError(t, err)
   334  		assert.Equal(t, "FooBar 1.12.0", verStr)
   335  		assert.Equal(t, 1_12_00, verInt)
   336  		assert.NoError(t, mock.ExpectationsWereMet())
   337  	})
   338  
   339  	t.Run("invalid version string", func(t *testing.T) {
   340  		mock, err := pgxmock.NewPool()
   341  		require.NoError(t, err)
   342  		md := &sources.SourceConn{Conn: mock}
   343  		mock.ExpectQuery("SHOW VERSION").
   344  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   345  			WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("invalid version"))
   346  		_, verInt, err := md.FetchVersion(ctx, "SHOW VERSION")
   347  		assert.Equal(t, 0, verInt)
   348  		assert.NoError(t, err)
   349  		assert.NoError(t, mock.ExpectationsWereMet())
   350  	})
   351  
   352  	t.Run("query error", func(t *testing.T) {
   353  		mock, err := pgxmock.NewPool()
   354  		require.NoError(t, err)
   355  		md := &sources.SourceConn{Conn: mock}
   356  		mock.ExpectQuery("SHOW VERSION").
   357  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   358  			WillReturnError(assert.AnError)
   359  		_, _, err = md.FetchVersion(ctx, "SHOW VERSION")
   360  		assert.Error(t, err)
   361  		assert.NoError(t, mock.ExpectationsWereMet())
   362  	})
   363  }
   364  
   365  func TestSourceConn_GetClusterIdentifier(t *testing.T) {
   366  	md := &sources.SourceConn{
   367  		Source: sources.Source{
   368  			Name:    "test",
   369  			Kind:    sources.SourcePostgres,
   370  			ConnStr: "postgres://user:password@localhost:5432/mydatabase",
   371  		},
   372  		RuntimeInfo: sources.RuntimeInfo{
   373  			SystemIdentifier: "42424242",
   374  		},
   375  	}
   376  	assert.Equal(t, "42424242:localhost:5432", md.GetClusterIdentifier())
   377  
   378  	md = &sources.SourceConn{
   379  		Source: sources.Source{
   380  			Name:    "test",
   381  			Kind:    sources.SourcePostgres,
   382  			ConnStr: "foo boo",
   383  		},
   384  	}
   385  	assert.Equal(t, "", md.GetClusterIdentifier())
   386  }
   387