...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/sources/conn_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/sources

     1  package sources_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/jackc/pgx/v5"
    10  	"github.com/jackc/pgx/v5/pgconn"
    11  	"github.com/jackc/pgx/v5/pgxpool"
    12  	"github.com/pashagolub/pgxmock/v4"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  
    16  	"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
    17  	"github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
    18  )
    19  
    20  func TestSourceConn_Connect(t *testing.T) {
    21  
    22  	t.Run("failed config parsing", func(t *testing.T) {
    23  		md := &sources.SourceConn{}
    24  		md.ConnStr = "invalid connection string"
    25  		err := md.Connect(ctx, sources.CmdOpts{})
    26  		assert.Error(t, err)
    27  	})
    28  
    29  	t.Run("failed connection", func(t *testing.T) {
    30  		md := &sources.SourceConn{}
    31  		sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
    32  			return nil, assert.AnError
    33  		}
    34  		err := md.Connect(ctx, sources.CmdOpts{})
    35  		assert.ErrorIs(t, err, assert.AnError)
    36  	})
    37  
    38  	t.Run("successful connection to pgbouncer", func(t *testing.T) {
    39  		mock, err := pgxmock.NewPool()
    40  		require.NoError(t, err)
    41  		sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
    42  			return mock, nil
    43  		}
    44  
    45  		md := &sources.SourceConn{}
    46  		md.Kind = sources.SourcePgBouncer
    47  
    48  		opts := sources.CmdOpts{}
    49  		opts.MaxParallelConnectionsPerDb = 3
    50  
    51  		mock.ExpectExec("SHOW VERSION").WillReturnResult(pgconn.NewCommandTag("SELECT 1"))
    52  
    53  		err = md.Connect(ctx, opts)
    54  		assert.NoError(t, err)
    55  
    56  		assert.NoError(t, mock.ExpectationsWereMet())
    57  	})
    58  }
    59  
    60  func TestSourceConn_ParseConfig(t *testing.T) {
    61  	md := &sources.SourceConn{}
    62  	assert.NoError(t, md.ParseConfig())
    63  	//cached config
    64  	assert.NoError(t, md.ParseConfig())
    65  }
    66  
    67  func TestSourceConn_GetDatabaseName(t *testing.T) {
    68  	md := &sources.SourceConn{}
    69  	md.ConnStr = "postgres://user:password@localhost:5432/mydatabase"
    70  	expected := "mydatabase"
    71  	// check pgx.ConnConfig related code
    72  	got := md.GetDatabaseName()
    73  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    74  	// check ConnStr parsing
    75  	got = md.Source.GetDatabaseName()
    76  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    77  
    78  	md = &sources.SourceConn{}
    79  	md.ConnStr = "foo boo"
    80  	expected = ""
    81  	got = md.GetDatabaseName()
    82  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    83  }
    84  
    85  func TestSourceConn_SetDatabaseName(t *testing.T) {
    86  	md := &sources.SourceConn{}
    87  	md.ConnStr = "postgres://user:password@localhost:5432/mydatabase"
    88  	expected := "mydatabase"
    89  	// check ConnStr parsing
    90  	md.SetDatabaseName(expected)
    91  	got := md.GetDatabaseName()
    92  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    93  	// check pgx.ConnConfig related code
    94  	expected = "newdatabase"
    95  	md.SetDatabaseName(expected)
    96  	got = md.GetDatabaseName()
    97  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
    98  
    99  	md = &sources.SourceConn{}
   100  	md.ConnStr = "foo boo"
   101  	expected = ""
   102  	md.SetDatabaseName("ingored due to invalid ConnStr")
   103  	got = md.GetDatabaseName()
   104  	assert.Equal(t, expected, got, "GetDatabaseName() = %v, want %v", got, expected)
   105  }
   106  
   107  func TestSourceConn_DiscoverPlatform(t *testing.T) {
   108  	ctx := context.Background()
   109  	mock, err := pgxmock.NewPool()
   110  	require.NoError(t, err)
   111  	md := &sources.SourceConn{Conn: mock}
   112  
   113  	mock.ExpectQuery("select").WillReturnRows(pgxmock.NewRows([]string{"exec_env"}).AddRow("AZURE_SINGLE"))
   114  	md.ExecEnv = md.DiscoverPlatform(ctx)
   115  	assert.Equal(t, "AZURE_SINGLE", md.ExecEnv)
   116  	assert.Equal(t, "AZURE_SINGLE", md.DiscoverPlatform(ctx)) // cached
   117  	assert.NoError(t, mock.ExpectationsWereMet())
   118  }
   119  
   120  func TestSourceConn_GetApproxSize(t *testing.T) {
   121  	mock, err := pgxmock.NewPool()
   122  	require.NoError(t, err)
   123  	md := &sources.SourceConn{Conn: mock}
   124  
   125  	mock.ExpectQuery("select").WillReturnRows(pgxmock.NewRows([]string{"size"}).AddRow(42))
   126  
   127  	assert.EqualValues(t, 42, md.FetchApproxSize(ctx))
   128  	assert.NoError(t, err)
   129  	assert.NoError(t, mock.ExpectationsWereMet())
   130  }
   131  
   132  func TestSourceConn_FunctionExists(t *testing.T) {
   133  	mock, err := pgxmock.NewPool()
   134  	require.NoError(t, err)
   135  	md := &sources.SourceConn{Conn: mock}
   136  
   137  	mock.ExpectQuery("select").WithArgs("get_foo").WillReturnRows(pgxmock.NewRows([]string{"exists"}))
   138  
   139  	assert.False(t, md.FunctionExists(ctx, "get_foo"))
   140  	assert.NoError(t, mock.ExpectationsWereMet())
   141  }
   142  
   143  func TestSourceConn_IsPostgresSource(t *testing.T) {
   144  	md := &sources.SourceConn{}
   145  	md.Kind = sources.SourcePostgres
   146  	assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true")
   147  
   148  	md.Kind = sources.SourcePgBouncer
   149  	assert.False(t, md.IsPostgresSource(), "IsPostgresSource() = true, want false")
   150  
   151  	md.Kind = sources.SourcePgPool
   152  	assert.False(t, md.IsPostgresSource(), "IsPostgresSource() = true, want false")
   153  
   154  	md.Kind = sources.SourcePatroni
   155  	assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true")
   156  }
   157  
   158  func TestSourceConn_Ping(t *testing.T) {
   159  	db, err := pgxmock.NewPool()
   160  	require.NoError(t, err)
   161  	md := &sources.SourceConn{Conn: db}
   162  
   163  	db.ExpectPing()
   164  	md.Kind = sources.SourcePostgres
   165  	assert.NoError(t, md.Ping(ctx), "Ping() = error, want nil")
   166  
   167  	db.ExpectExec("SHOW VERSION").WillReturnResult(pgconn.NewCommandTag("SELECT 1"))
   168  	md.Conn = db
   169  	md.Kind = sources.SourcePgBouncer
   170  	assert.NoError(t, md.Ping(ctx), "Ping() = error, want nil")
   171  }
   172  
   173  func TestSourceConn_GetMetricInterval(t *testing.T) {
   174  	md := &sources.SourceConn{
   175  		Source: sources.Source{
   176  			Metrics:        map[string]float64{"foo": 1.5, "bar": 2.5},
   177  			MetricsStandby: map[string]float64{"foo": 3.5},
   178  		},
   179  	}
   180  
   181  	t.Run("primary uses Metrics", func(t *testing.T) {
   182  		md.IsInRecovery = false
   183  		assert.Equal(t, 1.5, md.GetMetricInterval("foo"))
   184  		assert.Equal(t, 2.5, md.GetMetricInterval("bar"))
   185  	})
   186  
   187  	t.Run("standby uses MetricsStandby if present", func(t *testing.T) {
   188  		md.IsInRecovery = true
   189  		assert.Equal(t, 3.5, md.GetMetricInterval("foo"))
   190  		assert.Equal(t, 0.0, md.GetMetricInterval("bar"))
   191  	})
   192  
   193  	t.Run("standby with empty MetricsStandby falls back to Metrics", func(t *testing.T) {
   194  		md.IsInRecovery = true
   195  		md.MetricsStandby = map[string]float64{}
   196  		assert.Equal(t, 1.5, md.GetMetricInterval("foo"))
   197  	})
   198  }
   199  
   200  func TestVersionToInt(t *testing.T) {
   201  	tests := []struct {
   202  		arg  string
   203  		want int
   204  	}{
   205  		{"", 0},
   206  		{"foo", 0},
   207  		{"13", 13_00_00},
   208  		{"3.0", 3_00_00},
   209  		{"9.6.3", 9_06_03},
   210  		{"v9.6-beta2", 9_06_00},
   211  	}
   212  	for _, tt := range tests {
   213  		if got := sources.VersionToInt(tt.arg); got != tt.want {
   214  			t.Errorf("VersionToInt() = %v, want %v", got, tt.want)
   215  		}
   216  	}
   217  }
   218  
   219  func TestSourceConn_FetchRuntimeInfo(t *testing.T) {
   220  	ctx := context.Background()
   221  
   222  	t.Run("cancelled context", func(t *testing.T) {
   223  		ctxNew, cancel := context.WithCancel(ctx)
   224  		cancel()
   225  		err := (&sources.SourceConn{}).FetchRuntimeInfo(ctxNew, true)
   226  		assert.Error(t, err)
   227  	})
   228  
   229  	t.Run("cached version", func(t *testing.T) {
   230  		md := &sources.SourceConn{
   231  			RuntimeInfo: sources.RuntimeInfo{
   232  				LastCheckedOn: time.Now().Add(-time.Minute),
   233  				Version:       42,
   234  			},
   235  		}
   236  		err := md.FetchRuntimeInfo(ctx, false)
   237  		assert.NoError(t, err)
   238  		assert.Equal(t, 42, md.Version)
   239  	})
   240  
   241  	t.Run("pgbouncer version fetch", func(t *testing.T) {
   242  		mock, err := pgxmock.NewPool()
   243  		require.NoError(t, err)
   244  		md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgBouncer})
   245  		md.Conn = mock
   246  		mock.ExpectQuery("SHOW VERSION").
   247  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   248  			WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("PgBouncer 1.12.0"))
   249  		err = md.FetchRuntimeInfo(ctx, true)
   250  		assert.NoError(t, err)
   251  		assert.Contains(t, md.VersionStr, "PgBouncer")
   252  		assert.True(t, md.Version > 0)
   253  		assert.NoError(t, mock.ExpectationsWereMet())
   254  	})
   255  
   256  	t.Run("pgpool version fetch", func(t *testing.T) {
   257  		mock, err := pgxmock.NewPool()
   258  		require.NoError(t, err)
   259  		md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgPool})
   260  		md.Conn = mock
   261  		mock.ExpectQuery("SHOW POOL_VERSION").
   262  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   263  			WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("4.1.2"))
   264  		err = md.FetchRuntimeInfo(ctx, true)
   265  		assert.NoError(t, err)
   266  		assert.Contains(t, md.VersionStr, "4.1.2")
   267  		assert.True(t, md.Version > 0)
   268  		assert.NoError(t, mock.ExpectationsWereMet())
   269  	})
   270  
   271  	t.Run("postgres version and extensions", func(t *testing.T) {
   272  		mock, err := pgxmock.NewPool()
   273  		require.NoError(t, err)
   274  		md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePostgres})
   275  		md.Conn = mock
   276  		mock.ExpectQuery("select").WillReturnRows(
   277  			pgxmock.NewRows([]string{"ver", "version", "pg_is_in_recovery", "current_database", "system_identifier", "is_superuser"}).
   278  				AddRow(13, "PostgreSQL 13.3", false, "testdb", "42424242", true),
   279  		)
   280  		mock.ExpectQuery("select").WillReturnRows(
   281  			pgxmock.NewRows([]string{"exec_env"}).AddRow("UNKNOWN"),
   282  		)
   283  		mock.ExpectQuery("select").WillReturnRows(
   284  			pgxmock.NewRows([]string{"approx_size"}).AddRow(42),
   285  		)
   286  
   287  		mock.ExpectQuery("select").WillReturnRows(
   288  			pgxmock.NewRows([]string{"extname", "extversion"}).AddRow("pg_stat_statements", "1.8"),
   289  		)
   290  		err = md.FetchRuntimeInfo(ctx, true)
   291  		assert.NoError(t, err)
   292  		assert.Equal(t, 13, md.Version)
   293  		assert.Equal(t, "testdb", md.RealDbname)
   294  		assert.Contains(t, md.Extensions, "pg_stat_statements")
   295  		assert.NoError(t, mock.ExpectationsWereMet())
   296  	})
   297  
   298  	t.Run("query error", func(t *testing.T) {
   299  		mock, err := pgxmock.NewPool()
   300  		require.NoError(t, err)
   301  		md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgBouncer})
   302  		md.Conn = mock
   303  		mock.ExpectQuery("SHOW VERSION").
   304  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   305  			WillReturnError(fmt.Errorf("db error"))
   306  		err = md.FetchRuntimeInfo(ctx, true)
   307  		assert.Error(t, err)
   308  		assert.NoError(t, mock.ExpectationsWereMet())
   309  	})
   310  }
   311  
   312  func TestSourceConn_FetchVersion(t *testing.T) {
   313  	ctx := context.Background()
   314  
   315  	t.Run("valid version string", func(t *testing.T) {
   316  		mock, err := pgxmock.NewPool()
   317  		require.NoError(t, err)
   318  		md := &sources.SourceConn{Conn: mock}
   319  		mock.ExpectQuery("SHOW VERSION").
   320  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   321  			WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("FooBar 1.12.0"))
   322  		verStr, verInt, err := md.FetchVersion(ctx, "SHOW VERSION")
   323  		assert.NoError(t, err)
   324  		assert.Equal(t, "FooBar 1.12.0", verStr)
   325  		assert.Equal(t, 1_12_00, verInt)
   326  		assert.NoError(t, mock.ExpectationsWereMet())
   327  	})
   328  
   329  	t.Run("invalid version string", func(t *testing.T) {
   330  		mock, err := pgxmock.NewPool()
   331  		require.NoError(t, err)
   332  		md := &sources.SourceConn{Conn: mock}
   333  		mock.ExpectQuery("SHOW VERSION").
   334  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   335  			WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("invalid version"))
   336  		_, verInt, err := md.FetchVersion(ctx, "SHOW VERSION")
   337  		assert.Equal(t, 0, verInt)
   338  		assert.NoError(t, err)
   339  		assert.NoError(t, mock.ExpectationsWereMet())
   340  	})
   341  
   342  	t.Run("query error", func(t *testing.T) {
   343  		mock, err := pgxmock.NewPool()
   344  		require.NoError(t, err)
   345  		md := &sources.SourceConn{Conn: mock}
   346  		mock.ExpectQuery("SHOW VERSION").
   347  			WithArgs(pgx.QueryExecModeSimpleProtocol).
   348  			WillReturnError(assert.AnError)
   349  		_, _, err = md.FetchVersion(ctx, "SHOW VERSION")
   350  		assert.Error(t, err)
   351  		assert.NoError(t, mock.ExpectationsWereMet())
   352  	})
   353  }
   354  
   355  func TestSourceConn_GetClusterIdentifier(t *testing.T) {
   356  	md := &sources.SourceConn{
   357  		Source: sources.Source{
   358  			Name:    "test",
   359  			Kind:    sources.SourcePostgres,
   360  			ConnStr: "postgres://user:password@localhost:5432/mydatabase",
   361  		},
   362  		RuntimeInfo: sources.RuntimeInfo{
   363  			SystemIdentifier: "42424242",
   364  		},
   365  	}
   366  	assert.Equal(t, "42424242:localhost:5432", md.GetClusterIdentifier())
   367  
   368  	md = &sources.SourceConn{
   369  		Source: sources.Source{
   370  			Name:    "test",
   371  			Kind:    sources.SourcePostgres,
   372  			ConnStr: "foo boo",
   373  		},
   374  	}
   375  	assert.Equal(t, "", md.GetClusterIdentifier())
   376  }
   377