diff options
Diffstat (limited to 'main.go')
| -rw-r--r-- | main.go | 360 |
1 files changed, 128 insertions, 232 deletions
@@ -5,15 +5,14 @@ import ( "crypto/rand" "embed" "encoding/base64" - "encoding/json" "errors" "flag" "fmt" + "html/template" "io" "io/fs" "log" "net/http" - "net/http/httputil" "net/url" "os" "path/filepath" @@ -35,20 +34,14 @@ var ( logger = log.New(os.Stderr, "", 0) ID_BYTES = 8 - embeddedPrefix = "ui/build" // Keep in sync with staticEmbedded below - //go:embed all:ui/build/* + //embeddedPrefix = "ui/build" // Keep in sync with staticEmbedded below + // go:embed all:ui/build/* + + embeddedPrefix = "htmx" // Keep in sync with staticEmbedded below + //go:embed all:htmx/* staticEmbedded embed.FS ) -type App struct { - static fs.FS - users UsersMap - jwtKey string - sessionHours int - storage string // path to where the paste files are actually stored - staticHandler http.Handler -} - // EnvFlagString is a convent way to set value from environment variable, // and allow override when a command line flag is set. It's assumed `p` is // not nil. @@ -59,6 +52,12 @@ func EnvFlagString(fl *flag.FlagSet, p *string, name, envvar, usage string) { fl.StringVar(p, name, *p, fmt.Sprintf("%s (Environ: '%s')", usage, envvar)) } +func logIfErr(err error) { + if err != nil { + logger.Println(err) + } +} + // Response encapsulates all the JSON responses from this server, Code // is a copy of the HTTP status code, Msg is a human-readable statement // about the particular response. Data is optional @@ -80,7 +79,6 @@ func main() { genhash = false storage = "" fsdir = "" - proxyURL = "" jwtKey = "" sessionHours = "12" rSessionHours = 12 @@ -104,8 +102,6 @@ func main() { "Path of directory to serve") EnvFlagString(fl, &fsdir, "fs", "FS_DIR", "Path which static assets are located, empty to use embedded") - EnvFlagString(fl, &proxyURL, "sp", "STATIC_PROXY", - "What server do we proxy to for static content?") EnvFlagString(fl, &jwtKey, "jwt", "JWT_KEY", "If supplied use the value as the JWT key instead of a random value\n"+ "Please understand this is not a good idea. Use a random key if "+ @@ -177,25 +173,9 @@ func main() { dumpFStree(static) } - app := &App{ - static: static, - storage: storage, - jwtKey: jwtKey, - sessionHours: rSessionHours, - users: LoadUsersFromEnviron(), - } - - rp, err := getProxyHandler(proxyURL) - if proxyURL != "" && err == nil { - app.staticHandler = rp - logger.Println("Proxying static requests to: ", proxyURL) - } else if err != nil { - logger.Printf("Warning, invalid url: '%s': %s", proxyURL, err) - } - logger.Println("listening on: ", listen) - handler := app.Handler() + handler := Handler(static, storage, LoadUsersFromEnviron(), jwtKey, rSessionHours) if logRequests { handler = httpRequestLogger(handler) @@ -226,17 +206,6 @@ func dumpFStree(f fs.FS) { }) } -// getProxyHandler returns a httputil.NewSingleHostReverseProxy for the -// url provided, and errors parsing the URL, if any. -func getProxyHandler(proxyURL string) (http.Handler, error) { - pu, err := url.Parse(proxyURL) - if err != nil { - return nil, err - } - rp := httputil.NewSingleHostReverseProxy(pu) - return rp, nil -} - func httpRequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Printf("%s %s %s \"%s\" \"%s\"\n", @@ -271,54 +240,39 @@ func interactiveHashGen() { fmt.Printf("hash: %s\n", string(passwd)) } -func (a *App) Handler() http.Handler { +func Handler( + static fs.FS, + storage string, + users UsersMap, + jwtKey string, + sessionHours int, +) http.Handler { mux := http.NewServeMux() secHandlers := map[string]http.Handler{ - "/api/v0/new": a.HandleNew(), - "/api/v0/list": a.HandleList(), - "/api/v0/del/": http.StripPrefix( - "/api/v0/del/", - a.HandleDel()), - - "/api/v1/new": a.HandleNewJSON(), - "/api/v1/list": a.HandleListJSON(), - "/api/v1/del/": http.StripPrefix( - "/api/v1/del/", - a.HandleDelJSON()), + "/plain/new": HandleNew(storage), + "/plain/list": HandleList(storage), + "/plain/del/": http.StripPrefix("/plain/del/", HandleDel(storage)), } handlers := map[string]http.Handler{ - "/api/v1/view/": http.StripPrefix( - "/api/v1/view/", - a.HandleViewJSON()), + "/plain/view/": http.StripPrefix("/plain/view/", HandleViewPlain(storage)), - "/api/v0/view/": http.StripPrefix( - "/api/v0/view/", - a.HandleViewPlain()), + "/login": loginHandler(static, users, jwtKey, sessionHours), + "/logout": logoutHandler(), - "/api/v1/login": loginHandler(a.users, a.jwtKey, a.sessionHours), - "/api/v1/logout": logoutHandler(), + "/static/": http.FileServer(http.FS(static)), - "/_app/": http.FileServer(http.FS(a.static)), - "/": a.HandleIndex(a.static), + "/": HandleIndex(static), } - if a.staticHandler != nil { - delete(handlers, "/_app") - handlers["/"] = a.staticHandler - } - - if len(a.users) > 0 { - for user := range a.users { + if len(users) > 0 { + for user := range users { logger.Println("Found user:", user) } for pth, handler := range secHandlers { - mux.Handle(pth, a.users.AuthHandler( - handler, - a.jwtKey, - )) + mux.Handle(pth, UserAuthHandler(users, handler, jwtKey)) } } else { _, _ = fmt.Fprintf(os.Stderr, @@ -361,15 +315,18 @@ func LoadUsersFromEnviron() UsersMap { return users } -func (a *App) HandleIndex(f fs.FS) http.Handler { +func HandleIndex(f fs.FS) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pth := filepath.Clean(r.URL.Path) fh, err := f.Open(pth) if err != nil { fh, err = f.Open("index.html") if err != nil { + logger.Println(err) sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) + return } } @@ -388,23 +345,40 @@ func genTokenKey() string { return base64.RawURLEncoding.EncodeToString(r) } -// sendJSON Sends down the response to the writer specified, automatically -// encoding the response status code in the header, JSON, and a friendly -// message as well as any other exported members of the Response struct. -func sendJSON(er Response) { - er.w.WriteHeader(er.Code) - er.w.Header().Add("Content-type", "application/json") - enc := json.NewEncoder(er.w) - enc.SetIndent("", " ") - _ = enc.Encode(er) +func isHTMX(r *http.Request) bool { + return r.Header.Get("HX-Request") == "true" +} + +func renderHX(fs fs.FS, isHX bool, wr io.Writer, + data any, templates ...string) error { + if isHX { + templates = append([]string{"htmx.html"}, templates...) + } else { + templates = append([]string{"base.html"}, templates...) + } + + return renderTemplate(fs, wr, data, templates...) +} + +func renderErrorTemplate(fs fs.FS, wr io.Writer, data any) error { + return renderTemplate(fs, wr, data, "base.html", "error.html") +} + +func renderTemplate(fs fs.FS, wr io.Writer, data any, templates ...string) error { + tpl, err := template.ParseFS(fs, templates...) + if err != nil { + return err + } + + return tpl.Execute(wr, data) } // sendPlain takes in a response struct, and writes the msg string to the // output verbatim unless the rdr is not nil, in which case the rdr // is used instead func sendPlain(r Response, rdr io.Reader) { + r.w.Header().Add("Content-type", "text/plain") r.w.WriteHeader(r.Code) - r.w.Header().Add("Content-type", "application/json") if rdr != nil { _, _ = io.Copy(r.w, rdr) } else { @@ -430,7 +404,7 @@ func GenId() string { // as well as JWTs type UsersMap map[string]string -func (users UsersMap) HasValidJWT(tokenString, jwtKey string) bool { +func HasValidJWT(users UsersMap, tokenString, jwtKey string) bool { token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { @@ -462,7 +436,7 @@ func (users UsersMap) HasValidJWT(tokenString, jwtKey string) bool { // IsValidLogin lets you check directly whether a particular username and // password pair is valid for the UsersMap in question -func (users UsersMap) IsValidLogin(username, password string) bool { +func IsValidLogin(users UsersMap, username, password string) bool { if _, haveUser := users[username]; !haveUser { return false } @@ -482,13 +456,13 @@ func (users UsersMap) IsValidLogin(username, password string) bool { // HasValidPlainAuth will return true if the request contains a basic auth // header that validates one of the users in the UsersMap -func (users UsersMap) HasValidPlainAuth(r *http.Request) bool { +func HasValidPlainAuth(users UsersMap, r *http.Request) bool { username, passwd, ok := r.BasicAuth() if !ok { return false } - return users.IsValidLogin(username, passwd) + return IsValidLogin(users, username, passwd) } // getCookie simply returns a cookie value if any, or an empty string if @@ -503,28 +477,19 @@ func getCookie(r *http.Request, name string) string { return c.Value } -// respAskForBasicLogin simply adds a header telling web browsers we need -// authentication -func respAskForBasicLogin(w http.ResponseWriter) { - w.Header().Add("WWW-Authenticate", - "Basic realm=\"Login\", charset=\"UTF-8\"") -} - // AuthHandler is a fairly basic authentication middleware that validates // credentials based on either a valid JWT in the 'Auth' cookie or // Username/Password via Basic Auth -func (users UsersMap) AuthHandler(next http.Handler, jwtKey string, +func UserAuthHandler(users UsersMap, next http.Handler, jwtKey string, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - respUnAuth := func() { - // respAskForBasicLogin(w) - sendJSON(Response{w, http.StatusUnauthorized, "Unauthorized", nil}) - } - if !users.HasValidPlainAuth(r) && - !users.HasValidJWT(getCookie(r, "Auth"), jwtKey) { + if !(HasValidPlainAuth(users, r) || + HasValidJWT(users, getCookie(r, "Auth"), jwtKey)) { - respUnAuth() + // Ask for basic auth + w.Header().Add("WWW-Authenticate", `Basic realm=<realm>, charset="UTF-8"`) + w.WriteHeader(http.StatusUnauthorized) return } @@ -532,56 +497,79 @@ func (users UsersMap) AuthHandler(next http.Handler, jwtKey string, }) } -func loginHandler(users UsersMap, jwtKey string, sessionHours int, +func loginHandler(f fs.FS, users UsersMap, jwtKey string, sessionHours int, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + logIfErr(renderHX(f, isHTMX(r), w, nil, "login.html")) + return + } + if r.Method != "POST" { - sendJSON(Response{w, http.StatusBadRequest, - "Invalid type. POST only", nil}) + w.WriteHeader(http.StatusBadRequest) + logIfErr(renderErrorTemplate(f, w, "Invalid request method"+r.Method)) return } err := r.ParseForm() if err != nil { - sendJSON(Response{w, http.StatusBadRequest, "Invalid form", nil}) + w.WriteHeader(http.StatusBadRequest) + logIfErr(renderErrorTemplate(f, w, "Invalid form")) return } username := r.PostFormValue("username") password := r.PostFormValue("password") - if !users.IsValidLogin(username, password) { - sendJSON(Response{w, http.StatusUnauthorized, - "Invalid username or password", nil}) + if !IsValidLogin(users, username, password) { + data := map[string]string{"username": username} + logIfErr(renderHX(f, isHTMX(r), w, data, "login.html")) return } - expires := time.Now().Add(time.Hour * time.Duration(sessionHours)) - claims := &jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(expires), - ID: username, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) - ss, err := token.SignedString([]byte(jwtKey)) + c, err := getLoginCookie(username, jwtKey, sessionHours) if err != nil { - sendJSON(Response{w, http.StatusInternalServerError, - "Invalid username or password", nil}) + w.WriteHeader(http.StatusInternalServerError) + logIfErr(renderErrorTemplate(f, w, "Internal server error")) + logIfErr(err) return } - http.SetCookie(w, &http.Cookie{ - Name: "Auth", - HttpOnly: true, - SameSite: http.SameSiteStrictMode, - Value: ss, - Secure: true, - }) + http.SetCookie(w, c) + + if isHTMX(r) { + w.Header().Set("HX-Redirect", "/") + return + } http.Redirect(w, r, "/", http.StatusFound) }) } +func getLoginCookie(username string, jwtKey string, sessionHours int, +) (*http.Cookie, error) { + expires := time.Now().Add(time.Hour * time.Duration(sessionHours)) + + claims := &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expires), + ID: username, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) + ss, err := token.SignedString([]byte(jwtKey)) + if err != nil { + return nil, err + } + + return &http.Cookie{ + Name: "Auth", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Value: ss, + Secure: true, + }, err +} + func logoutHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ @@ -589,7 +577,7 @@ func logoutHandler() http.Handler { HttpOnly: true, SameSite: http.SameSiteStrictMode, Value: "logout", - Expires: time.Now().Add(time.Second), //nolint + Expires: time.Now().Add(time.Second), }) http.Redirect(w, r, "/", http.StatusFound) @@ -600,11 +588,11 @@ func OpenW(filename string) (*os.File, error) { return os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE) } -func (a *App) HandleNew() http.Handler { +func HandleNew(storagePath string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := GenId() - fh, err := OpenW(filepath.Join(a.storage, id)) + fh, err := OpenW(filepath.Join(storagePath, id)) if err != nil { sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) @@ -626,71 +614,14 @@ type NewPost struct { Content string `json:"content"` } -func (a *App) HandleNewJSON() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id := GenId() - - p := &NewPost{} - - dec := json.NewDecoder(r.Body) - - err := dec.Decode(p) - if err != nil || p.Content == "" { - sendJSON(Response{w, http.StatusBadRequest, - "Malformed JSON", nil}) - return - } - - fh, err := OpenW(filepath.Join(a.storage, id)) - if err != nil { - sendJSON(Response{w, http.StatusInternalServerError, - "Internal server error", nil}) - return - } - - _, err = fh.Write([]byte(p.Content)) - if err != nil { - sendJSON(Response{w, http.StatusInternalServerError, - "Internal server error", nil}) - return - } - - sendJSON(Response{w, http.StatusOK, "OK", - struct{ Id string }{id}}) - }) -} - func getPastePathFromRawURL(storage, u string) string { return filepath.Join(storage, strings.ReplaceAll( filepath.Clean(u), "/", "")) } -func (a *App) HandleViewJSON() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pth := getPastePathFromRawURL(a.storage, r.URL.Path) - fh, err := os.Open(pth) - if err != nil { - sendJSON(Response{w, http.StatusNotFound, - "ID not found", nil}) - return - } - - b, err := io.ReadAll(fh) - if err != nil { - sendJSON(Response{w, http.StatusInternalServerError, - "Internal server error", nil}) - return - } - - sendJSON(Response{w, http.StatusOK, "Ok", struct{ Content string }{ - string(b), - }}) - }) -} - -func (a *App) HandleViewPlain() http.Handler { +func HandleViewPlain(storagePath string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pth := getPastePathFromRawURL(a.storage, r.URL.Path) + pth := getPastePathFromRawURL(storagePath, r.URL.Path) fh, err := os.Open(pth) if err != nil { sendPlain(Response{w, http.StatusNotFound, @@ -703,9 +634,9 @@ func (a *App) HandleViewPlain() http.Handler { }) } -func (a *App) HandleDel() http.Handler { +func HandleDel(storagePath string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pth := getPastePathFromRawURL(a.storage, r.URL.Path) + pth := getPastePathFromRawURL(storagePath, r.URL.Path) err := os.Remove(pth) if err != nil { @@ -724,29 +655,9 @@ func (a *App) HandleDel() http.Handler { }) } -func (a *App) HandleDelJSON() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger.Println("Made it to delete...") - pth := getPastePathFromRawURL(a.storage, r.URL.Path) - logger.Println("Deleting path: ", pth) - err := os.Remove(pth) - if err != nil { - if os.IsNotExist(err) { - sendJSON(Response{w, http.StatusNotFound, "Not found", nil}) - return - } - sendJSON(Response{w, http.StatusInternalServerError, - "Internal server error", nil}) - return - } - - sendJSON(Response{w, http.StatusOK, "Ok", nil}) - }) -} - -func (a *App) HandleList() http.Handler { +func HandleList(storagePath string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pl, err := LoadPastes(a.storage) + pl, err := LoadPastes(storagePath) if err != nil { sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) @@ -764,21 +675,6 @@ func (a *App) HandleList() http.Handler { }) } -func (a *App) HandleListJSON() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pl, err := LoadPastes(a.storage) - if err != nil { - sendJSON(Response{w, http.StatusInternalServerError, - "Internal server error", nil}) - return - } - - pl = handleSkipLimitSort(pl, r.URL) - - sendJSON(Response{w, http.StatusOK, "Ok", pl}) - }) -} - type Paste struct { Id string `json:"id"` Created time.Time `json:"created"` |
