diff options
| author | Mitchell Riedstra <mitch@riedstra.dev> | 2023-01-03 19:28:25 -0500 |
|---|---|---|
| committer | Mitchell Riedstra <mitch@riedstra.dev> | 2023-01-03 19:28:25 -0500 |
| commit | 7c01e12341f79a7bdf085a141e303d209fd8b3c5 (patch) | |
| tree | 272d33cf2ca840377322f97e8f728f1e7e68ed85 | |
| parent | 9fbe8b79f7bc12a71b62722b06f7e93334da1a52 (diff) | |
| download | paste-7c01e12341f79a7bdf085a141e303d209fd8b3c5.tar.gz paste-7c01e12341f79a7bdf085a141e303d209fd8b3c5.tar.xz | |
Add JWT support for logins. Add delete option.
| -rw-r--r-- | go.mod | 5 | ||||
| -rw-r--r-- | go.sum | 2 | ||||
| -rw-r--r-- | main.go | 365 |
3 files changed, 306 insertions, 66 deletions
@@ -7,4 +7,7 @@ require ( golang.org/x/term v0.3.0 ) -require golang.org/x/sys v0.3.0 // indirect +require ( + github.com/golang-jwt/jwt/v4 v4.4.3 // indirect + golang.org/x/sys v0.3.0 // indirect +) @@ -1,3 +1,5 @@ +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= @@ -22,6 +22,7 @@ import ( "strings" "time" + jwt "github.com/golang-jwt/jwt/v4" "golang.org/x/crypto/bcrypt" "golang.org/x/term" ) @@ -42,6 +43,8 @@ var ( type App struct { static fs.FS users map[string]string + jwtKey string + sessionHours int storage string // path to where the paste files are actually stored staticHandler http.Handler } @@ -56,32 +59,37 @@ func EnvFlagString(fl *flag.FlagSet, p *string, name, envvar, usage string) { fl.StringVar(p, name, *p, fmt.Sprintf("%s (Environ: '%s')", usage, envvar)) } -type errResp struct { +// Response encapsulates all the JSON responses from this server, Code +// is a copy of the HTTP status code, Msg is a human-readable statement +// about the particular response. Data is optional +type Response struct { w http.ResponseWriter Code int Msg string + Data any `json:",omitempty"` } func main() { var ( - listen = ":6130" - idBytes = "8" - debugF = "false" - debug = false - genhash = false - storage = "" - fsdir = "" - proxyURL = "" - static fs.FS - err error + listen = ":6130" + idBytes = "8" + debugF = "false" + debug bool + genhash = false + storage = "" + fsdir = "" + proxyURL = "" + jwtKey = "" + sessionHours = "12" + rSessionHours = 12 + static fs.FS + err error ) - dumpFStree(staticEmbedded) static, err = fs.Sub(staticEmbedded, embeddedPrefix) if err != nil { logger.Fatal("Embedding failed no static directory") } - dumpFStree(static) fl := flag.NewFlagSet("Simple pastebin", flag.ExitOnError) EnvFlagString(fl, &listen, "listen", "LISTEN_ADDR", @@ -89,13 +97,17 @@ func main() { EnvFlagString(fl, &idBytes, "b", "ID_BYTES", "How many bytes long should the IDs be") EnvFlagString(fl, &debugF, "d", "DEBUG", - "Additoinal debugging if set to true") + "Additional debugging if set to true") EnvFlagString(fl, &storage, "s", "STORAGE_DIR", "Path of directory to serve") EnvFlagString(fl, &fsdir, "fs", "FS_DIR", "Path which static assets are located, empty to use embedded") EnvFlagString(fl, &proxyURL, "sp", "STATIC_PROXY", "What server do we proxy to for static content?") + EnvFlagString(fl, &jwtKey, "jwt", "JWT_KEY", + "If supplied use the value as the JWT key instead of a random value") + EnvFlagString(fl, &sessionHours, "hours", "SESSION_HOURS", + "How many hours should login sessions last?") fl.BoolVar(&genhash, "genhash", genhash, "Interactively prompt for a password and spit out a hash\n") version := fl.Bool("v", false, "Print version then exit") @@ -126,7 +138,8 @@ func main() { static = os.DirFS(fsdir) } - if b, err := strconv.Atoi(idBytes); err != nil && b > 4 { + if b, err := strconv.Atoi(idBytes); err == nil && b > 4 { + logger.Printf("Setting ID_BYTES: %d\n", b) ID_BYTES = b } @@ -135,6 +148,15 @@ func main() { "`-s` flag or STORAGE_DIR environment variable") } + if jwtKey == "" { + jwtKey = genTokenKey() + } + + if h, err := strconv.Atoi(sessionHours); err == nil && h > 0 { + logger.Printf("Setting SESSION_HOURS: %d\n", h) + rSessionHours = h + } + err = os.MkdirAll(storage, 0777) if err != nil { logger.Fatal("Failed to create storage directory: ", err) @@ -145,15 +167,17 @@ func main() { } app := &App{ - static: static, - storage: storage, - users: getUsersFromEnviron(), + static: static, + storage: storage, + jwtKey: jwtKey, + sessionHours: rSessionHours, + users: getUsersFromEnviron(), } rp, err := getProxyHandler(proxyURL) if proxyURL != "" && err == nil { app.staticHandler = rp - logger.Println("Proxying static requests to: ", rp) + logger.Println("Proxying static requests to: ", proxyURL) } else if err != nil { logger.Printf("Warning, invalid url: '%s': %s", proxyURL, err) } @@ -171,7 +195,7 @@ func main() { func dumpFStree(f fs.FS) { logger.Println("dumping fs tree....") - fs.WalkDir(f, ".", func(path string, d fs.DirEntry, err error) error { + _ = fs.WalkDir(f, ".", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -236,9 +260,15 @@ func (a *App) Handler() http.Handler { secHandlers := map[string]http.Handler{ "/api/v0/new": a.HandleNew(), "/api/v0/list": a.HandleList(), + "/api/v0/del/": http.StripPrefix( + "/api/v0/del/", + a.HandleDel()), "/api/v1/new": a.HandleNewJSON(), "/api/v1/list": a.HandleListJSON(), + "/api/v1/del/": http.StripPrefix( + "/api/v1/del/", + a.HandleDelJSON()), } handlers := map[string]http.Handler{ @@ -250,6 +280,9 @@ func (a *App) Handler() http.Handler { "/api/v0/view/", a.HandleViewPlain()), + "/api/v1/login": loginHandler(a.users, a.jwtKey, a.sessionHours), + "/api/v1/logout": logoutHandler(), + "/": http.FileServer(http.FS(a.static)), } @@ -266,6 +299,7 @@ func (a *App) Handler() http.Handler { mux.Handle(pth, authHandler( handler, a.users, + a.jwtKey, )) } } else { @@ -302,13 +336,37 @@ func getUsersFromEnviron() map[string]string { return users } -func sendErr(er errResp) { +func genTokenKey() string { + r := make([]byte, 16) // 128 bits + _, err := rand.Read(r) + + if err != nil { + panic(fmt.Errorf("reading random bytes: %w", err)) + } + + return base64.RawURLEncoding.EncodeToString(r) +} + +func sendJSON(er Response) { er.w.WriteHeader(er.Code) er.w.Header().Add("Content-type", "application/json") enc := json.NewEncoder(er.w) _ = enc.Encode(er) } +// sendPlain takes in a response struct, and writes the msg string to the +// output verbatim unless the rdr is not nil, in which case the rdr +// is used instead +func sendPlain(r Response, rdr io.Reader) { + r.w.WriteHeader(r.Code) + r.w.Header().Add("Content-type", "application/json") + if rdr != nil { + _, _ = io.Copy(r.w, rdr) + } else { + _, _ = r.w.Write([]byte(r.Msg)) + } +} + func GenId() string { r := make([]byte, ID_BYTES) _, err := rand.Read(r) @@ -319,64 +377,188 @@ func GenId() string { return base64.RawURLEncoding.EncodeToString(r) } -func authHandler(next http.Handler, users map[string]string) http.Handler { +func hasValidJWT(users map[string]string, tokenString, jwtKey string) bool { + token, err := jwt.ParseWithClaims(tokenString, + &jwt.RegisteredClaims{}, + func(token *jwt.Token) (interface{}, error) { + return []byte(jwtKey), nil + }) + if err != nil { + logger.Printf("failed parsing token: %s", err) + return false + } + + if !token.Valid { + return false + } + + // logger.Printf("Type of token.Claims: %s\n", reflect.TypeOf(token.Claims)) + claims, ok := token.Claims.(*jwt.RegisteredClaims) + if !ok { + logger.Println("Invalid claims for token") + return false + } + + _, ok = users[claims.ID] + if !ok { + return false + } + + return err == nil +} + +func isValidLogin(users map[string]string, username, password string) bool { + if _, haveUser := users[username]; !haveUser { + return false + } + + err := bcrypt.CompareHashAndPassword( + []byte(users[username]), + []byte(password), + ) + + if errors.Is(err, bcrypt.ErrHashTooShort) { + logger.Println("Hash too short for username: ", + username) + } + + return err == nil +} + +func hasValidPlainAuth(r *http.Request, users map[string]string) bool { + username, passwd, ok := r.BasicAuth() + if !ok { + return false + } + + return isValidLogin(users, username, passwd) +} + +// getCookie simply returns a cookie value if any, or an empty string if +// it's invalid or there are any errors. +func getCookie(r *http.Request, name string) string { + c, err := r.Cookie(name) + // Normally I'd also check c.Valid()... Expires is optional, but + // apparently makes it invalid? + if err != nil { + return "" + } + fmt.Printf("Cookie value: '%s'\n", c.Value) + return c.Value +} + +// respAskForBasicLogin simply adds a header telling web browsers we need +// authentication +func respAskForBasicLogin(w http.ResponseWriter) { + w.Header().Add("WWW-Authenticate", + "Basic realm=\"Login\", charset=\"UTF-8\"") +} + +func authHandler(next http.Handler, users map[string]string, jwtKey string, +) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { respUnAuth := func() { - logger.Println("Unauthed") - w.Header().Add("Content-type", "text/plain") - w.Header().Add("WWW-Authenticate", - "Basic realm=\"Login\", charset=\"UTF-8\"") - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte("Unauthorized\n")) + // respAskForBasicLogin(w) + sendJSON(Response{w, http.StatusUnauthorized, "Unauthorized", nil}) } - username, passwd, ok := r.BasicAuth() - if _, haveUser := users[username]; !ok || !haveUser { + if !hasValidPlainAuth(r, users) && + !hasValidJWT(users, getCookie(r, "Auth"), jwtKey) { + respUnAuth() return } - err := bcrypt.CompareHashAndPassword( - []byte(users[username]), - []byte(passwd), - ) + next.ServeHTTP(w, r) + }) +} + +func loginHandler(users map[string]string, jwtKey string, sessionHours int, +) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + sendJSON(Response{w, http.StatusBadRequest, + "Invalid type. POST only", nil}) + return + } + + err := r.ParseForm() + if err != nil { + sendJSON(Response{w, http.StatusBadRequest, "Invalid form", nil}) + return + } - if errors.Is(err, bcrypt.ErrHashTooShort) { - logger.Println("Hash too short for username: ", - username) + username := r.PostFormValue("username") + password := r.PostFormValue("password") + + if !isValidLogin(users, username, password) { + sendJSON(Response{w, http.StatusUnauthorized, + "Invalid username or password", nil}) + return } + expires := time.Now().Add(time.Hour * time.Duration(sessionHours)) + claims := &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expires), + ID: username, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) + ss, err := token.SignedString([]byte(jwtKey)) if err != nil { - respUnAuth() + sendJSON(Response{w, http.StatusInternalServerError, + "Invalid username or password", nil}) return } - next.ServeHTTP(w, r) + http.SetCookie(w, &http.Cookie{ + Name: "Auth", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Value: ss, + }) + + http.Redirect(w, r, "/", http.StatusFound) + }) +} + +func logoutHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{ + Name: "Auth", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Value: "logout", + Expires: time.Now().Add(time.Second), //nolint + }) + + http.Redirect(w, r, "/", http.StatusFound) }) } +func OpenW(filename string) (*os.File, error) { + return os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE) +} + func (a *App) HandleNew() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := GenId() - fh, err := os.OpenFile(id, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE) + fh, err := OpenW(filepath.Join(a.storage, id)) if err != nil { - sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) + sendPlain(Response{w, http.StatusInternalServerError, + "Internal server error", nil}, nil) return } _, err = io.Copy(fh, r.Body) if err != nil { - sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) + sendPlain(Response{w, http.StatusInternalServerError, + "Internal server error", nil}, nil) return } - w.WriteHeader(http.StatusOK) - w.Header().Add("Content-type", "text/plain") - enc := json.NewEncoder(w) - _ = enc.Encode(struct { - Id string - }{id}) + sendPlain(Response{w, http.StatusOK, id + "\n", nil}, nil) }) } @@ -394,38 +576,49 @@ func (a *App) HandleNewJSON() http.Handler { err := dec.Decode(p) if err != nil || p.Content == "" { - sendErr(errResp{w, http.StatusBadRequest, "Malformed JSON"}) + sendJSON(Response{w, http.StatusBadRequest, + "Malformed JSON", nil}) return } - fh, err := os.OpenFile(id, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE) + fh, err := OpenW(filepath.Join(a.storage, id)) if err != nil { - sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) + sendJSON(Response{w, http.StatusInternalServerError, + "Internal server error", nil}) return } _, err = fh.Write([]byte(p.Content)) if err != nil { - sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) + sendJSON(Response{w, http.StatusInternalServerError, + "Internal server error", nil}) return } - sendErr(errResp{w, http.StatusOK, "OK"}) + sendJSON(Response{w, http.StatusOK, "OK", + struct{ Id string }{id}}) }) } +func getPastePathFromRawURL(storage, u string) string { + return filepath.Join(storage, strings.ReplaceAll( + filepath.Clean(u), "/", "")) +} + func (a *App) HandleViewJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pth := filepath.Join(a.storage, strings.ReplaceAll(filepath.Clean(r.URL.Path), "/", "")) + pth := getPastePathFromRawURL(a.storage, r.URL.Path) fh, err := os.Open(pth) if err != nil { - sendErr(errResp{w, http.StatusNotFound, "ID not found"}) + sendJSON(Response{w, http.StatusNotFound, + "ID not found", nil}) return } b, err := io.ReadAll(fh) if err != nil { - sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) + sendJSON(Response{w, http.StatusInternalServerError, + "Internal server error", nil}) return } @@ -440,16 +633,57 @@ func (a *App) HandleViewJSON() http.Handler { func (a *App) HandleViewPlain() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pth := filepath.Join(a.storage, strings.ReplaceAll(filepath.Clean(r.URL.Path), "/", "")) + pth := getPastePathFromRawURL(a.storage, r.URL.Path) fh, err := os.Open(pth) if err != nil { - sendErr(errResp{w, http.StatusNotFound, "ID not found"}) + sendPlain(Response{w, http.StatusNotFound, + "ID not found", nil}, nil) return } + defer fh.Close() - w.WriteHeader(http.StatusOK) - w.Header().Add("Content-type", "text/plain") - _, _ = io.Copy(w, fh) + sendPlain(Response{w, http.StatusOK, "", nil}, fh) + }) +} + +func (a *App) HandleDel() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pth := getPastePathFromRawURL(a.storage, r.URL.Path) + + err := os.Remove(pth) + if err != nil { + if os.IsNotExist(err) { + sendPlain( + Response{w, http.StatusNotFound, "Not found", nil}, + nil) + return + } + sendPlain(Response{w, http.StatusInternalServerError, + "Internal server error", nil}, nil) + return + } + + sendPlain(Response{w, http.StatusOK, "Ok", nil}, nil) + }) +} + +func (a *App) HandleDelJSON() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger.Println("Made it to delete...") + pth := getPastePathFromRawURL(a.storage, r.URL.Path) + logger.Println("Deleting path: ", pth) + err := os.Remove(pth) + if err != nil { + if os.IsNotExist(err) { + sendJSON(Response{w, http.StatusNotFound, "Not found", nil}) + return + } + sendJSON(Response{w, http.StatusInternalServerError, + "Internal server error", nil}) + return + } + + sendJSON(Response{w, http.StatusOK, "Ok", nil}) }) } @@ -457,7 +691,8 @@ func (a *App) HandleList() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pl, err := LoadPastes(a.storage) if err != nil { - sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) + sendPlain(Response{w, http.StatusInternalServerError, + "Internal server error", nil}, nil) return } @@ -476,8 +711,8 @@ func (a *App) HandleListJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pl, err := LoadPastes(a.storage) if err != nil { - sendErr(errResp{w, http.StatusInternalServerError, - "Internal server error"}) + sendJSON(Response{w, http.StatusInternalServerError, + "Internal server error", nil}) return } @@ -541,7 +776,7 @@ func (pl PasteListing) SortDate() { func getInt(u *url.URL, param string) (int, error) { i := u.Query().Get(param) if i == "" { - return 0, fmt.Errorf("No param '%s' supplied", param) + return 0, fmt.Errorf("no param '%s' supplied", param) } n, err := strconv.Atoi(i) |
