1 package webserver 2 3 import ( 4 "bytes" 5 "errors" 6 "io" 7 "net/http" 8 "net/http/httptest" 9 "testing" 10 11 "github.com/cybertec-postgresql/pgwatch/v3/internal/sources" 12 "github.com/stretchr/testify/assert" 13 ) 14 15 type mockSourcesReaderWriter struct { 16 GetSourcesFunc func() (sources.Sources, error) 17 UpdateSourceFunc func(md sources.Source) error 18 DeleteSourceFunc func(name string) error 19 WriteSourcesFunc func(sources.Sources) error 20 } 21 22 func (m *mockSourcesReaderWriter) GetSources() (sources.Sources, error) { 23 return m.GetSourcesFunc() 24 } 25 func (m *mockSourcesReaderWriter) UpdateSource(md sources.Source) error { 26 return m.UpdateSourceFunc(md) 27 } 28 func (m *mockSourcesReaderWriter) DeleteSource(name string) error { 29 return m.DeleteSourceFunc(name) 30 } 31 func (m *mockSourcesReaderWriter) WriteSources(srcs sources.Sources) error { 32 return m.WriteSourcesFunc(srcs) 33 } 34 35 func newTestSourceServer(mrw *mockSourcesReaderWriter) *WebUIServer { 36 return &WebUIServer{ 37 sourcesReaderWriter: mrw, 38 } 39 } 40 41 func TestHandleSources_GET(t *testing.T) { 42 mock := &mockSourcesReaderWriter{ 43 GetSourcesFunc: func() (sources.Sources, error) { 44 return sources.Sources{{Name: "foo"}}, nil 45 }, 46 } 47 ts := newTestSourceServer(mock) 48 r := httptest.NewRequest(http.MethodGet, "/source", nil) 49 w := httptest.NewRecorder() 50 ts.handleSources(w, r) 51 resp := w.Result() 52 defer resp.Body.Close() 53 assert.Equal(t, http.StatusOK, resp.StatusCode) 54 body, _ := io.ReadAll(resp.Body) 55 var got []sources.Source 56 assert.NoError(t, json.Unmarshal(body, &got)) 57 assert.Equal(t, "foo", got[0].Name) 58 } 59 60 func TestHandleSources_GET_Fail(t *testing.T) { 61 mock := &mockSourcesReaderWriter{ 62 GetSourcesFunc: func() (sources.Sources, error) { 63 return nil, errors.New("fail") 64 }, 65 } 66 ts := newTestSourceServer(mock) 67 r := httptest.NewRequest(http.MethodGet, "/source", nil) 68 w := httptest.NewRecorder() 69 ts.handleSources(w, r) 70 resp := w.Result() 71 defer resp.Body.Close() 72 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 73 body, _ := io.ReadAll(resp.Body) 74 assert.Contains(t, string(body), "fail") 75 } 76 77 func TestHandleSources_POST(t *testing.T) { 78 var updatedSource sources.Source 79 mock := &mockSourcesReaderWriter{ 80 UpdateSourceFunc: func(md sources.Source) error { 81 updatedSource = md 82 return nil 83 }, 84 } 85 ts := newTestSourceServer(mock) 86 src := sources.Source{Name: "bar"} 87 b, _ := json.Marshal(src) 88 r := httptest.NewRequest(http.MethodPost, "/source", bytes.NewReader(b)) 89 w := httptest.NewRecorder() 90 ts.handleSources(w, r) 91 resp := w.Result() 92 defer resp.Body.Close() 93 assert.Equal(t, http.StatusOK, resp.StatusCode) 94 assert.Equal(t, src, updatedSource) 95 } 96 97 func TestHandleSources_POST_ReaderFail(t *testing.T) { 98 mock := &mockSourcesReaderWriter{ 99 UpdateSourceFunc: func(sources.Source) error { 100 return nil 101 }, 102 } 103 ts := newTestSourceServer(mock) 104 r := httptest.NewRequest(http.MethodPost, "/Source?name=bar", &errorReader{}) 105 w := httptest.NewRecorder() 106 ts.handleSources(w, r) 107 resp := w.Result() 108 defer resp.Body.Close() 109 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 110 body, _ := io.ReadAll(resp.Body) 111 assert.Contains(t, string(body), "mock read error") 112 } 113 114 func TestHandleSources_POST_Fail(t *testing.T) { 115 mock := &mockSourcesReaderWriter{ 116 UpdateSourceFunc: func(sources.Source) error { 117 return errors.New("fail") 118 }, 119 } 120 ts := newTestSourceServer(mock) 121 src := sources.Source{Name: "bar"} 122 b, _ := json.Marshal(src) 123 r := httptest.NewRequest(http.MethodPost, "/source", bytes.NewReader(b)) 124 w := httptest.NewRecorder() 125 ts.handleSources(w, r) 126 resp := w.Result() 127 defer resp.Body.Close() 128 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 129 body, _ := io.ReadAll(resp.Body) 130 assert.Contains(t, string(body), "fail") 131 } 132 133 func TestHandleSources_DELETE(t *testing.T) { 134 var deletedName string 135 mock := &mockSourcesReaderWriter{ 136 DeleteSourceFunc: func(name string) error { 137 deletedName = name 138 return nil 139 }, 140 } 141 ts := newTestSourceServer(mock) 142 r := httptest.NewRequest(http.MethodDelete, "/source?name=foo", nil) 143 w := httptest.NewRecorder() 144 ts.handleSources(w, r) 145 resp := w.Result() 146 defer resp.Body.Close() 147 assert.Equal(t, http.StatusOK, resp.StatusCode) 148 assert.Equal(t, "foo", deletedName) 149 } 150 151 func TestHandleSources_Options(t *testing.T) { 152 mock := &mockSourcesReaderWriter{} 153 ts := newTestSourceServer(mock) 154 r := httptest.NewRequest(http.MethodOptions, "/source", nil) 155 w := httptest.NewRecorder() 156 ts.handleSources(w, r) 157 resp := w.Result() 158 defer resp.Body.Close() 159 assert.Equal(t, http.StatusNoContent, resp.StatusCode) 160 assert.Equal(t, "GET, POST, DELETE, OPTIONS", resp.Header.Get("Allow")) 161 } 162 163 func TestHandleSources_MethodNotAllowed(t *testing.T) { 164 mock := &mockSourcesReaderWriter{} 165 ts := newTestSourceServer(mock) 166 r := httptest.NewRequest(http.MethodPut, "/source", nil) 167 w := httptest.NewRecorder() 168 ts.handleSources(w, r) 169 resp := w.Result() 170 defer resp.Body.Close() 171 assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) 172 assert.Equal(t, "GET, POST, DELETE, OPTIONS", resp.Header.Get("Allow")) 173 } 174 175 func TestGetSources_Error(t *testing.T) { 176 mock := &mockSourcesReaderWriter{ 177 GetSourcesFunc: func() (sources.Sources, error) { 178 return nil, errors.New("fail") 179 }, 180 } 181 ts := newTestSourceServer(mock) 182 _, err := ts.GetSources() 183 assert.Error(t, err) 184 } 185 186 func TestUpdateSource_Error(t *testing.T) { 187 mock := &mockSourcesReaderWriter{ 188 UpdateSourceFunc: func(sources.Source) error { 189 return errors.New("fail") 190 }, 191 } 192 ts := newTestSourceServer(mock) 193 err := ts.UpdateSource([]byte("notjson")) 194 assert.Error(t, err) 195 } 196 197 func TestDeleteSource_Error(t *testing.T) { 198 mock := &mockSourcesReaderWriter{ 199 DeleteSourceFunc: func(string) error { 200 return errors.New("fail") 201 }, 202 } 203 ts := newTestSourceServer(mock) 204 err := ts.DeleteSource("foo") 205 assert.Error(t, err) 206 } 207 208 func TestHandleSources_ReadAllError(t *testing.T) { 209 mock := &mockSourcesReaderWriter{} 210 ts := newTestSourceServer(mock) 211 r := httptest.NewRequest(http.MethodPost, "/source", &errorReader{}) 212 w := httptest.NewRecorder() 213 ts.handleSources(w, r) 214 resp := w.Result() 215 defer resp.Body.Close() 216 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 217 } 218