1 package db_test
2
3 import (
4 "encoding/binary"
5 "os"
6 "path/filepath"
7 "reflect"
8 "testing"
9
10 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
11 "github.com/pashagolub/pgxmock/v4"
12 "github.com/stretchr/testify/require"
13 )
14
15 func TestMarshallParam(t *testing.T) {
16 tests := []struct {
17 name string
18 v any
19 want any
20 }{
21 {
22 name: "nil",
23 v: nil,
24 want: nil,
25 },
26 {
27 name: "empty map",
28 v: map[string]string{},
29 want: nil,
30 },
31 {
32 name: "empty slice",
33 v: []string{},
34 want: nil,
35 },
36 {
37 name: "empty struct",
38 v: struct{}{},
39 want: nil,
40 },
41 {
42 name: "non-empty map",
43 v: map[string]string{"key": "value"},
44 want: `{"key":"value"}`,
45 },
46 {
47 name: "non-empty slice",
48 v: []string{"value"},
49 want: `["value"]`,
50 },
51 {
52 name: "non-empty struct",
53 v: struct{ Key string }{Key: "value"},
54 want: `{"Key":"value"}`,
55 },
56 {
57 name: "non-marshallable",
58 v: make(chan struct{}),
59 want: nil,
60 },
61 }
62 for _, tt := range tests {
63 t.Run(tt.name, func(t *testing.T) {
64 if got := db.MarshallParamToJSONB(tt.v); !reflect.DeepEqual(got, tt.want) {
65 t.Errorf("MarshallParamToJSONB() = %v, want %v", got, tt.want)
66 }
67 })
68 }
69 }
70
71 func TestIsClientOnSameHost(t *testing.T) {
72
73 mock, err := pgxmock.NewPool()
74 if err != nil {
75 t.Fatalf("failed to create pgxmock pool: %v", err)
76 }
77 defer mock.Close()
78 dataDir := t.TempDir()
79 pgControl := filepath.Join(dataDir, "global")
80 require.NoError(t, os.MkdirAll(pgControl, 0755))
81 file, err := os.OpenFile(filepath.Join(pgControl, "pg_control"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
82 require.NoError(t, err)
83 err = binary.Write(file, binary.LittleEndian, uint64(12345))
84 require.NoError(t, err)
85 require.NoError(t, file.Close())
86
87
88 tests := []struct {
89 name string
90 setupMock func()
91 want bool
92 wantErr bool
93 }{
94 {
95 name: "UNIX socket connection",
96 setupMock: func() {
97 mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
98 pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(true),
99 )
100 },
101 want: true,
102 wantErr: false,
103 },
104 {
105 name: "Matching system identifier",
106 setupMock: func() {
107 mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
108 pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
109 )
110 mock.ExpectQuery(`SHOW`).WillReturnRows(
111 pgxmock.NewRows([]string{"data_directory"}).AddRow(dataDir),
112 )
113 mock.ExpectQuery(`SELECT`).WillReturnRows(
114 pgxmock.NewRows([]string{"system_identifier"}).AddRow(uint64(12345)),
115 )
116 },
117 want: true,
118 wantErr: false,
119 },
120 {
121 name: "Non-matching system identifier",
122 setupMock: func() {
123 mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
124 pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
125 )
126 mock.ExpectQuery(`SHOW`).WillReturnRows(
127 pgxmock.NewRows([]string{"data_directory"}).AddRow(dataDir),
128 )
129 mock.ExpectQuery(`SELECT`).WillReturnRows(
130 pgxmock.NewRows([]string{"system_identifier"}).AddRow(uint64(42)),
131 )
132 },
133 want: false,
134 wantErr: false,
135 },
136 {
137 name: "Error on COALESCE query",
138 setupMock: func() {
139 mock.ExpectQuery(`SELECT COALESCE`).WillReturnError(os.ErrInvalid)
140 },
141 want: false,
142 wantErr: true,
143 },
144 {
145 name: "Error on SHOW query",
146 setupMock: func() {
147 mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
148 pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
149 )
150 mock.ExpectQuery(`SHOW`).WillReturnError(os.ErrInvalid)
151 },
152 want: false,
153 wantErr: true,
154 },
155 {
156 name: "Error on SELECT system_identifier query",
157 setupMock: func() {
158 mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
159 pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
160 )
161 mock.ExpectQuery(`SHOW`).WillReturnRows(
162 pgxmock.NewRows([]string{"data_directory"}).AddRow(dataDir),
163 )
164 mock.ExpectQuery(`SELECT`).WillReturnError(os.ErrInvalid)
165 },
166 want: false,
167 wantErr: true,
168 },
169 {
170 name: "Error on os.Open",
171 setupMock: func() {
172 mock.ExpectQuery(`SELECT COALESCE`).WillReturnRows(
173 pgxmock.NewRows([]string{"is_unix_socket"}).AddRow(false),
174 )
175 mock.ExpectQuery(`SHOW`).WillReturnRows(
176 pgxmock.NewRows([]string{"data_directory"}).AddRow("invalid/path"),
177 )
178 mock.ExpectQuery(`SELECT`).WillReturnRows(
179 pgxmock.NewRows([]string{"system_identifier"}).AddRow(uint64(12345)),
180 )
181 },
182 want: false,
183 wantErr: true,
184 },
185 }
186
187 for _, tt := range tests {
188 t.Run(tt.name, func(t *testing.T) {
189 tt.setupMock()
190 got, err := db.IsClientOnSameHost(mock)
191 if (err != nil) != tt.wantErr {
192 t.Errorf("IsClientOnSameHost() error = %v, wantErr %v", err, tt.wantErr)
193 }
194 if got != tt.want {
195 t.Errorf("IsClientOnSameHost() = %v, want %v", got, tt.want)
196 }
197 })
198 }
199 }
200