From b9bb17044a8c2b47c7e96660e27ab645f82bec9d Mon Sep 17 00:00:00 2001 From: Mitch Riedstra Date: Thu, 4 Mar 2021 19:44:02 -0500 Subject: Further refactoring. --- cmd/web/util.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 cmd/web/util.go (limited to 'cmd/web/util.go') diff --git a/cmd/web/util.go b/cmd/web/util.go new file mode 100644 index 0000000..1252c66 --- /dev/null +++ b/cmd/web/util.go @@ -0,0 +1,89 @@ +package main + +import ( + "io" + "os" + "net/http" + "net" + "strings" +) + +func UnauthorizedIfNotLocal(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !isLocal(r.RemoteAddr) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + Logger.Printf("Unauthorized request from: %s for %s", + r.RemoteAddr, r.RequestURI) + return + } + h.ServeHTTP(w, r) + }) +} + +func isLocal(addr string) bool { + _, localNet, _ := net.ParseCIDR("127.0.0.1/8") + return localNet.Contains(net.ParseIP(strings.Split(addr, ":")[0])) +} + +// getHostIP attempts to guess the IP address of the current machine and +// returns that. Simply bails at the first non sane looking IP and returns it. +// Not ideal but it should work well enough most of the time +func getHostIP() string { + iFaces, err := net.Interfaces() + if err != nil { + return "127.0.0.1" + } + + // RFC 3927 + _, ipv4LinkLocal, _ := net.ParseCIDR("169.254.0.0/16") + + for _, iFace := range iFaces { + addrs, err := iFace.Addrs() + if err != nil { + return "127.0.0.1" + } + + for _, a := range addrs { + n, ok := a.(*net.IPNet) + if !ok { + continue + } + + if n.IP.To4() != nil && !n.IP.IsLoopback() && !ipv4LinkLocal.Contains(n.IP.To4()) { + return n.IP.String() + } + } + } + + return "127.0.0.1" +} + +func getPort() string { + s := strings.Split(Listen, ":") + + if len(s) != 2 { + return Listen + } + + return s[1] +} + +func serveSelf(w http.ResponseWriter, r *http.Request) { + s, err := os.Executable() + if err != nil { + Logger.Println("While trying to get my executable path: ", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + fh, err := os.Open(s) + if err != nil { + Logger.Println("While opening my own executable for reading: ", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + _, err = io.Copy(w, fh) + fh.Close() + return +} -- cgit v1.2.3