diff options
Diffstat (limited to 'main.go')
| -rw-r--r-- | main.go | 444 |
1 files changed, 393 insertions, 51 deletions
@@ -1,46 +1,59 @@ package main import ( - "compress/gzip" + "crypto/ed25519" "crypto/rand" - _ "embed" + "embed" "encoding/base64" - "encoding/gob" + "encoding/json" + "encoding/pem" "errors" "flag" + "fmt" "html/template" + "io" "log" "net/http" "os" + "path/filepath" + "strings" "time" + "github.com/golang-jwt/jwt/v4" "github.com/gorilla/mux" + "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v3" ) var logger = log.New(os.Stderr, "", 0) -// const ID_BYTES = 5 -const ID_BYTES = 16 +var ID_BYTES = 8 -//go:embed index.tpl -var indexTemplateContent string -var indexTemplate = template.Must(template.New("index").Parse(indexTemplateContent)) +//go:embed templates/* +var templateFS embed.FS -//go:embed view.tpl -var viewTemplateContent string -var viewTemplate = template.Must(template.New("view").Parse(viewTemplateContent)) +var tpls = map[string]*template.Template{} -//go:embed error.tpl -var errorTemplateContent string -var errorTemplate = template.Must(template.New("error").Parse(errorTemplateContent)) +func init() { + // List of templates that will be pre-rendered and toss into + // the tpls map above + tplList := []string{"index", "view", "error"} + for _, tpl := range tplList { + tpls[tpl] = template.Must(template.ParseFS(templateFS, + "templates/"+tpl+".tpl", + "templates/base.tpl", + )) + } +} -//go:embed style.css -var stylesheetContent []byte +//go:embed static/* +var staticFS embed.FS type Paste struct { Id string Title string - Content []byte + Tags map[string]struct{} + Content string } func (p Paste) GetContent() string { @@ -48,43 +61,33 @@ func (p Paste) GetContent() string { } func (p *Paste) Load() error { - fh, err := os.Open(p.Id) - if err != nil { - return err - } - - zrdr, err := gzip.NewReader(fh) + fh, err := os.Open(filepath.Join("p", p.Id)) if err != nil { return err } - dec := gob.NewDecoder(zrdr) + dec := json.NewDecoder(fh) err = dec.Decode(&p) if err != nil { return err } - zrdr.Close() - return fh.Close() } func (p Paste) Save() error { - fh, err := os.OpenFile(p.Id, os.O_CREATE|os.O_RDWR, 0666) + fh, err := os.OpenFile(filepath.Join("p", p.Id), os.O_CREATE|os.O_RDWR, 0666) if err != nil { return err } - zwr := gzip.NewWriter(fh) - - enc := gob.NewEncoder(zwr) + enc := json.NewEncoder(fh) + enc.SetIndent("", " ") err = enc.Encode(&p) if err != nil { return err } - zwr.Close() - return fh.Close() } @@ -98,11 +101,135 @@ func GenId() string { return base64.RawURLEncoding.EncodeToString(r) } +type User struct { + Username string `yaml:"Username"` + Password string `yaml:"Password"` + HashedPassword string `yaml:"HashedPassword"` +} + +// CheckPass returns true if passwords match +func (u *User) CheckPass(pass string) bool { + err := bcrypt.CompareHashAndPassword([]byte(u.HashedPassword), []byte(pass)) + if err != nil { + return false + } + return true +} + +type Conf struct { + // Populated after read used for lookups + Users map[string]*User `yaml:"Users"` +} + +func hashPasswd(pass string) (string, error) { + b, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost) + return string(b), err +} + +func readConf(fn string) (*Conf, error) { + fh, err := os.Open(fn) + if err != nil { + return nil, err + } + + dec := yaml.NewDecoder(fh) + dec.KnownFields(true) + + c := &Conf{} + err = dec.Decode(c) + if err != nil { + return nil, err + } + fh.Close() + + changed := false + // Convert any plain passwords to HashedPasswords + for n, u := range c.Users { + u.Username = n + if u.Password != "" { + u.HashedPassword, err = hashPasswd(u.Password) + if err != nil { + return nil, err + } + u.Password = "" + changed = true + } + } + + if changed { + fh, err := os.OpenFile(fn, os.O_TRUNC|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + + enc := yaml.NewEncoder(fh) + enc.SetIndent(2) + err = enc.Encode(c) + if err != nil { + return nil, err + } + fh.Close() + } + + return c, nil +} + +func loadOrGenKeys() (ed25519.PublicKey, ed25519.PrivateKey, error) { + var ( + key ed25519.PrivateKey + pub ed25519.PublicKey + err error + ) + if _, err = os.Stat("key"); err != nil { + pub, key, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + + fh, err := os.OpenFile("key", os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, nil, err + } + + err = pem.Encode(fh, &pem.Block{ + Type: "ED25519 PRIVATE KEY", + Bytes: key, + }) + if err != nil { + return nil, nil, err + } + + fh.Close() + } else { + fh, err := os.Open("key") + if err != nil { + return nil, nil, err + } + + b, err := io.ReadAll(fh) + if err != nil { + return nil, nil, err + } + + blk, _ := pem.Decode(b) + if blk == nil || blk.Type != "ED25519 PRIVATE KEY" { + return nil, nil, errors.New("Failed to decode PEM file on disk") + } + + key = ed25519.PrivateKey(blk.Bytes) + pub = key.Public().(ed25519.PublicKey) + fh.Close() + } + + return pub, key, err +} + func main() { - fl := flag.NewFlagSet("Brutally simple pastebin", flag.ExitOnError) + fl := flag.NewFlagSet("simple pastebin", flag.ExitOnError) listen := fl.String("listen", ":6130", "Address to bind to, LISTEN_ADDR environment variable overrides") debug := fl.Bool("d", false, "debugging add information to the logging output DEBUG=true|false controls this as well") storage := fl.String("s", "", "Directory to serve, must be supplied via flag or STORAGE_DIR environment variable") + fl.IntVar(&ID_BYTES, "b", ID_BYTES, "How many random bytes for the id?") _ = fl.Parse(os.Args[1:]) if addr := os.Getenv("LISTEN_ADDR"); addr != "" { @@ -119,22 +246,42 @@ func main() { logger.Fatal("Cannot continue without storage directory, set `-s` flag or STORAGE_DIR environment variable") } - err := os.Chdir(*storage) + err := os.MkdirAll(filepath.Join(*storage, "p"), 0755) if err != nil { logger.Fatal(err) } - mux := mux.NewRouter() + err = os.Chdir(*storage) + if err != nil { + logger.Fatal(err) + } + + c, err := readConf("config.yml") + if err != nil { + logger.Fatal(err) + } + logger.Println("Config:") + b, _ := json.MarshalIndent(c, "", " ") + logger.Println(string(b)) - mux.HandleFunc("/new", newPaste) - mux.HandleFunc("/view/{id}", loadPaste) - mux.HandleFunc("/style.css", stylesheet) - mux.HandleFunc("/", index) + pubKey, key, err := loadOrGenKeys() + if err != nil { + logger.Fatal(err) + } + + r := mux.NewRouter() + + r.Handle("/new", requireJWT(pubKey, c.Users, newPasteJson())) + r.HandleFunc("/view/{id}", loadPaste) + r.HandleFunc("/view/json/{id}", loadPasteJson) + r.PathPrefix("/static").Handler(http.FileServer(http.FS(staticFS))) + r.Handle("/login", handleLogin(key, c.Users)) + r.HandleFunc("/", index) logger.Println("listening on: ", *listen) srv := &http.Server{ - Handler: mux, + Handler: r, Addr: *listen, WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, @@ -142,6 +289,174 @@ func main() { logger.Fatal(srv.ListenAndServe()) } +func jsonResp(w http.ResponseWriter, code int, data interface{}) { + w.Header().Add("Content-type", "application/json") + w.WriteHeader(code) + + enc := json.NewEncoder(w) + err := enc.Encode(data) + if err != nil { + logger.Println("While jsonResp: ", err) + } +} + +func jsonErr(logMsg string, msg string, + w http.ResponseWriter, statusCode int) { + + logger.Println(logMsg) + + w.Header().Add("Content-type", "application/json") + w.WriteHeader(statusCode) + + enc := json.NewEncoder(w) + err := enc.Encode(map[string]string{"error": msg}) + if err != nil { + logger.Println("While logMsg: ", err) + } +} + +func handleLogin(key ed25519.PrivateKey, users map[string]*User) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqU := &User{} + dec := json.NewDecoder(r.Body) + + err := dec.Decode(reqU) + if err != nil { + jsonErr( + fmt.Sprintf("Encountered error decoding user: %s", err), + "invalid json", + w, http.StatusBadRequest) + return + } + + u, ok := users[reqU.Username] + + if !ok || reqU.Username == "" || reqU.Password == "" { + jsonErr( + "Invalid username or password", + "invalid username or password", + w, http.StatusBadRequest) + return + } + + if !u.CheckPass(reqU.Password) { + jsonErr( + fmt.Sprintf("Bad password for: %s", u.Username), + "bad username or password", + w, http.StatusBadRequest) + return + } + + t := jwt.NewWithClaims(jwt.SigningMethodEdDSA, &jwt.StandardClaims{ + ExpiresAt: time.Now().Unix() + (12 * 60 * 60), + Subject: u.Username, + }) + + s, err := t.SignedString(key) + if err != nil { + jsonErr( + fmt.Sprintf("Failed to sign: %s", err), + "internal server error", + w, http.StatusInternalServerError) + return + } + + w.Header().Add("Content-type", "application/json") + w.Header().Add("Authorization", "Bearer "+s) + enc := json.NewEncoder(w) + enc.Encode(map[string]string{ + "token": s, + }) + }) +} + +func requireJWT(key ed25519.PublicKey, users map[string]*User, + next http.Handler) http.Handler { + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + tokenS := r.Header.Get("Authorization") + if tokenS == "" { + jsonErr("Empty token received", "Unauthorized: empty token", + w, http.StatusUnauthorized) + return + } + + tokenS = strings.TrimPrefix(tokenS, "Bearer ") + + claims := &jwt.StandardClaims{} + + token, err := jwt.ParseWithClaims(tokenS, claims, + func(token *jwt.Token) (interface{}, error) { + return key, nil + }) + + if err != nil { + jsonErr(fmt.Sprintf("Error parsing token: %s", err), + "Unauthorized: invalid token", w, http.StatusUnauthorized) + return + } + + if !token.Valid { + jsonErr( + fmt.Sprintf("Token for %s expires at: %v", claims.Subject, + claims.ExpiresAt), + "Unauthorized: invalid token", w, http.StatusUnauthorized) + return + } + + u, ok := users[claims.Subject] + + if !ok { + jsonErr( + fmt.Sprintf("User %s not valid", claims.Subject), + "invalid user", w, http.StatusUnauthorized) + return + } + + logger.Printf("%s -> authed user: %s", r.URL.Path, u.Username) + + next.ServeHTTP(w, r) + }) +} + +func newPasteJson() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + + paste := &Paste{} + + err := dec.Decode(paste) + if err != nil { + jsonErr(fmt.Sprintf("Encountered error decoding: %s", err), + "unable to decode input", w, http.StatusBadRequest) + return + } + + if paste.Content == "" { + jsonErr("Not saving paste with empty content", + "empty content", w, http.StatusBadRequest) + return + } + + paste.Id = GenId() + + err = paste.Save() + + if err != nil { + jsonErr(fmt.Sprintf("Encountered error saving paste: %s", err), + "internal server error", w, http.StatusInternalServerError) + return + } + + jsonResp(w, http.StatusOK, map[string]string{ + "status": "ok", + "id": paste.Id, + }) + }) +} + func newPaste(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { @@ -162,7 +477,7 @@ func newPaste(w http.ResponseWriter, r *http.Request) { p := &Paste{ Id: GenId(), Title: title, - Content: []byte(content), + Content: content, } err = p.Save() @@ -175,13 +490,45 @@ func newPaste(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/view/"+p.Id, http.StatusFound) } +func loadPasteJson(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + + id, ok := vars["id"] + if !ok { + jsonErr("No ID supplied", + "ID must be supplied", w, http.StatusBadRequest) + return + } + + p := &Paste{Id: id} + + err := p.Load() + if err != nil { + logger.Println(err) + if errors.Is(err, os.ErrNotExist) { + jsonErr( + fmt.Sprintf("Snip with id: %s not found", id), + "ID not found", w, http.StatusNotFound) + return + } else { + jsonErr( + fmt.Sprintf("Snip with id faild loading: %s", err), + "Failed to load snippet", w, http.StatusInternalServerError) + return + } + return + } + + jsonResp(w, http.StatusOK, p) +} + func loadPaste(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { w.WriteHeader(http.StatusBadRequest) - err := errorTemplate.Execute(w, map[string]string{ + err := tpls["error"].Execute(w, map[string]string{ "Short": "ID Not found", "Long": "There was no ID supplied", }) @@ -198,13 +545,13 @@ func loadPaste(w http.ResponseWriter, r *http.Request) { logger.Println(err) if errors.Is(err, os.ErrNotExist) { w.WriteHeader(http.StatusNotFound) - err = errorTemplate.Execute(w, map[string]string{ + err = tpls["error"].Execute(w, map[string]string{ "Short": "ID Not found", "Long": "ID: " + p.Id + "Was not found", }) } else { w.WriteHeader(http.StatusInternalServerError) - err = errorTemplate.Execute(w, map[string]string{ + err = tpls["error"].Execute(w, map[string]string{ "Short": "ID Not found", "Long": "There was an issue reading ID: " + p.Id, }) @@ -215,19 +562,14 @@ func loadPaste(w http.ResponseWriter, r *http.Request) { return } - err = viewTemplate.Execute(w, p) + err = tpls["view"].Execute(w, p) if err != nil { logger.Println(err) } } -func stylesheet(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-type", "text/css") - _, _ = w.Write(stylesheetContent) -} - func index(w http.ResponseWriter, r *http.Request) { - err := indexTemplate.Execute(w, nil) + err := tpls["index"].Execute(w, nil) if err != nil { logger.Println(err) http.Error(w, "Internal server error", http.StatusInternalServerError) |
