1 package webserver
2
3 import (
4 "bytes"
5 "io"
6 "net/http"
7 "net/http/httptest"
8 "testing"
9 "time"
10
11 "github.com/golang-jwt/jwt/v5"
12 jsoniter "github.com/json-iterator/go"
13 "github.com/stretchr/testify/assert"
14 )
15
16 var json = jsoniter.ConfigFastest
17
18 func TestIsCorrectPassword(t *testing.T) {
19 ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
20 assert.True(t, ts.IsCorrectPassword(loginReq{Username: "user", Password: "pass"}))
21 assert.False(t, ts.IsCorrectPassword(loginReq{Username: "user", Password: "wrong"}))
22 assert.True(t, (&WebUIServer{}).IsCorrectPassword(loginReq{}))
23 }
24
25 func TestHandleLogin_POST_Success(t *testing.T) {
26 ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
27 body, _ := json.Marshal(map[string]string{"user": "user", "password": "pass"})
28 r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader(body))
29 w := httptest.NewRecorder()
30 ts.handleLogin(w, r)
31 resp := w.Result()
32 assert.Equal(t, http.StatusOK, resp.StatusCode)
33 token, _ := io.ReadAll(resp.Body)
34 assert.NotEmpty(t, string(token))
35 }
36
37 func TestHandleLogin_POST_Fail(t *testing.T) {
38 ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
39 body, _ := json.Marshal(map[string]string{"user": "user", "password": "wrong"})
40 r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader(body))
41 w := httptest.NewRecorder()
42 ts.handleLogin(w, r)
43 resp := w.Result()
44 assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
45 }
46
47 func TestHandleLogin_POST_BadJSON(t *testing.T) {
48 ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
49 r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader([]byte("notjson")))
50 w := httptest.NewRecorder()
51 ts.handleLogin(w, r)
52 resp := w.Result()
53 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
54 }
55
56 func TestHandleLogin_GET(t *testing.T) {
57 ts := &WebUIServer{}
58 r := httptest.NewRequest(http.MethodGet, "/login", nil)
59 w := httptest.NewRecorder()
60 ts.handleLogin(w, r)
61 resp := w.Result()
62 assert.Equal(t, http.StatusOK, resp.StatusCode)
63 body, _ := io.ReadAll(resp.Body)
64 assert.Equal(t, "only POST methods is allowed.", string(body))
65 }
66
67 func TestGenerateAndValidateJWT(t *testing.T) {
68 token, err := generateJWT("user1")
69 assert.NoError(t, err)
70 r := httptest.NewRequest(http.MethodGet, "/", nil)
71 r.Header.Set("Token", token)
72 assert.NoError(t, validateToken(r))
73 }
74
75 func TestValidateToken_MissingToken(t *testing.T) {
76 r := httptest.NewRequest(http.MethodGet, "/", nil)
77 err := validateToken(r)
78 assert.Error(t, err)
79 assert.Contains(t, err.Error(), "can not find token")
80 }
81
82 func TestValidateToken_InvalidToken(t *testing.T) {
83 r := httptest.NewRequest(http.MethodGet, "/", nil)
84 r.Header.Set("Token", "invalidtoken")
85 err := validateToken(r)
86 assert.Error(t, err)
87 }
88
89 func TestEnsureAuth_ServeHTTP(t *testing.T) {
90 called := false
91 h := func(w http.ResponseWriter, _ *http.Request) {
92 called = true
93 w.WriteHeader(http.StatusTeapot)
94 }
95 token, _ := generateJWT("user1")
96 r := httptest.NewRequest(http.MethodGet, "/", nil)
97 r.Header.Set("Token", token)
98 w := httptest.NewRecorder()
99 NewEnsureAuth(h).ServeHTTP(w, r)
100 resp := w.Result()
101 assert.Equal(t, http.StatusTeapot, resp.StatusCode)
102 assert.True(t, called)
103 }
104
105 func TestEnsureAuth_ServeHTTP_InvalidToken(t *testing.T) {
106 h := func(w http.ResponseWriter, _ *http.Request) {
107 w.WriteHeader(http.StatusTeapot)
108 }
109 r := httptest.NewRequest(http.MethodGet, "/", nil)
110 r.Header.Set("Token", "invalidtoken")
111 w := httptest.NewRecorder()
112 NewEnsureAuth(h).ServeHTTP(w, r)
113 resp := w.Result()
114 assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
115 }
116
117 func TestJWT_Expiration(t *testing.T) {
118 tok := jwt.New(jwt.SigningMethodHS256)
119 claims := tok.Claims.(jwt.MapClaims)
120 claims["authorized"] = true
121 claims["username"] = "user"
122 claims["exp"] = time.Now().Add(-time.Hour).Unix()
123 token, _ := tok.SignedString(sampleSecretKey)
124 r := httptest.NewRequest(http.MethodGet, "/", nil)
125 r.Header.Set("Token", token)
126 err := validateToken(r)
127 assert.Error(t, err)
128 assert.Contains(t, err.Error(), "token is expired")
129 }
130