1 package webserver
2
3 import (
4 "bytes"
5 "errors"
6 "io"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "testing"
11
12 "github.com/cybertec-postgresql/pgwatch/v5/internal/sources"
13 "github.com/cybertec-postgresql/pgwatch/v5/internal/testutil"
14 jsoniter "github.com/json-iterator/go"
15 "github.com/stretchr/testify/assert"
16 )
17
18 func newTestSourceServer(mrw *testutil.MockSourcesReaderWriter) *WebUIServer {
19 return &WebUIServer{
20 sourcesReaderWriter: mrw,
21 }
22 }
23
24 func TestHandleSources(t *testing.T) {
25 t.Run("GET", func(t *testing.T) {
26 t.Run("Success", func(t *testing.T) {
27 mock := &testutil.MockSourcesReaderWriter{
28 GetSourcesFunc: func() (sources.Sources, error) {
29 return sources.Sources{{Name: "foo"}}, nil
30 },
31 }
32 ts := newTestSourceServer(mock)
33 r := httptest.NewRequest(http.MethodGet, "/source", nil)
34 w := httptest.NewRecorder()
35 ts.handleSources(w, r)
36 resp := w.Result()
37 defer resp.Body.Close()
38 assert.Equal(t, http.StatusOK, resp.StatusCode)
39 body, _ := io.ReadAll(resp.Body)
40 var got []sources.Source
41 assert.NoError(t, jsoniter.ConfigFastest.Unmarshal(body, &got))
42 assert.Equal(t, "foo", got[0].Name)
43 })
44
45 t.Run("Failure", func(t *testing.T) {
46 mock := &testutil.MockSourcesReaderWriter{
47 GetSourcesFunc: func() (sources.Sources, error) {
48 return nil, errors.New("fail")
49 },
50 }
51 ts := newTestSourceServer(mock)
52 r := httptest.NewRequest(http.MethodGet, "/source", nil)
53 w := httptest.NewRecorder()
54 ts.handleSources(w, r)
55 resp := w.Result()
56 defer resp.Body.Close()
57 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
58 body, _ := io.ReadAll(resp.Body)
59 assert.Contains(t, string(body), "fail")
60 })
61 })
62
63 t.Run("POST", func(t *testing.T) {
64 t.Run("Success", func(t *testing.T) {
65 var createdSource sources.Source
66 mock := &testutil.MockSourcesReaderWriter{
67 CreateSourceFunc: func(md sources.Source) error {
68 createdSource = md
69 return nil
70 },
71 }
72 ts := newTestSourceServer(mock)
73 src := sources.Source{Name: "bar"}
74 b, _ := jsoniter.ConfigFastest.Marshal(src)
75 r := httptest.NewRequest(http.MethodPost, "/source", bytes.NewReader(b))
76 w := httptest.NewRecorder()
77 ts.handleSources(w, r)
78 resp := w.Result()
79 defer resp.Body.Close()
80 assert.Equal(t, http.StatusCreated, resp.StatusCode)
81 assert.Equal(t, src, createdSource)
82 })
83
84 t.Run("ReaderFailure", func(t *testing.T) {
85 mock := &testutil.MockSourcesReaderWriter{
86 CreateSourceFunc: func(sources.Source) error {
87 return nil
88 },
89 }
90 ts := newTestSourceServer(mock)
91 r := httptest.NewRequest(http.MethodPost, "/Source?name=bar", &errorReader{})
92 w := httptest.NewRecorder()
93 ts.handleSources(w, r)
94 resp := w.Result()
95 defer resp.Body.Close()
96 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
97 body, _ := io.ReadAll(resp.Body)
98 assert.Contains(t, string(body), "mock read error")
99 })
100
101 t.Run("Conflict", func(t *testing.T) {
102 mock := &testutil.MockSourcesReaderWriter{
103 CreateSourceFunc: func(sources.Source) error {
104 return sources.ErrSourceExists
105 },
106 }
107 ts := newTestSourceServer(mock)
108 src := sources.Source{Name: "bar"}
109 b, _ := jsoniter.ConfigFastest.Marshal(src)
110 r := httptest.NewRequest(http.MethodPost, "/source", bytes.NewReader(b))
111 w := httptest.NewRecorder()
112 ts.handleSources(w, r)
113 resp := w.Result()
114 defer resp.Body.Close()
115 assert.Equal(t, http.StatusConflict, resp.StatusCode)
116 body, _ := io.ReadAll(resp.Body)
117 assert.Contains(t, string(body), "source already exists")
118 })
119
120 t.Run("CreateFailure", func(t *testing.T) {
121 mock := &testutil.MockSourcesReaderWriter{
122 CreateSourceFunc: func(sources.Source) error {
123 return errors.New("fail")
124 },
125 }
126 ts := newTestSourceServer(mock)
127 src := sources.Source{Name: "bar"}
128 b, _ := jsoniter.ConfigFastest.Marshal(src)
129 r := httptest.NewRequest(http.MethodPost, "/source", bytes.NewReader(b))
130 w := httptest.NewRecorder()
131 ts.handleSources(w, r)
132 resp := w.Result()
133 defer resp.Body.Close()
134 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
135 body, _ := io.ReadAll(resp.Body)
136 assert.Contains(t, string(body), "fail")
137 })
138
139 t.Run("ReadAllError", func(t *testing.T) {
140 mock := &testutil.MockSourcesReaderWriter{}
141 ts := newTestSourceServer(mock)
142 r := httptest.NewRequest(http.MethodPost, "/source", &errorReader{})
143 w := httptest.NewRecorder()
144 ts.handleSources(w, r)
145 resp := w.Result()
146 defer resp.Body.Close()
147 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
148 })
149 })
150
151 t.Run("OPTIONS", func(t *testing.T) {
152 mock := &testutil.MockSourcesReaderWriter{}
153 ts := newTestSourceServer(mock)
154 r := httptest.NewRequest(http.MethodOptions, "/source", nil)
155 w := httptest.NewRecorder()
156 ts.handleSources(w, r)
157 resp := w.Result()
158 defer resp.Body.Close()
159 assert.Equal(t, http.StatusOK, resp.StatusCode)
160 assert.Equal(t, "GET, POST, OPTIONS", resp.Header.Get("Allow"))
161 })
162
163 t.Run("MethodNotAllowed", func(t *testing.T) {
164 mock := &testutil.MockSourcesReaderWriter{}
165 ts := newTestSourceServer(mock)
166 r := httptest.NewRequest(http.MethodPut, "/source", nil)
167 w := httptest.NewRecorder()
168 ts.handleSources(w, r)
169 resp := w.Result()
170 defer resp.Body.Close()
171 assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
172 assert.Equal(t, "GET, POST, OPTIONS", resp.Header.Get("Allow"))
173 })
174 }
175
176 func TestGetSources_Error(t *testing.T) {
177 mock := &testutil.MockSourcesReaderWriter{
178 GetSourcesFunc: func() (sources.Sources, error) {
179 return nil, errors.New("fail")
180 },
181 }
182 ts := newTestSourceServer(mock)
183 _, err := ts.GetSources()
184 assert.Error(t, err)
185 }
186
187 func TestUpdateSource_Error(t *testing.T) {
188 mock := &testutil.MockSourcesReaderWriter{
189 UpdateSourceFunc: func(sources.Source) error {
190 return errors.New("fail")
191 },
192 }
193 ts := newTestSourceServer(mock)
194 err := ts.UpdateSource([]byte("notjson"))
195 assert.Error(t, err)
196 }
197
198 func TestDeleteSource_Error(t *testing.T) {
199 mock := &testutil.MockSourcesReaderWriter{
200 DeleteSourceFunc: func(string) error {
201 return errors.New("fail")
202 },
203 }
204 ts := newTestSourceServer(mock)
205 err := ts.DeleteSource("foo")
206 assert.Error(t, err)
207 }
208
209
210 func newSourceItemRequest(method, name string, body io.Reader) *http.Request {
211 url := "/source/" + name
212 r := httptest.NewRequest(method, url, body)
213 r.SetPathValue("name", name)
214 return r
215 }
216
217 func TestHandleSourceItem(t *testing.T) {
218 t.Run("GET", func(t *testing.T) {
219 t.Run("Success", func(t *testing.T) {
220 source := sources.Source{Name: "test-source", ConnStr: "postgresql://test"}
221 mock := &testutil.MockSourcesReaderWriter{
222 GetSourcesFunc: func() (sources.Sources, error) {
223 return sources.Sources{source}, nil
224 },
225 }
226 ts := newTestSourceServer(mock)
227 r := newSourceItemRequest(http.MethodGet, "test-source", nil)
228 w := httptest.NewRecorder()
229 ts.handleSourceItem(w, r)
230 resp := w.Result()
231 defer resp.Body.Close()
232 assert.Equal(t, http.StatusOK, resp.StatusCode)
233 assert.Equal(t, "application/json", resp.Header.Get("Content-Type"))
234
235 var returnedSource sources.Source
236 body, _ := io.ReadAll(resp.Body)
237 assert.NoError(t, jsoniter.ConfigFastest.Unmarshal(body, &returnedSource))
238 assert.Equal(t, source.Name, returnedSource.Name)
239 })
240
241 t.Run("NotFound", func(t *testing.T) {
242 mock := &testutil.MockSourcesReaderWriter{
243 GetSourcesFunc: func() (sources.Sources, error) {
244 return sources.Sources{}, nil
245 },
246 }
247 ts := newTestSourceServer(mock)
248 r := newSourceItemRequest(http.MethodGet, "nonexistent", nil)
249 w := httptest.NewRecorder()
250 ts.handleSourceItem(w, r)
251 resp := w.Result()
252 defer resp.Body.Close()
253 assert.Equal(t, http.StatusNotFound, resp.StatusCode)
254 body, _ := io.ReadAll(resp.Body)
255 assert.Contains(t, string(body), "source not found")
256 })
257
258 t.Run("GetSourcesError", func(t *testing.T) {
259 mock := &testutil.MockSourcesReaderWriter{
260 GetSourcesFunc: func() (sources.Sources, error) {
261 return nil, errors.New("database connection failed")
262 },
263 }
264 ts := newTestSourceServer(mock)
265 r := newSourceItemRequest(http.MethodGet, "test-source", nil)
266 w := httptest.NewRecorder()
267 ts.handleSourceItem(w, r)
268 resp := w.Result()
269 defer resp.Body.Close()
270 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
271 body, _ := io.ReadAll(resp.Body)
272 assert.Contains(t, string(body), "database connection failed")
273 })
274 })
275
276 t.Run("PUT", func(t *testing.T) {
277 t.Run("Success", func(t *testing.T) {
278 existingSource := sources.Source{Name: "test-source", ConnStr: "postgresql://old"}
279 var updatedSource sources.Source
280 mock := &testutil.MockSourcesReaderWriter{
281 GetSourcesFunc: func() (sources.Sources, error) {
282 return sources.Sources{existingSource}, nil
283 },
284 UpdateSourceFunc: func(md sources.Source) error {
285 updatedSource = md
286 return nil
287 },
288 }
289 ts := newTestSourceServer(mock)
290
291 newSource := sources.Source{Name: "test-source", ConnStr: "postgresql://new"}
292 b, _ := jsoniter.ConfigFastest.Marshal(newSource)
293 r := newSourceItemRequest(http.MethodPut, "test-source", bytes.NewReader(b))
294 w := httptest.NewRecorder()
295 ts.handleSourceItem(w, r)
296 resp := w.Result()
297 defer resp.Body.Close()
298 assert.Equal(t, http.StatusOK, resp.StatusCode)
299 assert.Equal(t, newSource.ConnStr, updatedSource.ConnStr)
300 })
301
302 t.Run("CreateNew", func(t *testing.T) {
303 var updatedSource sources.Source
304 mock := &testutil.MockSourcesReaderWriter{
305 GetSourcesFunc: func() (sources.Sources, error) {
306 return sources.Sources{}, nil
307 },
308 UpdateSourceFunc: func(md sources.Source) error {
309 updatedSource = md
310 return nil
311 },
312 }
313 ts := newTestSourceServer(mock)
314
315 source := sources.Source{Name: "new-source", ConnStr: "postgresql://new"}
316 b, _ := jsoniter.ConfigFastest.Marshal(source)
317 r := newSourceItemRequest(http.MethodPut, "new-source", bytes.NewReader(b))
318 w := httptest.NewRecorder()
319 ts.handleSourceItem(w, r)
320 resp := w.Result()
321 defer resp.Body.Close()
322 assert.Equal(t, http.StatusOK, resp.StatusCode)
323 assert.Equal(t, source.Name, updatedSource.Name)
324 assert.Equal(t, source.ConnStr, updatedSource.ConnStr)
325 })
326
327 t.Run("NameMismatch", func(t *testing.T) {
328 existingSource := sources.Source{Name: "test-source"}
329 mock := &testutil.MockSourcesReaderWriter{
330 GetSourcesFunc: func() (sources.Sources, error) {
331 return sources.Sources{existingSource}, nil
332 },
333 }
334 ts := newTestSourceServer(mock)
335
336
337 source := sources.Source{Name: "different-name"}
338 b, _ := jsoniter.ConfigFastest.Marshal(source)
339 r := newSourceItemRequest(http.MethodPut, "test-source", bytes.NewReader(b))
340 w := httptest.NewRecorder()
341 ts.handleSourceItem(w, r)
342 resp := w.Result()
343 defer resp.Body.Close()
344 assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
345 body, _ := io.ReadAll(resp.Body)
346 assert.Contains(t, string(body), "name in URL and body must match")
347 })
348
349 t.Run("InvalidRequestBody", func(t *testing.T) {
350 mock := &testutil.MockSourcesReaderWriter{}
351 ts := newTestSourceServer(mock)
352 r := newSourceItemRequest(http.MethodPut, "test-source", &errorReader{})
353 w := httptest.NewRecorder()
354 ts.handleSourceItem(w, r)
355 resp := w.Result()
356 defer resp.Body.Close()
357 assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
358 body, _ := io.ReadAll(resp.Body)
359 assert.Contains(t, string(body), "invalid request body")
360 })
361
362 t.Run("InvalidJSON", func(t *testing.T) {
363 mock := &testutil.MockSourcesReaderWriter{}
364 ts := newTestSourceServer(mock)
365 r := newSourceItemRequest(http.MethodPut, "test-source", strings.NewReader("invalid json"))
366 w := httptest.NewRecorder()
367 ts.handleSourceItem(w, r)
368 resp := w.Result()
369 defer resp.Body.Close()
370 assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
371 body, _ := io.ReadAll(resp.Body)
372 assert.Contains(t, string(body), "invalid JSON format")
373 })
374
375 t.Run("UpdateError", func(t *testing.T) {
376 mock := &testutil.MockSourcesReaderWriter{
377 UpdateSourceFunc: func(sources.Source) error {
378 return errors.New("update operation failed")
379 },
380 }
381 ts := newTestSourceServer(mock)
382
383 source := sources.Source{Name: "test-source", ConnStr: "postgresql://test"}
384 b, _ := jsoniter.ConfigFastest.Marshal(source)
385 r := newSourceItemRequest(http.MethodPut, "test-source", bytes.NewReader(b))
386 w := httptest.NewRecorder()
387 ts.handleSourceItem(w, r)
388 resp := w.Result()
389 defer resp.Body.Close()
390 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
391 body, _ := io.ReadAll(resp.Body)
392 assert.Contains(t, string(body), "update operation failed")
393 })
394 })
395
396 t.Run("DELETE", func(t *testing.T) {
397 t.Run("Success", func(t *testing.T) {
398 existingSource := sources.Source{Name: "test-source"}
399 var deletedName string
400 mock := &testutil.MockSourcesReaderWriter{
401 GetSourcesFunc: func() (sources.Sources, error) {
402 return sources.Sources{existingSource}, nil
403 },
404 DeleteSourceFunc: func(name string) error {
405 deletedName = name
406 return nil
407 },
408 }
409 ts := newTestSourceServer(mock)
410 r := newSourceItemRequest(http.MethodDelete, "test-source", nil)
411 w := httptest.NewRecorder()
412 ts.handleSourceItem(w, r)
413 resp := w.Result()
414 defer resp.Body.Close()
415 assert.Equal(t, http.StatusOK, resp.StatusCode)
416 assert.Equal(t, "test-source", deletedName)
417 })
418
419 t.Run("Idempotent", func(t *testing.T) {
420 var deletedName string
421 mock := &testutil.MockSourcesReaderWriter{
422 GetSourcesFunc: func() (sources.Sources, error) {
423 return sources.Sources{}, nil
424 },
425 DeleteSourceFunc: func(name string) error {
426 deletedName = name
427 return nil
428 },
429 }
430 ts := newTestSourceServer(mock)
431 r := newSourceItemRequest(http.MethodDelete, "nonexistent", nil)
432 w := httptest.NewRecorder()
433 ts.handleSourceItem(w, r)
434 resp := w.Result()
435 defer resp.Body.Close()
436 assert.Equal(t, http.StatusOK, resp.StatusCode)
437 assert.Equal(t, "nonexistent", deletedName)
438 })
439
440 t.Run("DeleteError", func(t *testing.T) {
441 mock := &testutil.MockSourcesReaderWriter{
442 DeleteSourceFunc: func(string) error {
443 return errors.New("delete operation failed")
444 },
445 }
446 ts := newTestSourceServer(mock)
447 r := newSourceItemRequest(http.MethodDelete, "test-source", nil)
448 w := httptest.NewRecorder()
449 ts.handleSourceItem(w, r)
450 resp := w.Result()
451 defer resp.Body.Close()
452 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
453 body, _ := io.ReadAll(resp.Body)
454 assert.Contains(t, string(body), "delete operation failed")
455 })
456 })
457
458 t.Run("EmptyName", func(t *testing.T) {
459 mock := &testutil.MockSourcesReaderWriter{}
460 ts := newTestSourceServer(mock)
461 r := newSourceItemRequest(http.MethodGet, "", nil)
462 w := httptest.NewRecorder()
463 ts.handleSourceItem(w, r)
464 resp := w.Result()
465 defer resp.Body.Close()
466 assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
467 body, _ := io.ReadAll(resp.Body)
468 assert.Contains(t, string(body), "source name is required")
469 })
470
471 t.Run("OPTIONS", func(t *testing.T) {
472 mock := &testutil.MockSourcesReaderWriter{}
473 ts := newTestSourceServer(mock)
474 r := newSourceItemRequest(http.MethodOptions, "test", nil)
475 w := httptest.NewRecorder()
476 ts.handleSourceItem(w, r)
477 resp := w.Result()
478 defer resp.Body.Close()
479 assert.Equal(t, http.StatusOK, resp.StatusCode)
480 assert.Equal(t, "GET, PUT, DELETE, OPTIONS", resp.Header.Get("Allow"))
481 })
482
483 t.Run("MethodNotAllowed", func(t *testing.T) {
484 mock := &testutil.MockSourcesReaderWriter{}
485 ts := newTestSourceServer(mock)
486 r := newSourceItemRequest(http.MethodPost, "test", nil)
487 w := httptest.NewRecorder()
488 ts.handleSourceItem(w, r)
489 resp := w.Result()
490 defer resp.Body.Close()
491 assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
492 assert.Equal(t, "GET, PUT, DELETE, OPTIONS", resp.Header.Get("Allow"))
493 })
494 }
495