...

Source file src/github.com/cybertec-postgresql/pgwatch/v5/internal/webserver/webserver_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v5/internal/webserver

     1  package webserver
     2  
     3  import (
     4  	"io"
     5  	"io/fs"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"os"
     9  	"path"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/cybertec-postgresql/pgwatch/v5/internal/metrics"
    14  	"github.com/sirupsen/logrus"
    15  	"github.com/stretchr/testify/assert"
    16  )
    17  
    18  type mockFS struct {
    19  	OpenFunc func(name string) (fs.File, error)
    20  }
    21  
    22  func (m mockFS) Open(name string) (fs.File, error) {
    23  	return m.OpenFunc(name)
    24  }
    25  
    26  func TestServer_handleStatic(t *testing.T) {
    27  	tempFile := path.Join(t.TempDir(), "file.ext")
    28  	assert.NoError(t, os.WriteFile(tempFile, []byte(`{"foo": {"bar": 1}}`), 0644))
    29  
    30  	indexHTML := []byte(`<!DOCTYPE html><html><head><script>window.__PGWATCH_BASE_PATH__='';</script></head><body>{"foo": {"bar": 1}}</body></html>`)
    31  
    32  	// Save original uiFS and restore after test
    33  	origUIFS := uiFS
    34  	defer func() { uiFS = origUIFS }()
    35  
    36  	uiFS = mockFS{
    37  		OpenFunc: func(name string) (fs.File, error) {
    38  			switch name {
    39  			case "index.html", "static/file.ext":
    40  				return os.Open(tempFile)
    41  			case "badfile.ext":
    42  				return nil, fs.ErrInvalid
    43  			default:
    44  				return nil, fs.ErrNotExist
    45  			}
    46  		},
    47  	}
    48  
    49  	ts := &WebUIServer{
    50  		Logger:    logrus.StandardLogger(),
    51  		indexHTML: indexHTML,
    52  	}
    53  
    54  	t.Run("not GET", func(t *testing.T) {
    55  		r := httptest.NewRequest(http.MethodPost, "/static/file.ext", nil)
    56  		w := httptest.NewRecorder()
    57  		ts.handleStatic(w, r)
    58  		resp := w.Result()
    59  		defer resp.Body.Close()
    60  		assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
    61  		body, _ := io.ReadAll(resp.Body)
    62  		assert.Equal(t, "Method Not Allowed\n", string(body))
    63  	})
    64  
    65  	t.Run("some static file", func(t *testing.T) {
    66  		r := httptest.NewRequest(http.MethodGet, "/static/file.ext", nil)
    67  		w := httptest.NewRecorder()
    68  		ts.handleStatic(w, r)
    69  		resp := w.Result()
    70  		defer resp.Body.Close()
    71  		assert.Equal(t, http.StatusOK, resp.StatusCode)
    72  		body, _ := io.ReadAll(resp.Body)
    73  		var got map[string]metrics.Metric
    74  		assert.NoError(t, json.Unmarshal(body, &got))
    75  		assert.Contains(t, got, "foo")
    76  	})
    77  
    78  	t.Run("predefined route", func(t *testing.T) {
    79  		r := httptest.NewRequest(http.MethodGet, "/metrics", nil)
    80  		w := httptest.NewRecorder()
    81  		ts.handleStatic(w, r)
    82  		resp := w.Result()
    83  		defer resp.Body.Close()
    84  		assert.Equal(t, http.StatusOK, resp.StatusCode)
    85  		assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
    86  		body, _ := io.ReadAll(resp.Body)
    87  		bodyStr := string(body)
    88  		assert.Contains(t, bodyStr, "<!DOCTYPE html>")
    89  		assert.Contains(t, bodyStr, "window.__PGWATCH_BASE_PATH__")
    90  	})
    91  
    92  	t.Run("file not found", func(t *testing.T) {
    93  		r := httptest.NewRequest(http.MethodGet, "/static/notfound.ext", nil)
    94  		w := httptest.NewRecorder()
    95  		ts.handleStatic(w, r)
    96  		resp := w.Result()
    97  		defer resp.Body.Close()
    98  		assert.Equal(t, http.StatusNotFound, resp.StatusCode)
    99  		body, _ := io.ReadAll(resp.Body)
   100  		assert.Equal(t, "404 page not found\n", string(body))
   101  	})
   102  
   103  	t.Run("file cannot be read", func(t *testing.T) {
   104  		r := httptest.NewRequest(http.MethodGet, "/badfile.ext", nil)
   105  		w := httptest.NewRecorder()
   106  		ts.handleStatic(w, r)
   107  		resp := w.Result()
   108  		defer resp.Body.Close()
   109  		assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   110  	})
   111  }
   112  
   113  func TestServer_handleTestConnect(t *testing.T) {
   114  	ts := &WebUIServer{
   115  		Logger: logrus.StandardLogger(),
   116  	}
   117  
   118  	t.Run("POST", func(t *testing.T) {
   119  		r := httptest.NewRequest(http.MethodPost, "/testconnect", strings.NewReader("bad connection string"))
   120  		w := httptest.NewRecorder()
   121  		ts.handleTestConnect(w, r)
   122  		resp := w.Result()
   123  		defer resp.Body.Close()
   124  		assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   125  	})
   126  
   127  	t.Run("failed reader", func(t *testing.T) {
   128  		r := httptest.NewRequest(http.MethodPost, "/testconnect", &errorReader{})
   129  		w := httptest.NewRecorder()
   130  		ts.handleTestConnect(w, r)
   131  		resp := w.Result()
   132  		defer resp.Body.Close()
   133  		assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   134  	})
   135  
   136  	t.Run("GET", func(t *testing.T) {
   137  		r := httptest.NewRequest(http.MethodGet, "/testconnect", nil)
   138  		w := httptest.NewRecorder()
   139  		ts.handleTestConnect(w, r)
   140  		resp := w.Result()
   141  		defer resp.Body.Close()
   142  		assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
   143  		body, _ := io.ReadAll(resp.Body)
   144  		assert.Equal(t, "Method Not Allowed\n", string(body))
   145  	})
   146  }
   147