diff options
Diffstat (limited to 'cmd/web/main.go')
| -rw-r--r-- | cmd/web/main.go | 72 |
1 files changed, 69 insertions, 3 deletions
diff --git a/cmd/web/main.go b/cmd/web/main.go index ec0704d..0f76670 100644 --- a/cmd/web/main.go +++ b/cmd/web/main.go @@ -5,8 +5,10 @@ import ( "fmt" "log" "math/rand" + "net" "net/http" "os" + "strings" "sync" "time" @@ -16,8 +18,8 @@ import ( var ( Version = "Development" - Logger = log.New(os.Stderr, "", log.LstdFlags) - Listen = ":8899" + Logger = log.New(os.Stderr, "", log.LstdFlags) + Listen = ":8899" libMu = &sync.RWMutex{} Lib *steam.Library @@ -38,6 +40,10 @@ func reloadLib() { } func setLibHandler(w http.ResponseWriter, r *http.Request) { + if unauthorizedIfNotLocal(w, r) { + return + } + err := r.ParseForm() if err != nil { Logger.Printf("Setlib: While parsing form: %s", err) @@ -51,16 +57,75 @@ func setLibHandler(w http.ResponseWriter, r *http.Request) { } func quitHandler(w http.ResponseWriter, r *http.Request) { + if unauthorizedIfNotLocal(w, r) { + return + } + Logger.Println("Quit was called, exiting") w.Header().Add("Content-type", "text/plain") w.Write([]byte("Shutting down...")) go func() { - time.Sleep(time.Second*2) + time.Sleep(time.Second * 2) os.Exit(0) }() return } +func unauthorizedIfNotLocal(w http.ResponseWriter, r *http.Request) bool { + if !isLocal(r.RemoteAddr) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + Logger.Printf("Unauthorized request from: %s for %s", + r.RemoteAddr, r.RequestURI) + return true + } + return false +} + +func isLocal(addr string) bool { + _, localNet, _ := net.ParseCIDR("127.0.0.1/8") + return localNet.Contains(net.ParseIP(strings.Split(addr, ":")[0])) +} + +// getHostIP attempts to guess the IP address of the current machine and +// returns that. Simply bails at the first non loopback IP returning that. +// not ideal but it should work well enough most of the time +func getHostIP() string { + iFaces, err := net.Interfaces() + if err != nil { + return "127.0.0.1" + } + + for _, iFace := range iFaces { + addrs, err := iFace.Addrs() + if err != nil { + return "127.0.0.1" + } + + for _, a := range addrs { + n, ok := a.(*net.IPNet) + if !ok { + continue + } + + if n.IP.To4() != nil && !n.IP.IsLoopback() { + return n.IP.String() + } + } + } + + return "127.0.0.1" +} + +func getPort() string { + s := strings.Split(Listen, ":") + + if len(s) != 2 { + return Listen + } + + return s[1] +} + func main() { fl := flag.NewFlagSet("steam-export", flag.ExitOnError) debug := fl.Bool("d", false, "Print line numbers in log") @@ -86,6 +151,7 @@ func main() { r.HandleFunc("/setLib", setLibHandler) r.HandleFunc("/delete", gameDelete) r.HandleFunc("/install", gameInstaller) + r.HandleFunc("/steam-export-web.exe", serveSelf) r.HandleFunc("/download/{game}", gameDownloader) r.HandleFunc("/style.css", cssHandler) r.HandleFunc("/", index) |
