package main import ( "bytes" "crypto/rand" "embed" "encoding/base64" "errors" "flag" "fmt" "html/template" "io" "io/fs" "log" "net/http" "net/url" "os" "path/filepath" "sort" "strconv" "strings" "time" jwt "github.com/golang-jwt/jwt/v4" "golang.org/x/crypto/bcrypt" "golang.org/x/term" ) // Default to the system umask const OPEN_MODE = 0777 var ( Version = "git or development" logger = log.New(os.Stderr, "", 0) ID_BYTES = 8 //embeddedPrefix = "ui/build" // Keep in sync with staticEmbedded below // go:embed all:ui/build/* embeddedPrefix = "htmx" // Keep in sync with staticEmbedded below //go:embed all:htmx/* staticEmbedded embed.FS ) // EnvFlagString is a convent way to set value from environment variable, // and allow override when a command line flag is set. It's assumed `p` is // not nil. func EnvFlagString(fl *flag.FlagSet, p *string, name, envvar, usage string) { if v := os.Getenv(envvar); v != "" { *p = v } fl.StringVar(p, name, *p, fmt.Sprintf("%s (Environ: '%s')", usage, envvar)) } func logIfErr(err error) { if err != nil { logger.Println(err) } } // Response encapsulates all the JSON responses from this server, Code // is a copy of the HTTP status code, Msg is a human-readable statement // about the particular response. Data is optional type Response struct { w http.ResponseWriter Code int Msg string Data any `json:",omitempty"` } func main() { var ( listen = ":6130" idBytes = "8" logRequestsF = "false" logRequests bool debugF = "false" debug bool genhash = false storage = "" fsdir = "" jwtKey = "" sessionHours = "12" rSessionHours = 12 static fs.FS err error ) static, err = fs.Sub(staticEmbedded, embeddedPrefix) if err != nil { logger.Fatal("Embedding failed no static directory") } fl := flag.NewFlagSet("Simple pastebin", flag.ExitOnError) EnvFlagString(fl, &listen, "listen", "LISTEN_ADDR", "Address to bind to") EnvFlagString(fl, &idBytes, "b", "ID_BYTES", "How many bytes long should the IDs be") EnvFlagString(fl, &debugF, "d", "DEBUG", "Additional debugging if set to true") EnvFlagString(fl, &storage, "s", "STORAGE_DIR", "Path of directory to serve") EnvFlagString(fl, &fsdir, "fs", "FS_DIR", "Path which static assets are located, empty to use embedded") EnvFlagString(fl, &jwtKey, "jwt", "JWT_KEY", "If supplied use the value as the JWT key instead of a random value\n"+ "Please understand this is not a good idea. Use a random key if "+ "possible\n") EnvFlagString(fl, &sessionHours, "hours", "SESSION_HOURS", "How many hours should login sessions last?") EnvFlagString(fl, &logRequestsF, "log", "LOG_REQUESTS", "Do we do full request logs, or are we quiet?") fl.BoolVar(&genhash, "genhash", genhash, "Interactively prompt for a password and spit out a hash\n") version := fl.Bool("v", false, "Print version then exit") _ = fl.Parse(os.Args[1:]) if *version { log.Println(Version) os.Exit(0) } if genhash { interactiveHashGen() os.Exit(0) } debug, err = strconv.ParseBool(debugF) if err != nil { logger.Println("Warning invalid value for debug: ", debugF) } logRequests, err = strconv.ParseBool(logRequestsF) if err != nil { logger.Println("Warning invalid value for logRequests: ", logRequestsF) } if debug { logger.SetFlags(log.LstdFlags | log.Llongfile) } if fsdir != "" { logger.Println("Using filesystem directory for assets: ", fsdir) static = os.DirFS(fsdir) } if b, err := strconv.Atoi(idBytes); err == nil && b > 4 { logger.Printf("Setting ID_BYTES: %d\n", b) ID_BYTES = b } if storage == "" { logger.Fatal("Cannot continue without storage directory, set " + "`-s` flag or STORAGE_DIR environment variable") } if jwtKey == "" { jwtKey = genTokenKey() } if h, err := strconv.Atoi(sessionHours); err == nil && h > 0 { logger.Printf("Setting SESSION_HOURS: %d\n", h) rSessionHours = h } err = os.MkdirAll(storage, 0777) if err != nil { logger.Fatal("Failed to create storage directory: ", err) } if debug { dumpFStree(static) } logger.Println("listening on: ", listen) handler := Handler(static, storage, LoadUsersFromEnviron(), jwtKey, rSessionHours) if logRequests { handler = httpRequestLogger(handler) } srv := &http.Server{ Handler: handler, Addr: listen, WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, } logger.Fatal(srv.ListenAndServe()) } func dumpFStree(f fs.FS) { logger.Println("dumping fs tree....") _ = fs.WalkDir(f, ".", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } if d.IsDir() { return nil } logger.Println(path) return nil }) } func httpRequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Printf("%s %s %s \"%s\" \"%s\"\n", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent(), r.Referer()) next.ServeHTTP(w, r) }) } func interactiveHashGen() { fmt.Print("Enter password: ") passwd, err := term.ReadPassword(0) if err != nil { logger.Fatal("\nFailed: ", err) } fmt.Printf("\nAgain: ") passwd2, err := term.ReadPassword(0) if err != nil { logger.Fatal("\nFailed: ", err) } fmt.Println("") if !bytes.Equal(passwd, passwd2) { logger.Fatal("Passwords do not match") } passwd, err = bcrypt.GenerateFromPassword(passwd, bcrypt.DefaultCost) if err != nil { logger.Fatal("Failed: ", err) } fmt.Printf("hash: %s\n", string(passwd)) } func Handler( static fs.FS, storage string, users UsersMap, jwtKey string, sessionHours int, ) http.Handler { mux := http.NewServeMux() secHandlers := map[string]http.Handler{ "/plain/new": HandleNew(storage), "/plain/list": HandleList(storage), "/plain/del/": http.StripPrefix("/plain/del/", HandleDel(storage)), } handlers := map[string]http.Handler{ "/plain/view/": http.StripPrefix("/plain/view/", HandleViewPlain(storage)), "/login": loginHandler(static, users, jwtKey, sessionHours), "/logout": logoutHandler(), "/static/": http.FileServer(http.FS(static)), "/": HandleIndex(static), } if len(users) > 0 { for user := range users { logger.Println("Found user:", user) } for pth, handler := range secHandlers { mux.Handle(pth, UserAuthHandler(users, handler, jwtKey)) } } else { _, _ = fmt.Fprintf(os.Stderr, "\033[1;31mWARNING: RUNNING WITH NO AUTHENTICATION\033[0m\n") for pth, handler := range secHandlers { mux.Handle(pth, handler) } } for pth, handler := range handlers { mux.Handle(pth, handler) } return mux } // LoadUsersFromEnviron lets you build a very basic user store from scraping // the environment variables alone. Environment variables are in the form of // `USER_=` For example in a shell script: // // export USER_mitch='$2a$10$MdpHOxqyaxVwX7tBmch/MOnuq5jgcy7ciCUGwixVR43SchyDtxLVW' // // Will then be picked up and put into the map as mitch: func LoadUsersFromEnviron() UsersMap { users := map[string]string{} for _, entry := range os.Environ() { if !strings.HasPrefix(entry, "USER_") { continue } e := strings.SplitN(entry, "=", 2) key := e[0] val := e[1] username := strings.TrimPrefix(key, "USER_") users[username] = val } return users } func HandleIndex(f fs.FS) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logIfErr(renderHX(f, isHTMX(r), w, nil, "index.html")) }) } func genTokenKey() string { r := make([]byte, 16) // 128 bits _, err := rand.Read(r) if err != nil { panic(fmt.Errorf("reading random bytes: %w", err)) } return base64.RawURLEncoding.EncodeToString(r) } func isHTMX(r *http.Request) bool { return r.Header.Get("HX-Request") == "true" } func renderHX(fs fs.FS, isHX bool, wr io.Writer, data any, templates ...string) error { if isHX { templates = append([]string{"htmx.html"}, templates...) } else { templates = append([]string{"base.html"}, templates...) } return renderTemplate(fs, wr, data, templates...) } func renderErrorTemplate(fs fs.FS, wr io.Writer, data any) error { return renderTemplate(fs, wr, data, "base.html", "error.html") } func renderTemplate(fs fs.FS, wr io.Writer, data any, templates ...string) error { tpl, err := template.ParseFS(fs, templates...) if err != nil { return err } return tpl.Execute(wr, data) } // sendPlain takes in a response struct, and writes the msg string to the // output verbatim unless the rdr is not nil, in which case the rdr // is used instead func sendPlain(r Response, rdr io.Reader) { r.w.Header().Add("Content-type", "text/plain") r.w.WriteHeader(r.Code) if rdr != nil { _, _ = io.Copy(r.w, rdr) } else { _, _ = r.w.Write([]byte(r.Msg)) } } // GenId grabs cryptographically random data, and dumps out a URL encoded base64 // string of it. Suitable for identifiers. Bytes used to create the string // is controlled by ID_BYTES func GenId() string { r := make([]byte, ID_BYTES) _, err := rand.Read(r) if err != nil { logger.Fatal(err) } return base64.RawURLEncoding.EncodeToString(r) } // UsersMap is simply a convience wrapp around a map[string]string, // with a few methods to handle validation of usernames/passwords // as well as JWTs type UsersMap map[string]string func HasValidJWT(users UsersMap, tokenString, jwtKey string) bool { token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte(jwtKey), nil }) if err != nil { logger.Printf("failed parsing token: %s", err) return false } if !token.Valid { return false } // logger.Printf("Type of token.Claims: %s\n", reflect.TypeOf(token.Claims)) claims, ok := token.Claims.(*jwt.RegisteredClaims) if !ok { logger.Println("Invalid claims for token") return false } _, ok = users[claims.ID] if !ok { return false } return err == nil } // IsValidLogin lets you check directly whether a particular username and // password pair is valid for the UsersMap in question func IsValidLogin(users UsersMap, username, password string) bool { if _, haveUser := users[username]; !haveUser { return false } err := bcrypt.CompareHashAndPassword( []byte(users[username]), []byte(password), ) if errors.Is(err, bcrypt.ErrHashTooShort) { logger.Println("Hash too short for username: ", username) } return err == nil } // HasValidPlainAuth will return true if the request contains a basic auth // header that validates one of the users in the UsersMap func HasValidPlainAuth(users UsersMap, r *http.Request) bool { username, passwd, ok := r.BasicAuth() if !ok { return false } return IsValidLogin(users, username, passwd) } // getCookie simply returns a cookie value if any, or an empty string if // it's invalid or there are any errors. func getCookie(r *http.Request, name string) string { c, err := r.Cookie(name) // Normally I'd also check c.Valid()... Expires is optional, but apparently // makes it invalid? Either way duration is handled by the JWT if err != nil { return "" } return c.Value } // AuthHandler is a fairly basic authentication middleware that validates // credentials based on either a valid JWT in the 'Auth' cookie or // Username/Password via Basic Auth func UserAuthHandler(users UsersMap, next http.Handler, jwtKey string, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !(HasValidPlainAuth(users, r) || HasValidJWT(users, getCookie(r, "Auth"), jwtKey)) { // Ask for basic auth w.Header().Add("WWW-Authenticate", `Basic realm=, charset="UTF-8"`) w.WriteHeader(http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } func loginHandler(f fs.FS, users UsersMap, jwtKey string, sessionHours int, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { logIfErr(renderHX(f, isHTMX(r), w, nil, "login.html")) return } if r.Method != "POST" { w.WriteHeader(http.StatusBadRequest) logIfErr(renderErrorTemplate(f, w, "Invalid request method"+r.Method)) return } err := r.ParseForm() if err != nil { w.WriteHeader(http.StatusBadRequest) logIfErr(renderErrorTemplate(f, w, "Invalid form")) return } username := r.PostFormValue("username") password := r.PostFormValue("password") if !IsValidLogin(users, username, password) { data := map[string]string{"username": username} logIfErr(renderHX(f, isHTMX(r), w, data, "login.html")) return } c, err := getLoginCookie(username, jwtKey, sessionHours) if err != nil { w.WriteHeader(http.StatusInternalServerError) logIfErr(renderErrorTemplate(f, w, "Internal server error")) logIfErr(err) return } http.SetCookie(w, c) if isHTMX(r) { w.Header().Set("HX-Redirect", "/") return } http.Redirect(w, r, "/", http.StatusFound) }) } func getLoginCookie(username string, jwtKey string, sessionHours int, ) (*http.Cookie, error) { expires := time.Now().Add(time.Hour * time.Duration(sessionHours)) claims := &jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expires), ID: username, } token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) ss, err := token.SignedString([]byte(jwtKey)) if err != nil { return nil, err } return &http.Cookie{ Name: "Auth", HttpOnly: true, SameSite: http.SameSiteStrictMode, Value: ss, Secure: true, }, err } func logoutHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: "Auth", HttpOnly: true, SameSite: http.SameSiteStrictMode, Value: "logout", Expires: time.Now().Add(time.Second), }) http.Redirect(w, r, "/", http.StatusFound) }) } func OpenW(filename string) (*os.File, error) { return os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE) } func HandleNew(storagePath string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := GenId() fh, err := OpenW(filepath.Join(storagePath, id)) if err != nil { sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) return } _, err = io.Copy(fh, r.Body) if err != nil { sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) return } sendPlain(Response{w, http.StatusOK, id + "\n", nil}, nil) }) } type NewPost struct { Content string `json:"content"` } func getPastePathFromRawURL(storage, u string) string { return filepath.Join(storage, strings.ReplaceAll( filepath.Clean(u), "/", "")) } func HandleViewPlain(storagePath string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pth := getPastePathFromRawURL(storagePath, r.URL.Path) fh, err := os.Open(pth) if err != nil { sendPlain(Response{w, http.StatusNotFound, "ID not found", nil}, nil) return } defer fh.Close() sendPlain(Response{w, http.StatusOK, "", nil}, fh) }) } func HandleDel(storagePath string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pth := getPastePathFromRawURL(storagePath, r.URL.Path) err := os.Remove(pth) if err != nil { if os.IsNotExist(err) { sendPlain( Response{w, http.StatusNotFound, "Not found", nil}, nil) return } sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) return } sendPlain(Response{w, http.StatusOK, "Ok", nil}, nil) }) } func HandleList(storagePath string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pl, err := LoadPastes(storagePath) if err != nil { sendPlain(Response{w, http.StatusInternalServerError, "Internal server error", nil}, nil) return } pl = handleSkipLimitSort(pl, r.URL) for _, e := range pl { _, _ = w.Write([]byte(fmt.Sprintf("%d\t%s\t%s\n", e.Size, e.Created.Format("2006-01-02 15:04 MST"), e.Id))) } }) } type Paste struct { Id string `json:"id"` Created time.Time `json:"created"` Size int64 `json:"size"` } type PasteListing []*Paste func LoadPastes(pth string) (PasteListing, error) { out := PasteListing{} de, err := os.ReadDir(pth) if err != nil { return nil, err } for _, e := range de { if e.IsDir() { continue } info, err := e.Info() if err != nil { return nil, err } out = append(out, &Paste{ Id: e.Name(), Created: info.ModTime(), Size: info.Size(), }) } out.SortDate() return out, nil } func (pl PasteListing) SortDateReverse() { sort.Slice(pl, func(i, j int) bool { return pl[i].Created.Before(pl[j].Created) }) } func (pl PasteListing) SortDate() { sort.Slice(pl, func(i, j int) bool { return pl[i].Created.After(pl[j].Created) }) } func getInt(u *url.URL, param string) (int, error) { i := u.Query().Get(param) if i == "" { return 0, fmt.Errorf("no param '%s' supplied", param) } n, err := strconv.Atoi(i) if err != nil { return 0, err } return n, nil } func limitSlice[T any](s []T, nElem int) []T { if nElem > len(s) { return s } return s[:nElem] } func skipSlice[T any](s []T, skip int) []T { if skip > len(s) { return []T{} } return s[skip:] } func handleSkipLimitSort(pl PasteListing, URL *url.URL) PasteListing { if _, ok := URL.Query()["reverse"]; ok { pl.SortDateReverse() } skip, err := getInt(URL, "skip") if err == nil { pl = skipSlice(pl, skip) } limit, err := getInt(URL, "limit") if err == nil { pl = limitSlice(pl, limit) } return pl }