1 package webserver
2
3 import (
4 "io"
5 "io/fs"
6 "net/http"
7 "net/http/httptest"
8 "os"
9 "path"
10 "strings"
11 "testing"
12
13 "github.com/cybertec-postgresql/pgwatch/v5/internal/metrics"
14 "github.com/sirupsen/logrus"
15 "github.com/stretchr/testify/assert"
16 )
17
18 type mockFS struct {
19 OpenFunc func(name string) (fs.File, error)
20 }
21
22 func (m mockFS) Open(name string) (fs.File, error) {
23 return m.OpenFunc(name)
24 }
25
26 func TestServer_handleStatic(t *testing.T) {
27 tempFile := path.Join(t.TempDir(), "file.ext")
28 assert.NoError(t, os.WriteFile(tempFile, []byte(`{"foo": {"bar": 1}}`), 0644))
29
30 indexHTML := []byte(`<!DOCTYPE html><html><head><script>window.__PGWATCH_BASE_PATH__='';</script></head><body>{"foo": {"bar": 1}}</body></html>`)
31
32
33 origUIFS := uiFS
34 defer func() { uiFS = origUIFS }()
35
36 uiFS = mockFS{
37 OpenFunc: func(name string) (fs.File, error) {
38 switch name {
39 case "index.html", "static/file.ext":
40 return os.Open(tempFile)
41 case "badfile.ext":
42 return nil, fs.ErrInvalid
43 default:
44 return nil, fs.ErrNotExist
45 }
46 },
47 }
48
49 ts := &WebUIServer{
50 Logger: logrus.StandardLogger(),
51 indexHTML: indexHTML,
52 }
53
54 t.Run("not GET", func(t *testing.T) {
55 r := httptest.NewRequest(http.MethodPost, "/static/file.ext", nil)
56 w := httptest.NewRecorder()
57 ts.handleStatic(w, r)
58 resp := w.Result()
59 defer resp.Body.Close()
60 assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
61 body, _ := io.ReadAll(resp.Body)
62 assert.Equal(t, "Method Not Allowed\n", string(body))
63 })
64
65 t.Run("some static file", func(t *testing.T) {
66 r := httptest.NewRequest(http.MethodGet, "/static/file.ext", nil)
67 w := httptest.NewRecorder()
68 ts.handleStatic(w, r)
69 resp := w.Result()
70 defer resp.Body.Close()
71 assert.Equal(t, http.StatusOK, resp.StatusCode)
72 body, _ := io.ReadAll(resp.Body)
73 var got map[string]metrics.Metric
74 assert.NoError(t, json.Unmarshal(body, &got))
75 assert.Contains(t, got, "foo")
76 })
77
78 t.Run("predefined route", func(t *testing.T) {
79 r := httptest.NewRequest(http.MethodGet, "/metrics", nil)
80 w := httptest.NewRecorder()
81 ts.handleStatic(w, r)
82 resp := w.Result()
83 defer resp.Body.Close()
84 assert.Equal(t, http.StatusOK, resp.StatusCode)
85 assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
86 body, _ := io.ReadAll(resp.Body)
87 bodyStr := string(body)
88 assert.Contains(t, bodyStr, "<!DOCTYPE html>")
89 assert.Contains(t, bodyStr, "window.__PGWATCH_BASE_PATH__")
90 })
91
92 t.Run("file not found", func(t *testing.T) {
93 r := httptest.NewRequest(http.MethodGet, "/static/notfound.ext", nil)
94 w := httptest.NewRecorder()
95 ts.handleStatic(w, r)
96 resp := w.Result()
97 defer resp.Body.Close()
98 assert.Equal(t, http.StatusNotFound, resp.StatusCode)
99 body, _ := io.ReadAll(resp.Body)
100 assert.Equal(t, "404 page not found\n", string(body))
101 })
102
103 t.Run("file cannot be read", func(t *testing.T) {
104 r := httptest.NewRequest(http.MethodGet, "/badfile.ext", nil)
105 w := httptest.NewRecorder()
106 ts.handleStatic(w, r)
107 resp := w.Result()
108 defer resp.Body.Close()
109 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
110 })
111 }
112
113 func TestServer_handleTestConnect(t *testing.T) {
114 ts := &WebUIServer{
115 Logger: logrus.StandardLogger(),
116 }
117
118 t.Run("POST", func(t *testing.T) {
119 r := httptest.NewRequest(http.MethodPost, "/testconnect", strings.NewReader("bad connection string"))
120 w := httptest.NewRecorder()
121 ts.handleTestConnect(w, r)
122 resp := w.Result()
123 defer resp.Body.Close()
124 assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
125 })
126
127 t.Run("failed reader", func(t *testing.T) {
128 r := httptest.NewRequest(http.MethodPost, "/testconnect", &errorReader{})
129 w := httptest.NewRecorder()
130 ts.handleTestConnect(w, r)
131 resp := w.Result()
132 defer resp.Body.Close()
133 assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
134 })
135
136 t.Run("GET", func(t *testing.T) {
137 r := httptest.NewRequest(http.MethodGet, "/testconnect", nil)
138 w := httptest.NewRecorder()
139 ts.handleTestConnect(w, r)
140 resp := w.Result()
141 defer resp.Body.Close()
142 assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
143 body, _ := io.ReadAll(resp.Body)
144 assert.Equal(t, "Method Not Allowed\n", string(body))
145 })
146 }
147