package main import ( "bytes" "crypto/rand" "embed" "encoding/base64" "encoding/json" "errors" "flag" "fmt" "io" "io/fs" "log" "net/http" "net/http/httputil" "net/url" "os" "path/filepath" "sort" "strconv" "strings" "time" jwt "github.com/golang-jwt/jwt/v4" "golang.org/x/crypto/bcrypt" "golang.org/x/term" ) // Default to the system umask const OPEN_MODE = 0777 var ( Version = "git or development" logger = log.New(os.Stderr, "", 0) ID_BYTES = 8 embeddedPrefix = "ui/build" // Keep in sync with staticEmbedded below //go:embed all:ui/build/* staticEmbedded embed.FS ) type App struct { static fs.FS users map[string]string jwtKey string sessionHours int storage string // path to where the paste files are actually stored staticHandler http.Handler } // EnvFlagString is a convent way to set value from environment variable, // and allow override when a command line flag is set. It's assumed `p` is // not nil. func EnvFlagString(fl *flag.FlagSet, p *string, name, envvar, usage string) { if v := os.Getenv(envvar); v != "" { *p = v } fl.StringVar(p, name, *p, fmt.Sprintf("%s (Environ: '%s')", usage, envvar)) } // Response encapsulates all the JSON responses from this server, Code // is a copy of the HTTP status code, Msg is a human-readable statement // about the particular response. Data is optional type Response struct { w http.ResponseWriter Code int Msg string Data any `json:",omitempty"` } func main() { var ( listen = ":6130" idBytes = "8" debugF = "false" debug bool genhash = false storage = "" fsdir = "" proxyURL = "" jwtKey = "" sessionHours = "12" rSessionHours = 12 static fs.FS err error ) static, err = fs.Sub(staticEmbedded, embeddedPrefix) if err != nil { logger.Fatal("Embedding failed no static directory") } fl := flag.NewFlagSet("Simple pastebin", flag.ExitOnError) EnvFlagString(fl, &listen, "listen", "LISTEN_ADDR", "Address to bind to") EnvFlagString(fl, &idBytes, "b", "ID_BYTES", "How many bytes long should the IDs be") EnvFlagString(fl, &debugF, "d", "DEBUG", "Additional debugging if set to true") EnvFlagString(fl, &storage, "s", "STORAGE_DIR", "Path of directory to serve") EnvFlagString(fl, &fsdir, "fs", "FS_DIR", "Path which static assets are located, empty to use embedded") EnvFlagString(fl, &proxyURL, "sp", "STATIC_PROXY", "What server do we proxy to for static content?") EnvFlagString(fl, &jwtKey, "jwt", "JWT_KEY", "If supplied use the value as the JWT key instead of a random value") EnvFlagString(fl, &sessionHours, "hours", "SESSION_HOURS", "How many hours should login sessions last?") fl.BoolVar(&genhash, "genhash", genhash, "Interactively prompt for a password and spit out a hash\n") version := fl.Bool("v", false, "Print version then exit") _ = fl.Parse(os.Args[1:]) if *version { log.Println(Version) os.Exit(0) } if genhash { interactiveHashGen() os.Exit(0) } debug, err = strconv.ParseBool(debugF) if err != nil { logger.Println("Warning invalid value for debug: ", debugF) } if debug { logger.SetFlags(log.LstdFlags | log.Llongfile) } if fsdir != "" { logger.Println("Using filesystem directory for assets: ", fsdir) static = os.DirFS(fsdir) } if b, err := strconv.Atoi(idBytes); err == nil && b > 4 { logger.Printf("Setting ID_BYTES: %d\n", b) ID_BYTES = b } if storage == "" { logger.Fatal("Cannot continue without storage directory, set " + "`-s` flag or STORAGE_DIR environment variable") } if jwtKey == "" { jwtKey = genTokenKey() } if h, err := strconv.Atoi(sessionHours); err == nil && h > 0 { logger.Printf("Setting SESSION_HOURS: %d\n", h) rSessionHours = h } err = os.MkdirAll(storage, 0777) if err != nil { logger.Fatal("Failed to create storage directory: ", err) } if debug { dumpFStree(static) } app := &App{ static: static, storage: storage, jwtKey: jwtKey, sessionHours: rSessionHours, users: getUsersFromEnviron(), } rp, err := getProxyHandler(proxyURL) if proxyURL != "" && err == nil { app.staticHandler = rp logger.Println("Proxying static requests to: ", proxyURL) } else if err != nil { logger.Printf("Warning, invalid url: '%s': %s", proxyURL, err) } logger.Println("listening on: ", listen) srv := &http.Server{ Handler: logRequests(app.Handler()), Addr: listen, WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, } logger.Fatal(srv.ListenAndServe()) } func dumpFStree(f fs.FS) { logger.Println("dumping fs tree....") _ = fs.WalkDir(f, ".", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } if d.IsDir() { return nil } logger.Println(path) return nil }) } // getProxyHandler returns a httputil.NewSingleHostReverseProxy for the // url provided, and errors parsing the URL, if any. func getProxyHandler(proxyURL string) (http.Handler, error) { pu, err := url.Parse(proxyURL) if err != nil { return nil, err } rp := httputil.NewSingleHostReverseProxy(pu) return rp, nil } func logRequests(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Printf("%s %s %s \"%s\" \"%s\"\n", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent(), r.Referer()) next.ServeHTTP(w, r) }) } func interactiveHashGen() { fmt.Print("Enter password: ") passwd, err := term.ReadPassword(0) if err != nil { logger.Fatal("\nFailed: ", err) } fmt.Printf("\nAgain: ") passwd2, err := term.ReadPassword(0) if err != nil { logger.Fatal("\nFailed: ", err) } fmt.Println("") if !bytes.Equal(passwd, passwd2) { logger.Fatal("Passwords do not match") } passwd, err = bcrypt.GenerateFromPassword(passwd, bcrypt.DefaultCost) if err != nil { logger.Fatal("Failed: ", err) } fmt.Printf("hash: %s\n", string(passwd)) } func (a *App) Handler() http.Handler { mux := http.NewServeMux() secHandlers := map[string]http.Handler{ "/api/v0/new": a.HandleNew(), "/api/v0/list": a.HandleList(), "/api/v0/del/": http.StripPrefix( "/api/v0/del/", a.HandleDel()), "/api/v1/new": a.HandleNewJSON(), "/api/v1/list": a.HandleListJSON(), "/api/v1/del/": http.StripPrefix( "/api/v1/del/", a.HandleDelJSON()), } handlers := map[string]http.Handler{ "/api/v1/view/": http.StripPrefix( "/api/v1/view/", a.HandleViewJSON()), "/api/v0/view/": http.StripPrefix( "/api/v0/view/", a.HandleViewPlain()), "/api/v1/login": loginHandler(a.users, a.jwtKey, a.sessionHours), "/api/v1/logout": logoutHandler(), "/": http.FileServer(http.FS(a.static)), } if a.staticHandler != nil { handlers["/"] = a.staticHandler } if len(a.users) > 0 { for user := range a.users { logger.Println("Found user:", user) } for pth, handler := range secHandlers { mux.Handle(pth, authHandler( handler, a.users, a.jwtKey, )) } } else { _, _ = fmt.Fprintf(os.Stderr, "\033[1;31mWARNING: RUNNING WITH NO AUTHENTICATION\033[0m\n") for pth, handler := range secHandlers { mux.Handle(pth, handler) } } for pth, handler := range handlers { mux.Handle(pth, handler) } return mux } func getUsersFromEnviron() map[string]string { users := map[string]string{} for _, entry := range os.Environ() { if !strings.HasPrefix(entry, "USER_") { continue } e := strings.SplitN(entry, "=", 2) key := e[0] val := e[1] username := strings.TrimPrefix(key, "USER_") users[username] = val } return users } func genTokenKey() string { r := make([]byte, 16) // 128 bits _, err := rand.Read(r) if err != nil { panic(fmt.Errorf("reading random bytes: %w", err)) } return base64.RawURLEncoding.EncodeToString(r) } func sendJSON(er Response) { er.w.WriteHeader(er.Code) er.w.Header().Add("Content-type", "application/json") enc := json.NewEncoder(er.w) _ = enc.Encode(er) } // sendPlain takes in a response struct, and writes the msg string to the // output verbatim unless the rdr is not nil, in which case the rdr // is used instead func sendPlain(r Response, rdr io.Reader) { r.w.WriteHeader(r.Code) r.w.Header().Add("Content-type", "application/json") if rdr != nil { _, _ = io.Copy(r.w, rdr) } else { _, _ = r.w.Write([]byte(r.Msg)) } } func GenId() string { r := make([]byte, ID_BYTES) _, err := rand.Read(r) if err != nil { logger.Fatal(err) } return base64.RawURLEncoding.EncodeToString(r) } func hasValidJWT(users map[string]string, tokenString, jwtKey string) bool { token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte(jwtKey), nil }) if err != nil { logger.Printf("failed parsing token: %s", err) return false } if !token.Valid { return false } // logger.Printf("Type of token.Claims: %s\n", reflect.TypeOf(token.Claims)) claims, ok := token.Claims.(*jwt.RegisteredClaims) if !ok { logger.Println("Invalid claims for token") return false } _, ok = users[claims.ID] if !ok { return false } return err == nil } func isValidLogin(users map[string]string, username, password string) bool { if _, haveUser := users[username]; !haveUser { return false } err := bcrypt.CompareHashAndPassword( []byte(users[username]), []byte(password), ) if errors.Is(err, bcrypt.ErrHashTooShort) { logger.Println("Hash too short for username: ", username) } return err == nil } func hasValidPlainAuth(r *http.Request, users map[string]string) bool { username, passwd, ok := r.BasicAuth() if !ok { return false } return isValidLogin(users, username, passwd) } // getCookie simply returns a cookie value if any, or an empty string if // it's invalid or there are any errors. func getCookie(r *http.Request, name string) string { c, err := r.Cookie(name) // Normally I'd also check c.Valid()... Expires is optional, but // apparently makes it invalid? if err != nil { return "" } fmt.Printf("Cookie value: '%s'\n", c.Value) return c.Value } // respAskForBasicLogin simply adds a header telling web browsers we need // authentication func respAskForBasicLogin(w http.ResponseWriter) { w.Header().Add("WWW-Authenticate", "Basic realm=\"Login\", charset=\"UTF-8\"") } func authHandler(next http.Handler, users map[string]string, jwtKey string, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { respUnAuth := func() { // respAskForBasicLogin(w) sendJSON(Response{w, http.StatusUnauthorized, "Unauthorized", nil}) } if !hasValidPlainAuth(r, users) && !hasValidJWT(users, getCookie(r, "Auth"), jwtKey) { respUnAuth() return } next.ServeHTTP(w, r) }) } func loginHandler(users map[string]string, jwtKey string, sessionHours int, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { sendJSON(Response{w, http.StatusBadRequest, "Invalid type. POST only", nil}) return } err := r.ParseForm() if err != nil { sendJSON(Response{w, http.StatusBadRequest, "Invalid form", nil}) return } username := r.PostFormValue("username") password := r.PostFormValue("password") if !isValidLogin(users, username, password) { sendJSON(Response{w, http.StatusUnauthorized, "Invalid username or password", nil}) return } expires := time.Now().Add(time.Hour * time.Duration(sessionHours)) claims := &jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expires), ID: username, } token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) ss, err := token.SignedString([]byte(jwtKey)) if err != nil { sendJSON(Response{w, http.StatusInternalServerError, "Invalid username or password", nil}) return } http.SetCookie(w, &http.Cookie{ Name: "Auth", HttpOnly: true, SameSite: http.SameSiteStrictMode, Value: ss, }) http.Redirect(w, r, "/", http.StatusFound) }) } func logoutHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: "Auth", HttpOnly: true, SameSite: http.SameSiteStrictMode, Value: "logout", Expires: time.Now().Add(time.Second), //nolint }) http.Redirect(w, r, "/", http.StatusFound) }) } func OpenW(filename string) (*os.File, error) { return os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE) } func (a *App) HandleNew() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := GenId() fh, err := OpenW(filepath.Join(a.storage, id)) if err != nil { sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) return } _, err = io.Copy(fh, r.Body) if err != nil { sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) return } sendPlain(Response{w, http.StatusOK, id + "\n", nil}, nil) }) } type NewPost struct { Content string `json:"content"` } func (a *App) HandleNewJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := GenId() p := &NewPost{} dec := json.NewDecoder(r.Body) err := dec.Decode(p) if err != nil || p.Content == "" { sendJSON(Response{w, http.StatusBadRequest, "Malformed JSON", nil}) return } fh, err := OpenW(filepath.Join(a.storage, id)) if err != nil { sendJSON(Response{w, http.StatusInternalServerError, "Internal server error", nil}) return } _, err = fh.Write([]byte(p.Content)) if err != nil { sendJSON(Response{w, http.StatusInternalServerError, "Internal server error", nil}) return } sendJSON(Response{w, http.StatusOK, "OK", struct{ Id string }{id}}) }) } func getPastePathFromRawURL(storage, u string) string { return filepath.Join(storage, strings.ReplaceAll( filepath.Clean(u), "/", "")) } func (a *App) HandleViewJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pth := getPastePathFromRawURL(a.storage, r.URL.Path) fh, err := os.Open(pth) if err != nil { sendJSON(Response{w, http.StatusNotFound, "ID not found", nil}) return } b, err := io.ReadAll(fh) if err != nil { sendJSON(Response{w, http.StatusInternalServerError, "Internal server error", nil}) return } w.WriteHeader(http.StatusOK) w.Header().Add("Content-type", "application/json") enc := json.NewEncoder(w) _ = enc.Encode(struct { Content string }{string(b)}) }) } func (a *App) HandleViewPlain() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pth := getPastePathFromRawURL(a.storage, r.URL.Path) fh, err := os.Open(pth) if err != nil { sendPlain(Response{w, http.StatusNotFound, "ID not found", nil}, nil) return } defer fh.Close() sendPlain(Response{w, http.StatusOK, "", nil}, fh) }) } func (a *App) HandleDel() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pth := getPastePathFromRawURL(a.storage, r.URL.Path) err := os.Remove(pth) if err != nil { if os.IsNotExist(err) { sendPlain( Response{w, http.StatusNotFound, "Not found", nil}, nil) return } sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) return } sendPlain(Response{w, http.StatusOK, "Ok", nil}, nil) }) } func (a *App) HandleDelJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Println("Made it to delete...") pth := getPastePathFromRawURL(a.storage, r.URL.Path) logger.Println("Deleting path: ", pth) err := os.Remove(pth) if err != nil { if os.IsNotExist(err) { sendJSON(Response{w, http.StatusNotFound, "Not found", nil}) return } sendJSON(Response{w, http.StatusInternalServerError, "Internal server error", nil}) return } sendJSON(Response{w, http.StatusOK, "Ok", nil}) }) } func (a *App) HandleList() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pl, err := LoadPastes(a.storage) if err != nil { sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) return } pl = handleSkipLimitSort(pl, r.URL) for _, e := range pl { _, _ = w.Write([]byte(fmt.Sprintf("%d\t%s\t%s\n", e.Size, e.Created.Format("2006-01-02 15:04 MST"), e.Id))) } }) } func (a *App) HandleListJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pl, err := LoadPastes(a.storage) if err != nil { sendJSON(Response{w, http.StatusInternalServerError, "Internal server error", nil}) return } pl = handleSkipLimitSort(pl, r.URL) enc := json.NewEncoder(w) enc.SetIndent("", " ") _ = enc.Encode(pl) }) } type Paste struct { Id string `json:"id"` Created time.Time `json:"created"` Size int64 `json:"size"` } type PasteListing []*Paste func LoadPastes(pth string) (PasteListing, error) { out := PasteListing{} de, err := os.ReadDir(pth) if err != nil { return nil, err } for _, e := range de { if e.IsDir() { continue } info, err := e.Info() if err != nil { return nil, err } out = append(out, &Paste{ Id: e.Name(), Created: info.ModTime(), Size: info.Size(), }) } out.SortDate() return out, nil } func (pl PasteListing) SortDateReverse() { sort.Slice(pl, func(i, j int) bool { return pl[i].Created.Before(pl[j].Created) }) } func (pl PasteListing) SortDate() { sort.Slice(pl, func(i, j int) bool { return pl[i].Created.After(pl[j].Created) }) } func getInt(u *url.URL, param string) (int, error) { i := u.Query().Get(param) if i == "" { return 0, fmt.Errorf("no param '%s' supplied", param) } n, err := strconv.Atoi(i) if err != nil { return 0, err } return n, nil } func limitSlice[T any](s []T, nElem int) []T { if nElem > len(s) { return s } return s[:nElem] } func skipSlice[T any](s []T, skip int) []T { if skip > len(s) { return []T{} } return s[skip:] } func handleSkipLimitSort(pl PasteListing, URL *url.URL) PasteListing { if _, ok := URL.Query()["reverse"]; ok { pl.SortDateReverse() } skip, err := getInt(URL, "skip") if err == nil { pl = skipSlice(pl, skip) } limit, err := getInt(URL, "limit") if err == nil { pl = limitSlice(pl, limit) } return pl }