package main import ( "bytes" "crypto/rand" "embed" "encoding/base64" "encoding/json" "errors" "flag" "fmt" "io" "io/fs" "log" "net/http" "net/http/httputil" "net/url" "os" "path/filepath" "sort" "strconv" "strings" "time" "golang.org/x/crypto/bcrypt" "golang.org/x/term" ) // Default to the system umask const OPEN_MODE = 0777 var ( Version = "git or development" logger = log.New(os.Stderr, "", 0) ID_BYTES = 8 embeddedPrefix = "ui/build" // Keep in sync with staticEmbedded below //go:embed all:ui/build/* staticEmbedded embed.FS ) type App struct { static fs.FS users map[string]string storage string // path to where the paste files are actually stored staticHandler http.Handler } // EnvFlagString is a convent way to set value from environment variable, // and allow override when a command line flag is set. It's assumed `p` is // not nil. func EnvFlagString(fl *flag.FlagSet, p *string, name, envvar, usage string) { if v := os.Getenv(envvar); v != "" { *p = v } fl.StringVar(p, name, *p, fmt.Sprintf("%s (Environ: '%s')", usage, envvar)) } type errResp struct { w http.ResponseWriter Code int Msg string } func main() { var ( listen = ":6130" idBytes = "8" debugF = "false" debug = false genhash = false storage = "" fsdir = "" proxyURL = "" static fs.FS err error ) dumpFStree(staticEmbedded) static, err = fs.Sub(staticEmbedded, embeddedPrefix) if err != nil { logger.Fatal("Embedding failed no static directory") } dumpFStree(static) fl := flag.NewFlagSet("Simple pastebin", flag.ExitOnError) EnvFlagString(fl, &listen, "listen", "LISTEN_ADDR", "Address to bind to") EnvFlagString(fl, &idBytes, "b", "ID_BYTES", "How many bytes long should the IDs be") EnvFlagString(fl, &debugF, "d", "DEBUG", "Additoinal debugging if set to true") EnvFlagString(fl, &storage, "s", "STORAGE_DIR", "Path of directory to serve") EnvFlagString(fl, &fsdir, "fs", "FS_DIR", "Path which static assets are located, empty to use embedded") EnvFlagString(fl, &proxyURL, "sp", "STATIC_PROXY", "What server do we proxy to for static content?") fl.BoolVar(&genhash, "genhash", genhash, "Interactively prompt for a password and spit out a hash\n") version := fl.Bool("v", false, "Print version then exit") _ = fl.Parse(os.Args[1:]) if *version { log.Println(Version) os.Exit(0) } if genhash { interactiveHashGen() os.Exit(0) } debug, err = strconv.ParseBool(debugF) if err != nil { logger.Println("Warning invalid value for debug: ", debugF) } if debug { logger.SetFlags(log.LstdFlags | log.Llongfile) } if fsdir != "" { logger.Println("Using filesystem directory for assets: ", fsdir) static = os.DirFS(fsdir) } if b, err := strconv.Atoi(idBytes); err != nil && b > 4 { ID_BYTES = b } if storage == "" { logger.Fatal("Cannot continue without storage directory, set " + "`-s` flag or STORAGE_DIR environment variable") } err = os.MkdirAll(storage, 0777) if err != nil { logger.Fatal("Failed to create storage directory: ", err) } if debug { dumpFStree(static) } app := &App{ static: static, storage: storage, users: getUsersFromEnviron(), } rp, err := getProxyHandler(proxyURL) if proxyURL != "" && err == nil { app.staticHandler = rp logger.Println("Proxying static requests to: ", rp) } else if err != nil { logger.Printf("Warning, invalid url: '%s': %s", proxyURL, err) } logger.Println("listening on: ", listen) srv := &http.Server{ Handler: logRequests(app.Handler()), Addr: listen, WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, } logger.Fatal(srv.ListenAndServe()) } func dumpFStree(f fs.FS) { logger.Println("dumping fs tree....") fs.WalkDir(f, ".", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } if d.IsDir() { return nil } logger.Println(path) return nil }) } // getProxyHandler returns a httputil.NewSingleHostReverseProxy for the // url provided, and errors parsing the URL, if any. func getProxyHandler(proxyURL string) (http.Handler, error) { pu, err := url.Parse(proxyURL) if err != nil { return nil, err } rp := httputil.NewSingleHostReverseProxy(pu) return rp, nil } func logRequests(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Printf("%s %s %s \"%s\" \"%s\"\n", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent(), r.Referer()) next.ServeHTTP(w, r) }) } func interactiveHashGen() { fmt.Print("Enter password: ") passwd, err := term.ReadPassword(0) if err != nil { logger.Fatal("\nFailed: ", err) } fmt.Printf("\nAgain: ") passwd2, err := term.ReadPassword(0) if err != nil { logger.Fatal("\nFailed: ", err) } fmt.Println("") if !bytes.Equal(passwd, passwd2) { logger.Fatal("Passwords do not match") } passwd, err = bcrypt.GenerateFromPassword(passwd, bcrypt.DefaultCost) if err != nil { logger.Fatal("Failed: ", err) } fmt.Printf("hash: %s\n", string(passwd)) } func (a *App) Handler() http.Handler { mux := http.NewServeMux() secHandlers := map[string]http.Handler{ "/api/v0/new": a.HandleNew(), "/api/v0/list": a.HandleList(), "/api/v1/new": a.HandleNewJSON(), "/api/v1/list": a.HandleListJSON(), } handlers := map[string]http.Handler{ "/api/v1/view/": http.StripPrefix( "/api/v1/view/", a.HandleViewJSON()), "/api/v0/view/": http.StripPrefix( "/api/v0/view/", a.HandleViewPlain()), "/": http.FileServer(http.FS(a.static)), } if a.staticHandler != nil { handlers["/"] = a.staticHandler } if len(a.users) > 0 { for user := range a.users { logger.Println("Found user:", user) } for pth, handler := range secHandlers { mux.Handle(pth, authHandler( handler, a.users, )) } } else { _, _ = fmt.Fprintf(os.Stderr, "\033[1;31mWARNING: RUNNING WITH NO AUTHENTICATION\033[0m\n") for pth, handler := range secHandlers { mux.Handle(pth, handler) } } for pth, handler := range handlers { mux.Handle(pth, handler) } return mux } func getUsersFromEnviron() map[string]string { users := map[string]string{} for _, entry := range os.Environ() { if !strings.HasPrefix(entry, "USER_") { continue } e := strings.SplitN(entry, "=", 2) key := e[0] val := e[1] username := strings.TrimPrefix(key, "USER_") users[username] = val } return users } func sendErr(er errResp) { er.w.WriteHeader(er.Code) er.w.Header().Add("Content-type", "application/json") enc := json.NewEncoder(er.w) _ = enc.Encode(er) } func GenId() string { r := make([]byte, ID_BYTES) _, err := rand.Read(r) if err != nil { logger.Fatal(err) } return base64.RawURLEncoding.EncodeToString(r) } func authHandler(next http.Handler, users map[string]string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { respUnAuth := func() { logger.Println("Unauthed") w.Header().Add("Content-type", "text/plain") w.Header().Add("WWW-Authenticate", "Basic realm=\"Login\", charset=\"UTF-8\"") w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte("Unauthorized\n")) } username, passwd, ok := r.BasicAuth() if _, haveUser := users[username]; !ok || !haveUser { respUnAuth() return } err := bcrypt.CompareHashAndPassword( []byte(users[username]), []byte(passwd), ) if errors.Is(err, bcrypt.ErrHashTooShort) { logger.Println("Hash too short for username: ", username) } if err != nil { respUnAuth() return } next.ServeHTTP(w, r) }) } func (a *App) HandleNew() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := GenId() fh, err := os.OpenFile(id, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE) if err != nil { sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) return } _, err = io.Copy(fh, r.Body) if err != nil { sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) return } w.WriteHeader(http.StatusOK) w.Header().Add("Content-type", "text/plain") enc := json.NewEncoder(w) _ = enc.Encode(struct { Id string }{id}) }) } type NewPost struct { Content string `json:"content"` } func (a *App) HandleNewJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := GenId() p := &NewPost{} dec := json.NewDecoder(r.Body) err := dec.Decode(p) if err != nil || p.Content == "" { sendErr(errResp{w, http.StatusBadRequest, "Malformed JSON"}) return } fh, err := os.OpenFile(id, os.O_CREATE|os.O_EXCL|os.O_RDWR, OPEN_MODE) if err != nil { sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) return } _, err = fh.Write([]byte(p.Content)) if err != nil { sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) return } sendErr(errResp{w, http.StatusOK, "OK"}) }) } func (a *App) HandleViewJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pth := filepath.Join(a.storage, strings.ReplaceAll(filepath.Clean(r.URL.Path), "/", "")) fh, err := os.Open(pth) if err != nil { sendErr(errResp{w, http.StatusNotFound, "ID not found"}) return } b, err := io.ReadAll(fh) if err != nil { sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) return } w.WriteHeader(http.StatusOK) w.Header().Add("Content-type", "application/json") enc := json.NewEncoder(w) _ = enc.Encode(struct { Content string }{string(b)}) }) } func (a *App) HandleViewPlain() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pth := filepath.Join(a.storage, strings.ReplaceAll(filepath.Clean(r.URL.Path), "/", "")) fh, err := os.Open(pth) if err != nil { sendErr(errResp{w, http.StatusNotFound, "ID not found"}) return } w.WriteHeader(http.StatusOK) w.Header().Add("Content-type", "text/plain") _, _ = io.Copy(w, fh) }) } func (a *App) HandleList() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pl, err := LoadPastes(a.storage) if err != nil { sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) return } pl = handleSkipLimitSort(pl, r.URL) for _, e := range pl { _, _ = w.Write([]byte(fmt.Sprintf("%d\t%s\t%s\n", e.Size, e.Created.Format("2006-01-02 15:04 MST"), e.Id))) } }) } func (a *App) HandleListJSON() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pl, err := LoadPastes(a.storage) if err != nil { sendErr(errResp{w, http.StatusInternalServerError, "Internal server error"}) return } pl = handleSkipLimitSort(pl, r.URL) enc := json.NewEncoder(w) enc.SetIndent("", " ") _ = enc.Encode(pl) }) } type Paste struct { Id string `json:"id"` Created time.Time `json:"created"` Size int64 `json:"size"` } type PasteListing []*Paste func LoadPastes(pth string) (PasteListing, error) { out := PasteListing{} de, err := os.ReadDir(pth) if err != nil { return nil, err } for _, e := range de { if e.IsDir() { continue } info, err := e.Info() if err != nil { return nil, err } out = append(out, &Paste{ Id: e.Name(), Created: info.ModTime(), Size: info.Size(), }) } out.SortDate() return out, nil } func (pl PasteListing) SortDateReverse() { sort.Slice(pl, func(i, j int) bool { return pl[i].Created.Before(pl[j].Created) }) } func (pl PasteListing) SortDate() { sort.Slice(pl, func(i, j int) bool { return pl[i].Created.After(pl[j].Created) }) } func getInt(u *url.URL, param string) (int, error) { i := u.Query().Get(param) if i == "" { return 0, fmt.Errorf("No param '%s' supplied", param) } n, err := strconv.Atoi(i) if err != nil { return 0, err } return n, nil } func limitSlice[T any](s []T, nElem int) []T { if nElem > len(s) { return s } return s[:nElem] } func skipSlice[T any](s []T, skip int) []T { if skip > len(s) { return []T{} } return s[skip:] } func handleSkipLimitSort(pl PasteListing, URL *url.URL) PasteListing { if _, ok := URL.Query()["reverse"]; ok { pl.SortDateReverse() } skip, err := getInt(URL, "skip") if err == nil { pl = skipSlice(pl, skip) } limit, err := getInt(URL, "limit") if err == nil { pl = limitSlice(pl, limit) } return pl }