1 package webserver 2 3 import ( 4 "io" 5 "io/fs" 6 "net/http" 7 "net/http/httptest" 8 "os" 9 "path" 10 "strings" 11 "testing" 12 13 "github.com/cybertec-postgresql/pgwatch/v3/internal/metrics" 14 "github.com/sirupsen/logrus" 15 "github.com/stretchr/testify/assert" 16 ) 17 18 type mockFS struct { 19 OpenFunc func(name string) (fs.File, error) 20 } 21 22 func (m mockFS) Open(name string) (fs.File, error) { 23 return m.OpenFunc(name) 24 } 25 26 func TestServer_handleStatic(t *testing.T) { 27 tempFile := path.Join(t.TempDir(), "file.ext") 28 assert.NoError(t, os.WriteFile(tempFile, []byte(`{"foo": {"bar": 1}}`), 0644)) 29 ts := &WebUIServer{ 30 Logger: logrus.StandardLogger(), 31 uiFS: mockFS{ 32 OpenFunc: func(name string) (fs.File, error) { 33 switch name { 34 case "index.html", "static/file.ext": 35 return os.Open(tempFile) 36 case "badfile.ext": 37 return nil, fs.ErrInvalid 38 default: 39 return nil, fs.ErrNotExist 40 } 41 }, 42 }, 43 } 44 45 t.Run("not GET", func(t *testing.T) { 46 r := httptest.NewRequest(http.MethodPost, "/static/file.ext", nil) 47 w := httptest.NewRecorder() 48 ts.handleStatic(w, r) 49 resp := w.Result() 50 defer resp.Body.Close() 51 assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) 52 body, _ := io.ReadAll(resp.Body) 53 assert.Equal(t, "Method Not Allowed\n", string(body)) 54 }) 55 56 t.Run("some static file", func(t *testing.T) { 57 r := httptest.NewRequest(http.MethodGet, "/static/file.ext", nil) 58 w := httptest.NewRecorder() 59 ts.handleStatic(w, r) 60 resp := w.Result() 61 defer resp.Body.Close() 62 assert.Equal(t, http.StatusOK, resp.StatusCode) 63 body, _ := io.ReadAll(resp.Body) 64 var got map[string]metrics.Metric 65 assert.NoError(t, json.Unmarshal(body, &got)) 66 assert.Contains(t, got, "foo") 67 }) 68 69 t.Run("predefined route", func(t *testing.T) { 70 r := httptest.NewRequest(http.MethodGet, "/metrics", nil) 71 w := httptest.NewRecorder() 72 ts.handleStatic(w, r) 73 resp := w.Result() 74 defer resp.Body.Close() 75 assert.Equal(t, http.StatusOK, resp.StatusCode) 76 body, _ := io.ReadAll(resp.Body) 77 var got map[string]metrics.Metric 78 assert.NoError(t, json.Unmarshal(body, &got)) 79 assert.Contains(t, got, "foo") 80 }) 81 82 t.Run("file not found", func(t *testing.T) { 83 r := httptest.NewRequest(http.MethodGet, "/static/notfound.ext", nil) 84 w := httptest.NewRecorder() 85 ts.handleStatic(w, r) 86 resp := w.Result() 87 defer resp.Body.Close() 88 assert.Equal(t, http.StatusNotFound, resp.StatusCode) 89 body, _ := io.ReadAll(resp.Body) 90 assert.Equal(t, "404 page not found\n", string(body)) 91 }) 92 93 t.Run("file cannot be read", func(t *testing.T) { 94 r := httptest.NewRequest(http.MethodGet, "/badfile.ext", nil) 95 w := httptest.NewRecorder() 96 ts.handleStatic(w, r) 97 resp := w.Result() 98 defer resp.Body.Close() 99 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 100 }) 101 } 102 103 func TestServer_handleTestConnect(t *testing.T) { 104 ts := &WebUIServer{ 105 Logger: logrus.StandardLogger(), 106 } 107 108 t.Run("POST", func(t *testing.T) { 109 r := httptest.NewRequest(http.MethodPost, "/testconnect", strings.NewReader("bad connection string")) 110 w := httptest.NewRecorder() 111 ts.handleTestConnect(w, r) 112 resp := w.Result() 113 defer resp.Body.Close() 114 assert.Equal(t, http.StatusBadRequest, resp.StatusCode) 115 }) 116 117 t.Run("failed reader", func(t *testing.T) { 118 r := httptest.NewRequest(http.MethodPost, "/testconnect", &errorReader{}) 119 w := httptest.NewRecorder() 120 ts.handleTestConnect(w, r) 121 resp := w.Result() 122 defer resp.Body.Close() 123 assert.Equal(t, http.StatusBadRequest, resp.StatusCode) 124 }) 125 126 t.Run("GET", func(t *testing.T) { 127 r := httptest.NewRequest(http.MethodGet, "/testconnect", nil) 128 w := httptest.NewRecorder() 129 ts.handleTestConnect(w, r) 130 resp := w.Result() 131 defer resp.Body.Close() 132 assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) 133 body, _ := io.ReadAll(resp.Body) 134 assert.Equal(t, "Method Not Allowed\n", string(body)) 135 }) 136 } 137