1 package webserver
2
3 import (
4 "context"
5 "fmt"
6 "io"
7 "io/fs"
8 "mime"
9 "net"
10 "net/http"
11 "os"
12 "path/filepath"
13 "slices"
14 "strings"
15 "time"
16
17 "github.com/cybertec-postgresql/pgwatch/v3/internal/db"
18 "github.com/cybertec-postgresql/pgwatch/v3/internal/log"
19 "github.com/cybertec-postgresql/pgwatch/v3/internal/metrics"
20 "github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
21 )
22
23 type ReadyChecker interface {
24 Ready() bool
25 }
26
27 type WebUIServer struct {
28 CmdOpts
29 http.Server
30 log.Logger
31 ctx context.Context
32 uiFS fs.FS
33 metricsReaderWriter metrics.ReaderWriter
34 sourcesReaderWriter sources.ReaderWriter
35 readyChecker ReadyChecker
36 }
37
38 func Init(ctx context.Context, opts CmdOpts, webuifs fs.FS, mrw metrics.ReaderWriter, srw sources.ReaderWriter, rc ReadyChecker) (*WebUIServer, error) {
39 if opts.WebDisable == WebDisableAll {
40 return nil, nil
41 }
42 mux := http.NewServeMux()
43 s := &WebUIServer{
44 Server: http.Server{
45 Addr: opts.WebAddr,
46 ReadTimeout: 10 * time.Second,
47 WriteTimeout: 10 * time.Second,
48 MaxHeaderBytes: 1 << 20,
49 Handler: corsMiddleware(mux),
50 },
51 ctx: ctx,
52 Logger: log.GetLogger(ctx),
53 CmdOpts: opts,
54 uiFS: webuifs,
55 metricsReaderWriter: mrw,
56 sourcesReaderWriter: srw,
57 readyChecker: rc,
58 }
59
60 mux.Handle("/source", NewEnsureAuth(s.handleSources))
61 mux.Handle("/test-connect", NewEnsureAuth(s.handleTestConnect))
62 mux.Handle("/metric", NewEnsureAuth(s.handleMetrics))
63 mux.Handle("/preset", NewEnsureAuth(s.handlePresets))
64 mux.Handle("/log", NewEnsureAuth(s.serveWsLog))
65 mux.HandleFunc("/login", s.handleLogin)
66 mux.HandleFunc("/liveness", s.handleLiveness)
67 mux.HandleFunc("/readiness", s.handleReadiness)
68 if opts.WebDisable != WebDisableUI {
69 mux.HandleFunc("/", s.handleStatic)
70 }
71
72 ln, err := net.Listen("tcp", s.Addr)
73 if err != nil {
74 return nil, err
75 }
76
77 go func() { panic(s.Serve(ln)) }()
78
79 return s, nil
80 }
81
82 func (Server *WebUIServer) handleStatic(w http.ResponseWriter, r *http.Request) {
83 if r.Method != "GET" {
84 http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
85 return
86 }
87 routes := []string{"/", "/sources", "/metrics", "/presets", "/logs"}
88 path := r.URL.Path
89 if slices.Contains(routes, path) {
90 path = "index.html"
91 } else {
92 path = strings.TrimPrefix(path, "/")
93 }
94
95 file, err := Server.uiFS.Open(path)
96 if err != nil {
97 if os.IsNotExist(err) {
98 Server.Println("file", path, "not found:", err)
99 http.NotFound(w, r)
100 return
101 }
102 Server.Println("file", path, "cannot be read:", err)
103 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
104 return
105 }
106 defer file.Close()
107
108 contentType := mime.TypeByExtension(filepath.Ext(path))
109 w.Header().Set("Content-Type", contentType)
110 if strings.HasPrefix(path, "static/") {
111 w.Header().Set("Cache-Control", "public, max-age=31536000")
112 }
113 stat, err := file.Stat()
114 if err == nil && stat.Size() > 0 {
115 w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
116 }
117
118 n, _ := io.Copy(w, file)
119 Server.Debug("file", path, "copied", n, "bytes")
120 }
121
122 func (Server *WebUIServer) handleLiveness(w http.ResponseWriter, _ *http.Request) {
123 if Server.ctx.Err() != nil {
124 w.WriteHeader(http.StatusServiceUnavailable)
125 _, _ = w.Write([]byte(`{"status": "unavailable"}`))
126 return
127 }
128 w.WriteHeader(http.StatusOK)
129 _, _ = w.Write([]byte(`{"status": "ok"}`))
130 }
131
132 func (Server *WebUIServer) handleReadiness(w http.ResponseWriter, _ *http.Request) {
133 if Server.readyChecker.Ready() {
134 w.WriteHeader(http.StatusOK)
135 _, _ = w.Write([]byte(`{"status": "ok"}`))
136 return
137 }
138 w.WriteHeader(http.StatusServiceUnavailable)
139 _, _ = w.Write([]byte(`{"status": "busy"}`))
140 }
141
142 func (Server *WebUIServer) handleTestConnect(w http.ResponseWriter, r *http.Request) {
143 switch r.Method {
144 case http.MethodPost:
145
146 p, err := io.ReadAll(r.Body)
147 if err != nil {
148 http.Error(w, err.Error(), http.StatusBadRequest)
149 return
150 }
151 if err := db.Ping(context.TODO(), string(p)); err != nil {
152 http.Error(w, err.Error(), http.StatusBadRequest)
153 }
154 default:
155 w.Header().Set("Allow", "POST")
156 http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
157 return
158 }
159 }
160
161 func corsMiddleware(next http.Handler) http.Handler {
162 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
163 w.Header().Set("Access-Control-Allow-Origin", "http://localhost:4000")
164 w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
165 w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, token")
166 if r.Method == "OPTIONS" {
167 w.WriteHeader(http.StatusOK)
168 return
169 }
170 next.ServeHTTP(w, r)
171 })
172 }
173