...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/db/conn_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/db

     1  package db_test
     2  
     3  import (
     4  	"encoding/binary"
     5  	"os"
     6  	"path/filepath"
     7  	"reflect"
     8  	"testing"
     9  
    10  	"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
    11  	"github.com/pashagolub/pgxmock/v4"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func TestMarshallParam(t *testing.T) {
    16  	tests := []struct {
    17  		name string
    18  		v    any
    19  		want any
    20  	}{
    21  		{
    22  			name: "nil",
    23  			v:    nil,
    24  			want: nil,
    25  		},
    26  		{
    27  			name: "empty map",
    28  			v:    map[string]string{},
    29  			want: nil,
    30  		},
    31  		{
    32  			name: "empty slice",
    33  			v:    []string{},
    34  			want: nil,
    35  		},
    36  		{
    37  			name: "empty struct",
    38  			v:    struct{}{},
    39  			want: nil,
    40  		},
    41  		{
    42  			name: "non-empty map",
    43  			v:    map[string]string{"key": "value"},
    44  			want: `{"key":"value"}`,
    45  		},
    46  		{
    47  			name: "non-empty slice",
    48  			v:    []string{"value"},
    49  			want: `["value"]`,
    50  		},
    51  		{
    52  			name: "non-empty struct",
    53  			v:    struct{ Key string }{Key: "value"},
    54  			want: `{"Key":"value"}`,
    55  		},
    56  		{
    57  			name: "non-marshallable",
    58  			v:    make(chan struct{}),
    59  			want: nil,
    60  		},
    61  	}
    62  	for _, tt := range tests {
    63  		t.Run(tt.name, func(t *testing.T) {
    64  			if got := db.MarshallParamToJSONB(tt.v); !reflect.DeepEqual(got, tt.want) {
    65  				t.Errorf("MarshallParamToJSONB() = %v, want %v", got, tt.want)
    66  			}
    67  		})
    68  	}
    69  }
    70  
    71  func TestIsClientOnSameHost(t *testing.T) {
    72  	// Create a pgxmock pool
    73  	mock, err := pgxmock.NewPool()
    74  	if err != nil {
    75  		t.Fatalf("failed to create pgxmock pool: %v", err)
    76  	}
    77  	defer mock.Close()
    78  	dataDir := t.TempDir()
    79  	pgControl := filepath.Join(dataDir, "global")
    80  	require.NoError(t, os.MkdirAll(pgControl, 0755))
    81  	file, err := os.OpenFile(filepath.Join(pgControl, "pg_control"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
    82  	require.NoError(t, err)
    83  	err = binary.Write(file, binary.LittleEndian, uint64(12345))
    84  	require.NoError(t, err)
    85  	require.NoError(t, file.Close())
    86  
    87  	// Test cases
    88  	tests := []struct {
    89  		name      string
    90  		setupMock func()
    91  		want      bool
    92  		wantErr   bool
    93  	}{
    94  		{
    95  			name: "UNIX socket connection",
    96  			setupMock: func() {
    97  				mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
    98  					pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(true),
    99  				)
   100  			},
   101  			want:    true,
   102  			wantErr: false,
   103  		},
   104  		{
   105  			name: "Matching system identifier",
   106  			setupMock: func() {
   107  				mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
   108  					pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
   109  				)
   110  				mock.ExpectQuery(`SHOW`).WillReturnRows(
   111  					pgxmock.NewRows([]string{"data_directory"}).AddRow(dataDir),
   112  				)
   113  				mock.ExpectQuery(`SELECT`).WillReturnRows(
   114  					pgxmock.NewRows([]string{"system_identifier"}).AddRow(uint64(12345)),
   115  				)
   116  			},
   117  			want:    true,
   118  			wantErr: false,
   119  		},
   120  		{
   121  			name: "Non-matching system identifier",
   122  			setupMock: func() {
   123  				mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
   124  					pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
   125  				)
   126  				mock.ExpectQuery(`SHOW`).WillReturnRows(
   127  					pgxmock.NewRows([]string{"data_directory"}).AddRow(dataDir),
   128  				)
   129  				mock.ExpectQuery(`SELECT`).WillReturnRows(
   130  					pgxmock.NewRows([]string{"system_identifier"}).AddRow(uint64(42)),
   131  				)
   132  			},
   133  			want:    false,
   134  			wantErr: false,
   135  		},
   136  		{
   137  			name: "Error on COALESCE query",
   138  			setupMock: func() {
   139  				mock.ExpectQuery(`SELECT COALESCE`).WillReturnError(os.ErrInvalid)
   140  			},
   141  			want:    false,
   142  			wantErr: true,
   143  		},
   144  		{
   145  			name: "Error on SHOW query",
   146  			setupMock: func() {
   147  				mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
   148  					pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
   149  				)
   150  				mock.ExpectQuery(`SHOW`).WillReturnError(os.ErrInvalid)
   151  			},
   152  			want:    false,
   153  			wantErr: true,
   154  		},
   155  		{
   156  			name: "Error on SELECT system_identifier query",
   157  			setupMock: func() {
   158  				mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
   159  					pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
   160  				)
   161  				mock.ExpectQuery(`SHOW`).WillReturnRows(
   162  					pgxmock.NewRows([]string{"data_directory"}).AddRow(dataDir),
   163  				)
   164  				mock.ExpectQuery(`SELECT`).WillReturnError(os.ErrInvalid)
   165  			},
   166  			want:    false,
   167  			wantErr: true,
   168  		},
   169  		{
   170  			name: "Error on os.Open",
   171  			setupMock: func() {
   172  				mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
   173  					pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
   174  				)
   175  				mock.ExpectQuery(`SHOW`).WillReturnRows(
   176  					pgxmock.NewRows([]string{"data_directory"}).AddRow("invalid/path"),
   177  				)
   178  				mock.ExpectQuery(`SELECT`).WillReturnRows(
   179  					pgxmock.NewRows([]string{"system_identifier"}).AddRow(uint64(12345)),
   180  				)
   181  			},
   182  			want:    false,
   183  			wantErr: true,
   184  		},
   185  	}
   186  
   187  	for _, tt := range tests {
   188  		t.Run(tt.name, func(t *testing.T) {
   189  			tt.setupMock()
   190  			got, err := db.IsClientOnSameHost(mock)
   191  			if (err != nil) != tt.wantErr {
   192  				t.Errorf("IsClientOnSameHost() error = %v, wantErr %v", err, tt.wantErr)
   193  			}
   194  			if got != tt.want {
   195  				t.Errorf("IsClientOnSameHost() = %v, want %v", got, tt.want)
   196  			}
   197  		})
   198  	}
   199  }
   200