...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/webserver/jwt_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/webserver

     1  package webserver
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/golang-jwt/jwt/v5"
    12  	jsoniter "github.com/json-iterator/go"
    13  	"github.com/stretchr/testify/assert"
    14  )
    15  
    16  var json = jsoniter.ConfigFastest
    17  
    18  func TestIsCorrectPassword(t *testing.T) {
    19  	ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
    20  	assert.True(t, ts.IsCorrectPassword(loginReq{Username: "user", Password: "pass"}))
    21  	assert.False(t, ts.IsCorrectPassword(loginReq{Username: "user", Password: "wrong"}))
    22  	assert.True(t, (&WebUIServer{}).IsCorrectPassword(loginReq{})) // empty user/pass disables auth
    23  }
    24  
    25  func TestHandleLogin_POST_Success(t *testing.T) {
    26  	ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
    27  	body, _ := json.Marshal(map[string]string{"user": "user", "password": "pass"})
    28  	r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader(body))
    29  	w := httptest.NewRecorder()
    30  	ts.handleLogin(w, r)
    31  	resp := w.Result()
    32  	assert.Equal(t, http.StatusOK, resp.StatusCode)
    33  	token, _ := io.ReadAll(resp.Body)
    34  	assert.NotEmpty(t, string(token))
    35  }
    36  
    37  func TestHandleLogin_POST_Fail(t *testing.T) {
    38  	ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
    39  	body, _ := json.Marshal(map[string]string{"user": "user", "password": "wrong"})
    40  	r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader(body))
    41  	w := httptest.NewRecorder()
    42  	ts.handleLogin(w, r)
    43  	resp := w.Result()
    44  	assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
    45  }
    46  
    47  func TestHandleLogin_POST_BadJSON(t *testing.T) {
    48  	ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
    49  	r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader([]byte("notjson")))
    50  	w := httptest.NewRecorder()
    51  	ts.handleLogin(w, r)
    52  	resp := w.Result()
    53  	assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
    54  }
    55  
    56  func TestHandleLogin_GET(t *testing.T) {
    57  	ts := &WebUIServer{}
    58  	r := httptest.NewRequest(http.MethodGet, "/login", nil)
    59  	w := httptest.NewRecorder()
    60  	ts.handleLogin(w, r)
    61  	resp := w.Result()
    62  	assert.Equal(t, http.StatusOK, resp.StatusCode)
    63  	body, _ := io.ReadAll(resp.Body)
    64  	assert.Equal(t, "only POST methods is allowed.", string(body))
    65  }
    66  
    67  func TestGenerateAndValidateJWT(t *testing.T) {
    68  	token, err := generateJWT("user1")
    69  	assert.NoError(t, err)
    70  	r := httptest.NewRequest(http.MethodGet, "/", nil)
    71  	r.Header.Set("Token", token)
    72  	assert.NoError(t, validateToken(r))
    73  }
    74  
    75  func TestValidateToken_MissingToken(t *testing.T) {
    76  	r := httptest.NewRequest(http.MethodGet, "/", nil)
    77  	err := validateToken(r)
    78  	assert.Error(t, err)
    79  	assert.Contains(t, err.Error(), "can not find token")
    80  }
    81  
    82  func TestValidateToken_InvalidToken(t *testing.T) {
    83  	r := httptest.NewRequest(http.MethodGet, "/", nil)
    84  	r.Header.Set("Token", "invalidtoken")
    85  	err := validateToken(r)
    86  	assert.Error(t, err)
    87  }
    88  
    89  func TestEnsureAuth_ServeHTTP(t *testing.T) {
    90  	called := false
    91  	h := func(w http.ResponseWriter, _ *http.Request) {
    92  		called = true
    93  		w.WriteHeader(http.StatusTeapot)
    94  	}
    95  	token, _ := generateJWT("user1")
    96  	r := httptest.NewRequest(http.MethodGet, "/", nil)
    97  	r.Header.Set("Token", token)
    98  	w := httptest.NewRecorder()
    99  	NewEnsureAuth(h).ServeHTTP(w, r)
   100  	resp := w.Result()
   101  	assert.Equal(t, http.StatusTeapot, resp.StatusCode)
   102  	assert.True(t, called)
   103  }
   104  
   105  func TestEnsureAuth_ServeHTTP_InvalidToken(t *testing.T) {
   106  	h := func(w http.ResponseWriter, _ *http.Request) {
   107  		w.WriteHeader(http.StatusTeapot)
   108  	}
   109  	r := httptest.NewRequest(http.MethodGet, "/", nil)
   110  	r.Header.Set("Token", "invalidtoken")
   111  	w := httptest.NewRecorder()
   112  	NewEnsureAuth(h).ServeHTTP(w, r)
   113  	resp := w.Result()
   114  	assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
   115  }
   116  
   117  func TestJWT_Expiration(t *testing.T) {
   118  	tok := jwt.New(jwt.SigningMethodHS256)
   119  	claims := tok.Claims.(jwt.MapClaims)
   120  	claims["authorized"] = true
   121  	claims["username"] = "user"
   122  	claims["exp"] = time.Now().Add(-time.Hour).Unix() // expired
   123  	token, _ := tok.SignedString(sampleSecretKey)
   124  	r := httptest.NewRequest(http.MethodGet, "/", nil)
   125  	r.Header.Set("Token", token)
   126  	err := validateToken(r)
   127  	assert.Error(t, err)
   128  	assert.Contains(t, err.Error(), "token is expired")
   129  }
   130