aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMitchell Riedstra <mitch@riedstra.dev>2023-01-03 19:28:25 -0500
committerMitchell Riedstra <mitch@riedstra.dev>2023-01-03 19:28:25 -0500
commit7c01e12341f79a7bdf085a141e303d209fd8b3c5 (patch)
tree272d33cf2ca840377322f97e8f728f1e7e68ed85
parent9fbe8b79f7bc12a71b62722b06f7e93334da1a52 (diff)
downloadpaste-7c01e12341f79a7bdf085a141e303d209fd8b3c5.tar.gz
paste-7c01e12341f79a7bdf085a141e303d209fd8b3c5.tar.xz
Add JWT support for logins. Add delete option.
-rw-r--r--go.mod5
-rw-r--r--go.sum2
-rw-r--r--main.go365
3 files changed, 306 insertions, 66 deletions
diff --git a/go.mod b/go.mod
index 9f3323b..150eecf 100644
--- a/go.mod
+++ b/go.mod
@@ -7,4 +7,7 @@ require (
golang.org/x/term v0.3.0
)
-require golang.org/x/sys v0.3.0 // indirect
+require (
+ github.com/golang-jwt/jwt/v4 v4.4.3 // indirect
+ golang.org/x/sys v0.3.0 // indirect
+)
diff --git a/go.sum b/go.sum
index 4742e51..a0552d8 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,5 @@
+github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
+github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ=
diff --git a/main.go b/main.go
index e9c6020..c03a0d7 100644
--- a/main.go
+++ b/main.go
@@ -22,6 +22,7 @@ import (
"strings"
"time"
+ jwt "github.com/golang-jwt/jwt/v4"
"golang.org/x/crypto/bcrypt"
"golang.org/x/term"
)
@@ -42,6 +43,8 @@ var (
type App struct {
static fs.FS
users map[string]string
+ jwtKey string
+ sessionHours int
storage string // path to where the paste files are actually stored
staticHandler http.Handler
}
@@ -56,32 +59,37 @@ func EnvFlagString(fl *flag.FlagSet, p *string, name, envvar, usage string) {
fl.StringVar(p, name, *p, fmt.Sprintf("%s (Environ: '%s')", usage, envvar))
}
-type errResp struct {
+// Response encapsulates all the JSON responses from this server, Code
+// is a copy of the HTTP status code, Msg is a human-readable statement
+// about the particular response. Data is optional
+type Response struct {
w http.ResponseWriter
Code int
Msg string
+ Data any `json:",omitempty"`
}
func main() {
var (
- listen = ":6130"
- idBytes = "8"
- debugF = "false"
- debug = false
- genhash = false
- storage = ""
- fsdir = ""
- proxyURL = ""
- static fs.FS
- err error
+ listen = ":6130"
+ idBytes = "8"
+ debugF = "false"
+ debug bool
+ genhash = false
+ storage = ""
+ fsdir = ""
+ proxyURL = ""
+ jwtKey = ""
+ sessionHours = "12"
+ rSessionHours = 12
+ static fs.FS
+ err error
)
- dumpFStree(staticEmbedded)
static, err = fs.Sub(staticEmbedded, embeddedPrefix)
if err != nil {
logger.Fatal("Embedding failed no static directory")
}
- dumpFStree(static)
fl := flag.NewFlagSet("Simple pastebin", flag.ExitOnError)
EnvFlagString(fl, &listen, "listen", "LISTEN_ADDR",
@@ -89,13 +97,17 @@ func main() {
EnvFlagString(fl, &idBytes, "b", "ID_BYTES",
"How many bytes long should the IDs be")
EnvFlagString(fl, &debugF, "d", "DEBUG",
- "Additoinal debugging if set to true")
+ "Additional debugging if set to true")
EnvFlagString(fl, &storage, "s", "STORAGE_DIR",
"Path of directory to serve")
EnvFlagString(fl, &fsdir, "fs", "FS_DIR",
"Path which static assets are located, empty to use embedded")
EnvFlagString(fl, &proxyURL, "sp", "STATIC_PROXY",
"What server do we proxy to for static content?")
+ EnvFlagString(fl, &jwtKey, "jwt", "JWT_KEY",
+ "If supplied use the value as the JWT key instead of a random value")
+ EnvFlagString(fl, &sessionHours, "hours", "SESSION_HOURS",
+ "How many hours should login sessions last?")
fl.BoolVar(&genhash, "genhash", genhash,
"Interactively prompt for a password and spit out a hash\n")
version := fl.Bool("v", false, "Print version then exit")
@@ -126,7 +138,8 @@ func main() {
static = os.DirFS(fsdir)
}
- if b, err := strconv.Atoi(idBytes); err != nil && b > 4 {
+ if b, err := strconv.Atoi(idBytes); err == nil && b > 4 {
+ logger.Printf("Setting ID_BYTES: %d\n", b)
ID_BYTES = b
}
@@ -135,6 +148,15 @@ func main() {
"`-s` flag or STORAGE_DIR environment variable")
}
+ if jwtKey == "" {
+ jwtKey = genTokenKey()
+ }
+
+ if h, err := strconv.Atoi(sessionHours); err == nil && h > 0 {
+ logger.Printf("Setting SESSION_HOURS: %d\n", h)
+ rSessionHours = h
+ }
+
err = os.MkdirAll(storage, 0777)
if err != nil {
logger.Fatal("Failed to create storage directory: ", err)
@@ -145,15 +167,17 @@ func main() {
}
app := &App{
- static: static,
- storage: storage,
- users: getUsersFromEnviron(),
+ static: static,
+ storage: storage,
+ jwtKey: jwtKey,
+ sessionHours: rSessionHours,
+ users: getUsersFromEnviron(),
}
rp, err := getProxyHandler(proxyURL)
if proxyURL != "" && err == nil {
app.staticHandler = rp
- logger.Println("Proxying static requests to: ", rp)
+ logger.Println("Proxying static requests to: ", proxyURL)
} else if err != nil {
logger.Printf("Warning, invalid url: '%s': %s", proxyURL, err)
}
@@ -171,7 +195,7 @@ func main() {
func dumpFStree(f fs.FS) {
logger.Println("dumping fs tree....")
- fs.WalkDir(f, ".", func(path string, d fs.DirEntry, err error) error {
+ _ = fs.WalkDir(f, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
@@ -236,9 +260,15 @@ func (a *App) Handler() http.Handler {
secHandlers := map[string]http.Handler{
"/api/v0/new": a.HandleNew(),
"/api/v0/list": a.HandleList(),
+ "/api/v0/del/": http.StripPrefix(
+ "/api/v0/del/",
+ a.HandleDel()),
"/api/v1/new": a.HandleNewJSON(),
"/api/v1/list": a.HandleListJSON(),
+ "/api/v1/del/": http.StripPrefix(
+ "/api/v1/del/",
+ a.HandleDelJSON()),
}
handlers := map[string]http.Handler{
@@ -250,6 +280,9 @@ func (a *App) Handler() http.Handler {
"/api/v0/view/",
a.HandleViewPlain()),
+ "/api/v1/login": loginHandler(a.users, a.jwtKey, a.sessionHours),
+ "/api/v1/logout": logoutHandler(),
+
"/": http.FileServer(http.FS(a.static)),
}
@@ -266,6 +299,7 @@ func (a *App) Handler() http.Handler {
mux.Handle(pth, authHandler(
handler,
a.users,
+ a.jwtKey,
))
}
} else {
@@ -302,13 +336,37 @@ func getUsersFromEnviron() map[string]string {
return users
}
-func sendErr(er errResp) {
+func genTokenKey() string {
+ r := make([]byte, 16) // 128 bits
+ _, err := rand.Read(r)
+
+ if err != nil {
+ panic(fmt.Errorf("reading random bytes: %w", err))
+ }
+
+ return base64.RawURLEncoding.EncodeToString(r)
+}
+
+func sendJSON(er Response) {
er.w.WriteHeader(er.Code)
er.w.Header().Add("Content-type", "application/json")
enc := json.NewEncoder(er.w)
_ = enc.Encode(er)
}
+// sendPlain takes in a response struct, and writes the msg string to the
+// output verbatim unless the rdr is not nil, in which case the rdr
+// is used instead
+func sendPlain(r Response, rdr io.Reader) {
+ r.w.WriteHeader(r.Code)
+ r.w.Header().Add("Content-type", "application/json")
+ if rdr != nil {
+ _, _ = io.Copy(r.w, rdr)
+ } else {
+ _, _ = r.w.Write([]byte(r.Msg))
+ }
+}
+
func GenId() string {
r := make([]byte, ID_BYTES)
_, err := rand.Read(r)
@@ -319,64 +377,188 @@ func GenId() string {
return base64.RawURLEncoding.EncodeToString(r)
}
-func authHandler(next http.Handler, users map[string]string) http.Handler {
+func hasValidJWT(users map[string]string, tokenString, jwtKey string) bool {
+ token, err := jwt.ParseWithClaims(tokenString,
+ &jwt.RegisteredClaims{},
+ func(token *jwt.Token) (interface{}, error) {
+ return []byte(jwtKey), nil
+ })
+ if err != nil {
+ logger.Printf("failed parsing token: %s", err)
+ return false
+ }
+
+ if !token.Valid {
+ return false
+ }
+
+ // logger.Printf("Type of token.Claims: %s\n", reflect.TypeOf(token.Claims))
+ claims, ok := token.Claims.(*jwt.RegisteredClaims)
+ if !ok {
+ logger.Println("Invalid claims for token")
+ return false
+ }
+
+ _, ok = users[claims.ID]
+ if !ok {
+ return false
+ }
+
+ return err == nil
+}
+
+func isValidLogin(users map[string]string, username, password string) bool {
+ if _, haveUser := users[username]; !haveUser {
+ return false
+ }
+
+ err := bcrypt.CompareHashAndPassword(
+ []byte(users[username]),
+ []byte(password),
+ )
+
+ if errors.Is(err, bcrypt.ErrHashTooShort) {
+ logger.Println("Hash too short for username: ",
+ username)
+ }
+
+ return err == nil
+}
+
+func hasValidPlainAuth(r *http.Request, users map[string]string) bool {
+ username, passwd, ok := r.BasicAuth()
+ if !ok {
+ return false
+ }
+
+ return isValidLogin(users, username, passwd)
+}
+
+// getCookie simply returns a cookie value if any, or an empty string if
+// it's invalid or there are any errors.
+func getCookie(r *http.Request, name string) string {
+ c, err := r.Cookie(name)
+ // Normally I'd also check c.Valid()... Expires is optional, but
+ // apparently makes it invalid?
+ if err != nil {
+ return ""
+ }
+ fmt.Printf("Cookie value: '%s'\n", c.Value)
+ return c.Value
+}
+
+// respAskForBasicLogin simply adds a header telling web browsers we need
+// authentication
+func respAskForBasicLogin(w http.ResponseWriter) {
+ w.Header().Add("WWW-Authenticate",
+ "Basic realm=\"Login\", charset=\"UTF-8\"")
+}
+
+func authHandler(next http.Handler, users map[string]string, jwtKey string,
+) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
respUnAuth := func() {
- logger.Println("Unauthed")
- w.Header().Add("Content-type", "text/plain")
- w.Header().Add("WWW-Authenticate",
- "Basic realm=\"Login\", charset=\"UTF-8\"")
- w.WriteHeader(http.StatusUnauthorized)
- _, _ = w.Write([]byte("Unauthorized\n"))
+ // respAskForBasicLogin(w)
+ sendJSON(Response{w, http.StatusUnauthorized, "Unauthorized", nil})
}
- username, passwd, ok := r.BasicAuth()
- if _, haveUser := users[username]; !ok || !haveUser {
+ if !hasValidPlainAuth(r, users) &&
+ !hasValidJWT(users, getCookie(r, "Auth"), jwtKey) {
+
respUnAuth()
return
}
- err := bcrypt.CompareHashAndPassword(
- []byte(users[username]),
- []byte(passwd),
- )
+ next.ServeHTTP(w, r)
+ })
+}
+
+func loginHandler(users map[string]string, jwtKey string, sessionHours int,
+) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "POST" {
+ sendJSON(Response{w, http.StatusBadRequest,
+ "Invalid type. POST only", nil})
+ return
+ }
+
+ err := r.ParseForm()
+ if err != nil {
+ sendJSON(Response{w, http.StatusBadRequest, "Invalid form", nil})
+ return
+ }
- if errors.Is(err, bcrypt.ErrHashTooShort) {
- logger.Println("Hash too short for username: ",
- username)
+ username := r.PostFormValue("username")
+ password := r.PostFormValue("password")
+
+ if !isValidLogin(users, username, password) {
+ sendJSON(Response{w, http.StatusUnauthorized,
+ "Invalid username or password", nil})
+ return
}
+ expires := time.Now().Add(time.Hour * time.Duration(sessionHours))
+ claims := &jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(expires),
+ ID: username,
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
+ ss, err := token.SignedString([]byte(jwtKey))
if err != nil {
- respUnAuth()
+ sendJSON(Response{w, http.StatusInternalServerError,
+ "Invalid username or password", nil})
return
}
- next.ServeHTTP(w, r)
+ http.SetCookie(w, &http.Cookie{
+ Name: "Auth",
+ HttpOnly: true,
+ SameSite: http.SameSiteStrictMode,
+ Value: ss,
+ })
+
+ http.Redirect(w, r, "/", http.StatusFound)
+ })
+}
+
+func logoutHandler() http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.SetCookie(w, &http.Cookie{
+ Name: "Auth",
+ HttpOnly: true,
+ SameSite: http.SameSiteStrictMode,
+ Value: "logout",
+ Expires: time.Now().Add(time.Second), //nolint
+ })
+
+ http.Redirect(w, r, "/", http.StatusFound)
})
}
+func OpenW(filename string) (*os.File, error) {
+ return os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE)
+}
+
func (a *App) HandleNew() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := GenId()
- fh, err := os.OpenFile(id, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE)
+ fh, err := OpenW(filepath.Join(a.storage, id))
if err != nil {
- sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"})
+ sendPlain(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil}, nil)
return
}
_, err = io.Copy(fh, r.Body)
if err != nil {
- sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"})
+ sendPlain(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil}, nil)
return
}
- w.WriteHeader(http.StatusOK)
- w.Header().Add("Content-type", "text/plain")
- enc := json.NewEncoder(w)
- _ = enc.Encode(struct {
- Id string
- }{id})
+ sendPlain(Response{w, http.StatusOK, id + "\n", nil}, nil)
})
}
@@ -394,38 +576,49 @@ func (a *App) HandleNewJSON() http.Handler {
err := dec.Decode(p)
if err != nil || p.Content == "" {
- sendErr(errResp{w, http.StatusBadRequest, "Malformed JSON"})
+ sendJSON(Response{w, http.StatusBadRequest,
+ "Malformed JSON", nil})
return
}
- fh, err := os.OpenFile(id, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE)
+ fh, err := OpenW(filepath.Join(a.storage, id))
if err != nil {
- sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"})
+ sendJSON(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil})
return
}
_, err = fh.Write([]byte(p.Content))
if err != nil {
- sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"})
+ sendJSON(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil})
return
}
- sendErr(errResp{w, http.StatusOK, "OK"})
+ sendJSON(Response{w, http.StatusOK, "OK",
+ struct{ Id string }{id}})
})
}
+func getPastePathFromRawURL(storage, u string) string {
+ return filepath.Join(storage, strings.ReplaceAll(
+ filepath.Clean(u), "/", ""))
+}
+
func (a *App) HandleViewJSON() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- pth := filepath.Join(a.storage, strings.ReplaceAll(filepath.Clean(r.URL.Path), "/", ""))
+ pth := getPastePathFromRawURL(a.storage, r.URL.Path)
fh, err := os.Open(pth)
if err != nil {
- sendErr(errResp{w, http.StatusNotFound, "ID not found"})
+ sendJSON(Response{w, http.StatusNotFound,
+ "ID not found", nil})
return
}
b, err := io.ReadAll(fh)
if err != nil {
- sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"})
+ sendJSON(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil})
return
}
@@ -440,16 +633,57 @@ func (a *App) HandleViewJSON() http.Handler {
func (a *App) HandleViewPlain() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- pth := filepath.Join(a.storage, strings.ReplaceAll(filepath.Clean(r.URL.Path), "/", ""))
+ pth := getPastePathFromRawURL(a.storage, r.URL.Path)
fh, err := os.Open(pth)
if err != nil {
- sendErr(errResp{w, http.StatusNotFound, "ID not found"})
+ sendPlain(Response{w, http.StatusNotFound,
+ "ID not found", nil}, nil)
return
}
+ defer fh.Close()
- w.WriteHeader(http.StatusOK)
- w.Header().Add("Content-type", "text/plain")
- _, _ = io.Copy(w, fh)
+ sendPlain(Response{w, http.StatusOK, "", nil}, fh)
+ })
+}
+
+func (a *App) HandleDel() http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ pth := getPastePathFromRawURL(a.storage, r.URL.Path)
+
+ err := os.Remove(pth)
+ if err != nil {
+ if os.IsNotExist(err) {
+ sendPlain(
+ Response{w, http.StatusNotFound, "Not found", nil},
+ nil)
+ return
+ }
+ sendPlain(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil}, nil)
+ return
+ }
+
+ sendPlain(Response{w, http.StatusOK, "Ok", nil}, nil)
+ })
+}
+
+func (a *App) HandleDelJSON() http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ logger.Println("Made it to delete...")
+ pth := getPastePathFromRawURL(a.storage, r.URL.Path)
+ logger.Println("Deleting path: ", pth)
+ err := os.Remove(pth)
+ if err != nil {
+ if os.IsNotExist(err) {
+ sendJSON(Response{w, http.StatusNotFound, "Not found", nil})
+ return
+ }
+ sendJSON(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil})
+ return
+ }
+
+ sendJSON(Response{w, http.StatusOK, "Ok", nil})
})
}
@@ -457,7 +691,8 @@ func (a *App) HandleList() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pl, err := LoadPastes(a.storage)
if err != nil {
- sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"})
+ sendPlain(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil}, nil)
return
}
@@ -476,8 +711,8 @@ func (a *App) HandleListJSON() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pl, err := LoadPastes(a.storage)
if err != nil {
- sendErr(errResp{w, http.StatusInternalServerError,
- "Internal server error"})
+ sendJSON(Response{w, http.StatusInternalServerError,
+ "Internal server error", nil})
return
}
@@ -541,7 +776,7 @@ func (pl PasteListing) SortDate() {
func getInt(u *url.URL, param string) (int, error) {
i := u.Query().Get(param)
if i == "" {
- return 0, fmt.Errorf("No param '%s' supplied", param)
+ return 0, fmt.Errorf("no param '%s' supplied", param)
}
n, err := strconv.Atoi(i)