...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/webserver/webserver_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/webserver

     1  package webserver
     2  
     3  import (
     4  	"io"
     5  	"io/fs"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"os"
     9  	"path"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/cybertec-postgresql/pgwatch/v3/internal/metrics"
    14  	"github.com/sirupsen/logrus"
    15  	"github.com/stretchr/testify/assert"
    16  )
    17  
    18  type mockFS struct {
    19  	OpenFunc func(name string) (fs.File, error)
    20  }
    21  
    22  func (m mockFS) Open(name string) (fs.File, error) {
    23  	return m.OpenFunc(name)
    24  }
    25  
    26  func TestServer_handleStatic(t *testing.T) {
    27  	tempFile := path.Join(t.TempDir(), "file.ext")
    28  	assert.NoError(t, os.WriteFile(tempFile, []byte(`{"foo": {"bar": 1}}`), 0644))
    29  	ts := &WebUIServer{
    30  		Logger: logrus.StandardLogger(),
    31  		uiFS: mockFS{
    32  			OpenFunc: func(name string) (fs.File, error) {
    33  				switch name {
    34  				case "index.html", "static/file.ext":
    35  					return os.Open(tempFile)
    36  				case "badfile.ext":
    37  					return nil, fs.ErrInvalid
    38  				default:
    39  					return nil, fs.ErrNotExist
    40  				}
    41  			},
    42  		},
    43  	}
    44  
    45  	t.Run("not GET", func(t *testing.T) {
    46  		r := httptest.NewRequest(http.MethodPost, "/static/file.ext", nil)
    47  		w := httptest.NewRecorder()
    48  		ts.handleStatic(w, r)
    49  		resp := w.Result()
    50  		defer resp.Body.Close()
    51  		assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
    52  		body, _ := io.ReadAll(resp.Body)
    53  		assert.Equal(t, "Method Not Allowed\n", string(body))
    54  	})
    55  
    56  	t.Run("some static file", func(t *testing.T) {
    57  		r := httptest.NewRequest(http.MethodGet, "/static/file.ext", nil)
    58  		w := httptest.NewRecorder()
    59  		ts.handleStatic(w, r)
    60  		resp := w.Result()
    61  		defer resp.Body.Close()
    62  		assert.Equal(t, http.StatusOK, resp.StatusCode)
    63  		body, _ := io.ReadAll(resp.Body)
    64  		var got map[string]metrics.Metric
    65  		assert.NoError(t, json.Unmarshal(body, &got))
    66  		assert.Contains(t, got, "foo")
    67  	})
    68  
    69  	t.Run("predefined route", func(t *testing.T) {
    70  		r := httptest.NewRequest(http.MethodGet, "/metrics", nil)
    71  		w := httptest.NewRecorder()
    72  		ts.handleStatic(w, r)
    73  		resp := w.Result()
    74  		defer resp.Body.Close()
    75  		assert.Equal(t, http.StatusOK, resp.StatusCode)
    76  		body, _ := io.ReadAll(resp.Body)
    77  		var got map[string]metrics.Metric
    78  		assert.NoError(t, json.Unmarshal(body, &got))
    79  		assert.Contains(t, got, "foo")
    80  	})
    81  
    82  	t.Run("file not found", func(t *testing.T) {
    83  		r := httptest.NewRequest(http.MethodGet, "/static/notfound.ext", nil)
    84  		w := httptest.NewRecorder()
    85  		ts.handleStatic(w, r)
    86  		resp := w.Result()
    87  		defer resp.Body.Close()
    88  		assert.Equal(t, http.StatusNotFound, resp.StatusCode)
    89  		body, _ := io.ReadAll(resp.Body)
    90  		assert.Equal(t, "404 page not found\n", string(body))
    91  	})
    92  
    93  	t.Run("file cannot be read", func(t *testing.T) {
    94  		r := httptest.NewRequest(http.MethodGet, "/badfile.ext", nil)
    95  		w := httptest.NewRecorder()
    96  		ts.handleStatic(w, r)
    97  		resp := w.Result()
    98  		defer resp.Body.Close()
    99  		assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   100  	})
   101  }
   102  
   103  func TestServer_handleTestConnect(t *testing.T) {
   104  	ts := &WebUIServer{
   105  		Logger: logrus.StandardLogger(),
   106  	}
   107  
   108  	t.Run("POST", func(t *testing.T) {
   109  		r := httptest.NewRequest(http.MethodPost, "/testconnect", strings.NewReader("bad connection string"))
   110  		w := httptest.NewRecorder()
   111  		ts.handleTestConnect(w, r)
   112  		resp := w.Result()
   113  		defer resp.Body.Close()
   114  		assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   115  	})
   116  
   117  	t.Run("failed reader", func(t *testing.T) {
   118  		r := httptest.NewRequest(http.MethodPost, "/testconnect", &errorReader{})
   119  		w := httptest.NewRecorder()
   120  		ts.handleTestConnect(w, r)
   121  		resp := w.Result()
   122  		defer resp.Body.Close()
   123  		assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   124  	})
   125  
   126  	t.Run("GET", func(t *testing.T) {
   127  		r := httptest.NewRequest(http.MethodGet, "/testconnect", nil)
   128  		w := httptest.NewRecorder()
   129  		ts.handleTestConnect(w, r)
   130  		resp := w.Result()
   131  		defer resp.Body.Close()
   132  		assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
   133  		body, _ := io.ReadAll(resp.Body)
   134  		assert.Equal(t, "Method Not Allowed\n", string(body))
   135  	})
   136  }
   137