...

Source file src/github.com/cybertec-postgresql/pgwatch/v5/internal/webserver/source_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v5/internal/webserver

     1  package webserver
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/cybertec-postgresql/pgwatch/v5/internal/sources"
    13  	"github.com/cybertec-postgresql/pgwatch/v5/internal/testutil"
    14  	jsoniter "github.com/json-iterator/go"
    15  	"github.com/stretchr/testify/assert"
    16  )
    17  
    18  func newTestSourceServer(mrw *testutil.MockSourcesReaderWriter) *WebUIServer {
    19  	return &WebUIServer{
    20  		sourcesReaderWriter: mrw,
    21  	}
    22  }
    23  
    24  func TestHandleSources(t *testing.T) {
    25  	t.Run("GET", func(t *testing.T) {
    26  		t.Run("Success", func(t *testing.T) {
    27  			mock := &testutil.MockSourcesReaderWriter{
    28  				GetSourcesFunc: func() (sources.Sources, error) {
    29  					return sources.Sources{{Name: "foo"}}, nil
    30  				},
    31  			}
    32  			ts := newTestSourceServer(mock)
    33  			r := httptest.NewRequest(http.MethodGet, "/source", nil)
    34  			w := httptest.NewRecorder()
    35  			ts.handleSources(w, r)
    36  			resp := w.Result()
    37  			defer resp.Body.Close()
    38  			assert.Equal(t, http.StatusOK, resp.StatusCode)
    39  			body, _ := io.ReadAll(resp.Body)
    40  			var got []sources.Source
    41  			assert.NoError(t, jsoniter.ConfigFastest.Unmarshal(body, &got))
    42  			assert.Equal(t, "foo", got[0].Name)
    43  		})
    44  
    45  		t.Run("Failure", func(t *testing.T) {
    46  			mock := &testutil.MockSourcesReaderWriter{
    47  				GetSourcesFunc: func() (sources.Sources, error) {
    48  					return nil, errors.New("fail")
    49  				},
    50  			}
    51  			ts := newTestSourceServer(mock)
    52  			r := httptest.NewRequest(http.MethodGet, "/source", nil)
    53  			w := httptest.NewRecorder()
    54  			ts.handleSources(w, r)
    55  			resp := w.Result()
    56  			defer resp.Body.Close()
    57  			assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
    58  			body, _ := io.ReadAll(resp.Body)
    59  			assert.Contains(t, string(body), "fail")
    60  		})
    61  	})
    62  
    63  	t.Run("POST", func(t *testing.T) {
    64  		t.Run("Success", func(t *testing.T) {
    65  			var createdSource sources.Source
    66  			mock := &testutil.MockSourcesReaderWriter{
    67  				CreateSourceFunc: func(md sources.Source) error {
    68  					createdSource = md
    69  					return nil
    70  				},
    71  			}
    72  			ts := newTestSourceServer(mock)
    73  			src := sources.Source{Name: "bar"}
    74  			b, _ := jsoniter.ConfigFastest.Marshal(src)
    75  			r := httptest.NewRequest(http.MethodPost, "/source", bytes.NewReader(b))
    76  			w := httptest.NewRecorder()
    77  			ts.handleSources(w, r)
    78  			resp := w.Result()
    79  			defer resp.Body.Close()
    80  			assert.Equal(t, http.StatusCreated, resp.StatusCode)
    81  			assert.Equal(t, src, createdSource)
    82  		})
    83  
    84  		t.Run("ReaderFailure", func(t *testing.T) {
    85  			mock := &testutil.MockSourcesReaderWriter{
    86  				CreateSourceFunc: func(sources.Source) error {
    87  					return nil
    88  				},
    89  			}
    90  			ts := newTestSourceServer(mock)
    91  			r := httptest.NewRequest(http.MethodPost, "/Source?name=bar", &errorReader{})
    92  			w := httptest.NewRecorder()
    93  			ts.handleSources(w, r)
    94  			resp := w.Result()
    95  			defer resp.Body.Close()
    96  			assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
    97  			body, _ := io.ReadAll(resp.Body)
    98  			assert.Contains(t, string(body), "mock read error")
    99  		})
   100  
   101  		t.Run("Conflict", func(t *testing.T) {
   102  			mock := &testutil.MockSourcesReaderWriter{
   103  				CreateSourceFunc: func(sources.Source) error {
   104  					return sources.ErrSourceExists
   105  				},
   106  			}
   107  			ts := newTestSourceServer(mock)
   108  			src := sources.Source{Name: "bar"}
   109  			b, _ := jsoniter.ConfigFastest.Marshal(src)
   110  			r := httptest.NewRequest(http.MethodPost, "/source", bytes.NewReader(b))
   111  			w := httptest.NewRecorder()
   112  			ts.handleSources(w, r)
   113  			resp := w.Result()
   114  			defer resp.Body.Close()
   115  			assert.Equal(t, http.StatusConflict, resp.StatusCode)
   116  			body, _ := io.ReadAll(resp.Body)
   117  			assert.Contains(t, string(body), "source already exists")
   118  		})
   119  
   120  		t.Run("CreateFailure", func(t *testing.T) {
   121  			mock := &testutil.MockSourcesReaderWriter{
   122  				CreateSourceFunc: func(sources.Source) error {
   123  					return errors.New("fail")
   124  				},
   125  			}
   126  			ts := newTestSourceServer(mock)
   127  			src := sources.Source{Name: "bar"}
   128  			b, _ := jsoniter.ConfigFastest.Marshal(src)
   129  			r := httptest.NewRequest(http.MethodPost, "/source", bytes.NewReader(b))
   130  			w := httptest.NewRecorder()
   131  			ts.handleSources(w, r)
   132  			resp := w.Result()
   133  			defer resp.Body.Close()
   134  			assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   135  			body, _ := io.ReadAll(resp.Body)
   136  			assert.Contains(t, string(body), "fail")
   137  		})
   138  
   139  		t.Run("ReadAllError", func(t *testing.T) {
   140  			mock := &testutil.MockSourcesReaderWriter{}
   141  			ts := newTestSourceServer(mock)
   142  			r := httptest.NewRequest(http.MethodPost, "/source", &errorReader{})
   143  			w := httptest.NewRecorder()
   144  			ts.handleSources(w, r)
   145  			resp := w.Result()
   146  			defer resp.Body.Close()
   147  			assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   148  		})
   149  	})
   150  
   151  	t.Run("OPTIONS", func(t *testing.T) {
   152  		mock := &testutil.MockSourcesReaderWriter{}
   153  		ts := newTestSourceServer(mock)
   154  		r := httptest.NewRequest(http.MethodOptions, "/source", nil)
   155  		w := httptest.NewRecorder()
   156  		ts.handleSources(w, r)
   157  		resp := w.Result()
   158  		defer resp.Body.Close()
   159  		assert.Equal(t, http.StatusOK, resp.StatusCode)
   160  		assert.Equal(t, "GET, POST, OPTIONS", resp.Header.Get("Allow"))
   161  	})
   162  
   163  	t.Run("MethodNotAllowed", func(t *testing.T) {
   164  		mock := &testutil.MockSourcesReaderWriter{}
   165  		ts := newTestSourceServer(mock)
   166  		r := httptest.NewRequest(http.MethodPut, "/source", nil)
   167  		w := httptest.NewRecorder()
   168  		ts.handleSources(w, r)
   169  		resp := w.Result()
   170  		defer resp.Body.Close()
   171  		assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
   172  		assert.Equal(t, "GET, POST, OPTIONS", resp.Header.Get("Allow"))
   173  	})
   174  }
   175  
   176  func TestGetSources_Error(t *testing.T) {
   177  	mock := &testutil.MockSourcesReaderWriter{
   178  		GetSourcesFunc: func() (sources.Sources, error) {
   179  			return nil, errors.New("fail")
   180  		},
   181  	}
   182  	ts := newTestSourceServer(mock)
   183  	_, err := ts.GetSources()
   184  	assert.Error(t, err)
   185  }
   186  
   187  func TestUpdateSource_Error(t *testing.T) {
   188  	mock := &testutil.MockSourcesReaderWriter{
   189  		UpdateSourceFunc: func(sources.Source) error {
   190  			return errors.New("fail")
   191  		},
   192  	}
   193  	ts := newTestSourceServer(mock)
   194  	err := ts.UpdateSource([]byte("notjson"))
   195  	assert.Error(t, err)
   196  }
   197  
   198  func TestDeleteSource_Error(t *testing.T) {
   199  	mock := &testutil.MockSourcesReaderWriter{
   200  		DeleteSourceFunc: func(string) error {
   201  			return errors.New("fail")
   202  		},
   203  	}
   204  	ts := newTestSourceServer(mock)
   205  	err := ts.DeleteSource("foo")
   206  	assert.Error(t, err)
   207  }
   208  
   209  // Helper function to create HTTP requests with path values for testing individual source endpoints
   210  func newSourceItemRequest(method, name string, body io.Reader) *http.Request {
   211  	url := "/source/" + name
   212  	r := httptest.NewRequest(method, url, body)
   213  	r.SetPathValue("name", name)
   214  	return r
   215  }
   216  
   217  func TestHandleSourceItem(t *testing.T) {
   218  	t.Run("GET", func(t *testing.T) {
   219  		t.Run("Success", func(t *testing.T) {
   220  			source := sources.Source{Name: "test-source", ConnStr: "postgresql://test"}
   221  			mock := &testutil.MockSourcesReaderWriter{
   222  				GetSourcesFunc: func() (sources.Sources, error) {
   223  					return sources.Sources{source}, nil
   224  				},
   225  			}
   226  			ts := newTestSourceServer(mock)
   227  			r := newSourceItemRequest(http.MethodGet, "test-source", nil)
   228  			w := httptest.NewRecorder()
   229  			ts.handleSourceItem(w, r)
   230  			resp := w.Result()
   231  			defer resp.Body.Close()
   232  			assert.Equal(t, http.StatusOK, resp.StatusCode)
   233  			assert.Equal(t, "application/json", resp.Header.Get("Content-Type"))
   234  
   235  			var returnedSource sources.Source
   236  			body, _ := io.ReadAll(resp.Body)
   237  			assert.NoError(t, jsoniter.ConfigFastest.Unmarshal(body, &returnedSource))
   238  			assert.Equal(t, source.Name, returnedSource.Name)
   239  		})
   240  
   241  		t.Run("NotFound", func(t *testing.T) {
   242  			mock := &testutil.MockSourcesReaderWriter{
   243  				GetSourcesFunc: func() (sources.Sources, error) {
   244  					return sources.Sources{}, nil
   245  				},
   246  			}
   247  			ts := newTestSourceServer(mock)
   248  			r := newSourceItemRequest(http.MethodGet, "nonexistent", nil)
   249  			w := httptest.NewRecorder()
   250  			ts.handleSourceItem(w, r)
   251  			resp := w.Result()
   252  			defer resp.Body.Close()
   253  			assert.Equal(t, http.StatusNotFound, resp.StatusCode)
   254  			body, _ := io.ReadAll(resp.Body)
   255  			assert.Contains(t, string(body), "source not found")
   256  		})
   257  
   258  		t.Run("GetSourcesError", func(t *testing.T) {
   259  			mock := &testutil.MockSourcesReaderWriter{
   260  				GetSourcesFunc: func() (sources.Sources, error) {
   261  					return nil, errors.New("database connection failed")
   262  				},
   263  			}
   264  			ts := newTestSourceServer(mock)
   265  			r := newSourceItemRequest(http.MethodGet, "test-source", nil)
   266  			w := httptest.NewRecorder()
   267  			ts.handleSourceItem(w, r)
   268  			resp := w.Result()
   269  			defer resp.Body.Close()
   270  			assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   271  			body, _ := io.ReadAll(resp.Body)
   272  			assert.Contains(t, string(body), "database connection failed")
   273  		})
   274  	})
   275  
   276  	t.Run("PUT", func(t *testing.T) {
   277  		t.Run("Success", func(t *testing.T) {
   278  			existingSource := sources.Source{Name: "test-source", ConnStr: "postgresql://old"}
   279  			var updatedSource sources.Source
   280  			mock := &testutil.MockSourcesReaderWriter{
   281  				GetSourcesFunc: func() (sources.Sources, error) {
   282  					return sources.Sources{existingSource}, nil
   283  				},
   284  				UpdateSourceFunc: func(md sources.Source) error {
   285  					updatedSource = md
   286  					return nil
   287  				},
   288  			}
   289  			ts := newTestSourceServer(mock)
   290  
   291  			newSource := sources.Source{Name: "test-source", ConnStr: "postgresql://new"}
   292  			b, _ := jsoniter.ConfigFastest.Marshal(newSource)
   293  			r := newSourceItemRequest(http.MethodPut, "test-source", bytes.NewReader(b))
   294  			w := httptest.NewRecorder()
   295  			ts.handleSourceItem(w, r)
   296  			resp := w.Result()
   297  			defer resp.Body.Close()
   298  			assert.Equal(t, http.StatusOK, resp.StatusCode)
   299  			assert.Equal(t, newSource.ConnStr, updatedSource.ConnStr)
   300  		})
   301  
   302  		t.Run("CreateNew", func(t *testing.T) {
   303  			var updatedSource sources.Source
   304  			mock := &testutil.MockSourcesReaderWriter{
   305  				GetSourcesFunc: func() (sources.Sources, error) {
   306  					return sources.Sources{}, nil // No existing sources
   307  				},
   308  				UpdateSourceFunc: func(md sources.Source) error {
   309  					updatedSource = md
   310  					return nil
   311  				},
   312  			}
   313  			ts := newTestSourceServer(mock)
   314  
   315  			source := sources.Source{Name: "new-source", ConnStr: "postgresql://new"}
   316  			b, _ := jsoniter.ConfigFastest.Marshal(source)
   317  			r := newSourceItemRequest(http.MethodPut, "new-source", bytes.NewReader(b))
   318  			w := httptest.NewRecorder()
   319  			ts.handleSourceItem(w, r)
   320  			resp := w.Result()
   321  			defer resp.Body.Close()
   322  			assert.Equal(t, http.StatusOK, resp.StatusCode)
   323  			assert.Equal(t, source.Name, updatedSource.Name)
   324  			assert.Equal(t, source.ConnStr, updatedSource.ConnStr)
   325  		})
   326  
   327  		t.Run("NameMismatch", func(t *testing.T) {
   328  			existingSource := sources.Source{Name: "test-source"}
   329  			mock := &testutil.MockSourcesReaderWriter{
   330  				GetSourcesFunc: func() (sources.Sources, error) {
   331  					return sources.Sources{existingSource}, nil
   332  				},
   333  			}
   334  			ts := newTestSourceServer(mock)
   335  
   336  			// Body has different name than URL path
   337  			source := sources.Source{Name: "different-name"}
   338  			b, _ := jsoniter.ConfigFastest.Marshal(source)
   339  			r := newSourceItemRequest(http.MethodPut, "test-source", bytes.NewReader(b))
   340  			w := httptest.NewRecorder()
   341  			ts.handleSourceItem(w, r)
   342  			resp := w.Result()
   343  			defer resp.Body.Close()
   344  			assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   345  			body, _ := io.ReadAll(resp.Body)
   346  			assert.Contains(t, string(body), "name in URL and body must match")
   347  		})
   348  
   349  		t.Run("InvalidRequestBody", func(t *testing.T) {
   350  			mock := &testutil.MockSourcesReaderWriter{}
   351  			ts := newTestSourceServer(mock)
   352  			r := newSourceItemRequest(http.MethodPut, "test-source", &errorReader{})
   353  			w := httptest.NewRecorder()
   354  			ts.handleSourceItem(w, r)
   355  			resp := w.Result()
   356  			defer resp.Body.Close()
   357  			assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   358  			body, _ := io.ReadAll(resp.Body)
   359  			assert.Contains(t, string(body), "invalid request body")
   360  		})
   361  
   362  		t.Run("InvalidJSON", func(t *testing.T) {
   363  			mock := &testutil.MockSourcesReaderWriter{}
   364  			ts := newTestSourceServer(mock)
   365  			r := newSourceItemRequest(http.MethodPut, "test-source", strings.NewReader("invalid json"))
   366  			w := httptest.NewRecorder()
   367  			ts.handleSourceItem(w, r)
   368  			resp := w.Result()
   369  			defer resp.Body.Close()
   370  			assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   371  			body, _ := io.ReadAll(resp.Body)
   372  			assert.Contains(t, string(body), "invalid JSON format")
   373  		})
   374  
   375  		t.Run("UpdateError", func(t *testing.T) {
   376  			mock := &testutil.MockSourcesReaderWriter{
   377  				UpdateSourceFunc: func(sources.Source) error {
   378  					return errors.New("update operation failed")
   379  				},
   380  			}
   381  			ts := newTestSourceServer(mock)
   382  
   383  			source := sources.Source{Name: "test-source", ConnStr: "postgresql://test"}
   384  			b, _ := jsoniter.ConfigFastest.Marshal(source)
   385  			r := newSourceItemRequest(http.MethodPut, "test-source", bytes.NewReader(b))
   386  			w := httptest.NewRecorder()
   387  			ts.handleSourceItem(w, r)
   388  			resp := w.Result()
   389  			defer resp.Body.Close()
   390  			assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   391  			body, _ := io.ReadAll(resp.Body)
   392  			assert.Contains(t, string(body), "update operation failed")
   393  		})
   394  	})
   395  
   396  	t.Run("DELETE", func(t *testing.T) {
   397  		t.Run("Success", func(t *testing.T) {
   398  			existingSource := sources.Source{Name: "test-source"}
   399  			var deletedName string
   400  			mock := &testutil.MockSourcesReaderWriter{
   401  				GetSourcesFunc: func() (sources.Sources, error) {
   402  					return sources.Sources{existingSource}, nil
   403  				},
   404  				DeleteSourceFunc: func(name string) error {
   405  					deletedName = name
   406  					return nil
   407  				},
   408  			}
   409  			ts := newTestSourceServer(mock)
   410  			r := newSourceItemRequest(http.MethodDelete, "test-source", nil)
   411  			w := httptest.NewRecorder()
   412  			ts.handleSourceItem(w, r)
   413  			resp := w.Result()
   414  			defer resp.Body.Close()
   415  			assert.Equal(t, http.StatusOK, resp.StatusCode)
   416  			assert.Equal(t, "test-source", deletedName)
   417  		})
   418  
   419  		t.Run("Idempotent", func(t *testing.T) {
   420  			var deletedName string
   421  			mock := &testutil.MockSourcesReaderWriter{
   422  				GetSourcesFunc: func() (sources.Sources, error) {
   423  					return sources.Sources{}, nil // No existing sources
   424  				},
   425  				DeleteSourceFunc: func(name string) error {
   426  					deletedName = name
   427  					return nil // DELETE is idempotent - succeeds even if source doesn't exist
   428  				},
   429  			}
   430  			ts := newTestSourceServer(mock)
   431  			r := newSourceItemRequest(http.MethodDelete, "nonexistent", nil)
   432  			w := httptest.NewRecorder()
   433  			ts.handleSourceItem(w, r)
   434  			resp := w.Result()
   435  			defer resp.Body.Close()
   436  			assert.Equal(t, http.StatusOK, resp.StatusCode)
   437  			assert.Equal(t, "nonexistent", deletedName)
   438  		})
   439  
   440  		t.Run("DeleteError", func(t *testing.T) {
   441  			mock := &testutil.MockSourcesReaderWriter{
   442  				DeleteSourceFunc: func(string) error {
   443  					return errors.New("delete operation failed")
   444  				},
   445  			}
   446  			ts := newTestSourceServer(mock)
   447  			r := newSourceItemRequest(http.MethodDelete, "test-source", nil)
   448  			w := httptest.NewRecorder()
   449  			ts.handleSourceItem(w, r)
   450  			resp := w.Result()
   451  			defer resp.Body.Close()
   452  			assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   453  			body, _ := io.ReadAll(resp.Body)
   454  			assert.Contains(t, string(body), "delete operation failed")
   455  		})
   456  	})
   457  
   458  	t.Run("EmptyName", func(t *testing.T) {
   459  		mock := &testutil.MockSourcesReaderWriter{}
   460  		ts := newTestSourceServer(mock)
   461  		r := newSourceItemRequest(http.MethodGet, "", nil)
   462  		w := httptest.NewRecorder()
   463  		ts.handleSourceItem(w, r)
   464  		resp := w.Result()
   465  		defer resp.Body.Close()
   466  		assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   467  		body, _ := io.ReadAll(resp.Body)
   468  		assert.Contains(t, string(body), "source name is required")
   469  	})
   470  
   471  	t.Run("OPTIONS", func(t *testing.T) {
   472  		mock := &testutil.MockSourcesReaderWriter{}
   473  		ts := newTestSourceServer(mock)
   474  		r := newSourceItemRequest(http.MethodOptions, "test", nil)
   475  		w := httptest.NewRecorder()
   476  		ts.handleSourceItem(w, r)
   477  		resp := w.Result()
   478  		defer resp.Body.Close()
   479  		assert.Equal(t, http.StatusOK, resp.StatusCode)
   480  		assert.Equal(t, "GET, PUT, DELETE, OPTIONS", resp.Header.Get("Allow"))
   481  	})
   482  
   483  	t.Run("MethodNotAllowed", func(t *testing.T) {
   484  		mock := &testutil.MockSourcesReaderWriter{}
   485  		ts := newTestSourceServer(mock)
   486  		r := newSourceItemRequest(http.MethodPost, "test", nil)
   487  		w := httptest.NewRecorder()
   488  		ts.handleSourceItem(w, r)
   489  		resp := w.Result()
   490  		defer resp.Body.Close()
   491  		assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
   492  		assert.Equal(t, "GET, PUT, DELETE, OPTIONS", resp.Header.Get("Allow"))
   493  	})
   494  }
   495