diff --git a/app.go b/app.go index 39aba5c7..6fa8ef34 100644 --- a/app.go +++ b/app.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "fmt" stdlog "log" "net" @@ -17,7 +16,6 @@ import ( "github.com/gobuffalo/packr" "github.com/hmage/golibs/log" - "golang.org/x/crypto/ssh/terminal" ) // VersionString will be set through ldflags, contains current version @@ -72,13 +70,10 @@ func run(args options) { log.Printf("AdGuard Home is running as a service") } - err := askUsernamePasswordIfPossible() - if err != nil { - log.Fatal(err) - } + config.firstRun = detectFirstRun() // Do the upgrade if necessary - err = upgradeConfig() + err := upgradeConfig() if err != nil { log.Fatal(err) } @@ -145,7 +140,9 @@ func run(args options) { // Initialize and run the admin Web interface box := packr.NewBox("build/static") - http.Handle("/", optionalAuthHandler(http.FileServer(box))) + // if not configured, redirect / to /install.html, otherwise redirect /install.html to / + http.Handle("/", postInstallHandler(optionalAuthHandler(http.FileServer(box)))) + http.Handle("/install.html", preInstallHandler(http.FileServer(box))) registerControlHandlers() address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) @@ -222,14 +219,6 @@ func cleanup() { } } -func getInput() (string, error) { - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() - text := scanner.Text() - err := scanner.Err() - return text, err -} - // command-line arguments type options struct { verbose bool // is verbose logging enabled @@ -318,79 +307,3 @@ func loadOptions() options { return o } - -func promptAndGet(prompt string) (string, error) { - for { - fmt.Print(prompt) - input, err := getInput() - if err != nil { - log.Printf("Failed to get input, aborting: %s", err) - return "", err - } - if len(input) != 0 { - return input, nil - } - // try again - } -} - -func promptAndGetPassword(prompt string) (string, error) { - for { - fmt.Print(prompt) - password, err := terminal.ReadPassword(int(os.Stdin.Fd())) - fmt.Print("\n") - if err != nil { - log.Printf("Failed to get input, aborting: %s", err) - return "", err - } - if len(password) != 0 { - return string(password), nil - } - // try again - } -} - -func askUsernamePasswordIfPossible() error { - configFile := config.getConfigFilename() - _, err := os.Stat(configFile) - if !os.IsNotExist(err) { - // do nothing, file exists - return nil - } - if !terminal.IsTerminal(int(os.Stdin.Fd())) { - return nil // do nothing - } - if !terminal.IsTerminal(int(os.Stdout.Fd())) { - return nil // do nothing - } - fmt.Printf("Would you like to set user/password for the web interface authentication (yes/no)?\n") - yesno, err := promptAndGet("Please type 'yes' or 'no': ") - if err != nil { - return err - } - if yesno[0] != 'y' && yesno[0] != 'Y' { - return nil - } - username, err := promptAndGet("Please enter the username: ") - if err != nil { - return err - } - - password, err := promptAndGetPassword("Please enter the password: ") - if err != nil { - return err - } - - password2, err := promptAndGetPassword("Please enter password again: ") - if err != nil { - return err - } - if password2 != password { - fmt.Printf("Passwords do not match! Aborting\n") - os.Exit(1) - } - - config.AuthName = username - config.AuthPass = password - return nil -} diff --git a/config.go b/config.go index 33d717ef..aa746285 100644 --- a/config.go +++ b/config.go @@ -29,6 +29,7 @@ type logSettings struct { type configuration struct { ourConfigFilename string // Config filename (can be overridden via the command line arguments) ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else + firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html BindHost string `yaml:"bind_host"` BindPort int `yaml:"bind_port"` @@ -152,6 +153,10 @@ func readConfigFile() ([]byte, error) { func (c *configuration) write() error { c.Lock() defer c.Unlock() + if config.firstRun { + log.Tracef("Silently refusing to write config because first run and not configured yet") + return nil + } configFile := config.getConfigFilename() log.Printf("Writing YAML file: %s", configFile) yamlText, err := yaml.Marshal(&config) diff --git a/control.go b/control.go index f869a8c0..51b0d77c 100644 --- a/control.go +++ b/control.go @@ -694,24 +694,43 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { } } -func handleGetDefaultAddresses(w http.ResponseWriter, r *http.Request) { - type ipport struct { - IP string `json:"ip"` - Port int `json:"port"` - } - data := struct { - Web ipport `json:"web"` - DNS ipport `json:"dns"` - }{} +type ipport struct { + IP string `json:"ip"` + Port int `json:"port"` +} - // TODO: replace mockup with actual data - data.Web.IP = "192.168.104.104" - data.Web.Port = 3000 - data.DNS.IP = "192.168.104.104" - data.DNS.Port = 53 +type firstRunData struct { + Web ipport `json:"web"` + DNS ipport `json:"dns"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` +} + +func handleGetDefaultAddresses(w http.ResponseWriter, r *http.Request) { + data := firstRunData{} + ifaces, err := getValidNetInterfaces() + if err != nil { + httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + return + } + if len(ifaces) == 0 { + httpError(w, http.StatusServiceUnavailable, "Couldn't find any legible interface, plase try again later") + return + } + + // find an interface with an ipv4 address + addr := findIPv4IfaceAddr(ifaces) + if len(addr) == 0 { + httpError(w, http.StatusServiceUnavailable, "Couldn't find any interface with IPv4, plase try again later") + return + } + data.Web.IP = addr + data.DNS.IP = addr + data.Web.Port = 3000 // TODO: find out if port 80 is available -- if not, fall back to 3000 + data.DNS.Port = 53 // TODO: find out if port 53 is available -- if not, show a big warning w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(data) + err = json.NewEncoder(w).Encode(data) if err != nil { httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err) return @@ -719,7 +738,7 @@ func handleGetDefaultAddresses(w http.ResponseWriter, r *http.Request) { } func handleSetAllSettings(w http.ResponseWriter, r *http.Request) { - newSettings := map[string]interface{}{} + newSettings := firstRunData{} err := json.NewDecoder(r.Body).Decode(&newSettings) if err != nil { httpError(w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err) @@ -727,48 +746,57 @@ func handleSetAllSettings(w http.ResponseWriter, r *http.Request) { } spew.Dump(newSettings) + config.firstRun = false + config.BindHost = newSettings.Web.IP + config.BindPort = newSettings.Web.Port + config.DNS.BindHost = newSettings.DNS.IP + config.DNS.Port = newSettings.DNS.Port + config.AuthName = newSettings.Username + config.AuthPass = newSettings.Password + + httpUpdateConfigReloadDNSReturnOK(w, r) } func registerControlHandlers() { - http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus))) - http.HandleFunc("/control/enable_protection", optionalAuth(ensurePOST(handleProtectionEnable))) - http.HandleFunc("/control/disable_protection", optionalAuth(ensurePOST(handleProtectionDisable))) - http.HandleFunc("/control/querylog", optionalAuth(ensureGET(dnsforward.HandleQueryLog))) - http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable))) - http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable))) - http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS))) - http.HandleFunc("/control/test_upstream_dns", optionalAuth(ensurePOST(handleTestUpstreamDNS))) - http.HandleFunc("/control/i18n/change_language", optionalAuth(ensurePOST(handleI18nChangeLanguage))) - http.HandleFunc("/control/i18n/current_language", optionalAuth(ensureGET(handleI18nCurrentLanguage))) - http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(dnsforward.HandleStatsTop))) - http.HandleFunc("/control/stats", optionalAuth(ensureGET(dnsforward.HandleStats))) - http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(dnsforward.HandleStatsHistory))) - http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(dnsforward.HandleStatsReset))) - http.HandleFunc("/control/version.json", optionalAuth(handleGetVersionJSON)) - http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable))) - http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable))) - http.HandleFunc("/control/filtering/add_url", optionalAuth(ensurePUT(handleFilteringAddURL))) - http.HandleFunc("/control/filtering/remove_url", optionalAuth(ensureDELETE(handleFilteringRemoveURL))) - http.HandleFunc("/control/filtering/enable_url", optionalAuth(ensurePOST(handleFilteringEnableURL))) - http.HandleFunc("/control/filtering/disable_url", optionalAuth(ensurePOST(handleFilteringDisableURL))) - http.HandleFunc("/control/filtering/refresh", optionalAuth(ensurePOST(handleFilteringRefresh))) - http.HandleFunc("/control/filtering/status", optionalAuth(ensureGET(handleFilteringStatus))) - http.HandleFunc("/control/filtering/set_rules", optionalAuth(ensurePUT(handleFilteringSetRules))) - http.HandleFunc("/control/safebrowsing/enable", optionalAuth(ensurePOST(handleSafeBrowsingEnable))) - http.HandleFunc("/control/safebrowsing/disable", optionalAuth(ensurePOST(handleSafeBrowsingDisable))) - http.HandleFunc("/control/safebrowsing/status", optionalAuth(ensureGET(handleSafeBrowsingStatus))) - http.HandleFunc("/control/parental/enable", optionalAuth(ensurePOST(handleParentalEnable))) - http.HandleFunc("/control/parental/disable", optionalAuth(ensurePOST(handleParentalDisable))) - http.HandleFunc("/control/parental/status", optionalAuth(ensureGET(handleParentalStatus))) - http.HandleFunc("/control/safesearch/enable", optionalAuth(ensurePOST(handleSafeSearchEnable))) - http.HandleFunc("/control/safesearch/disable", optionalAuth(ensurePOST(handleSafeSearchDisable))) - http.HandleFunc("/control/safesearch/status", optionalAuth(ensureGET(handleSafeSearchStatus))) - http.HandleFunc("/control/dhcp/status", optionalAuth(ensureGET(handleDHCPStatus))) - http.HandleFunc("/control/dhcp/interfaces", optionalAuth(ensureGET(handleDHCPInterfaces))) - http.HandleFunc("/control/dhcp/set_config", optionalAuth(ensurePOST(handleDHCPSetConfig))) - http.HandleFunc("/control/dhcp/find_active_dhcp", optionalAuth(ensurePOST(handleDHCPFindActiveServer))) + http.HandleFunc("/control/status", postInstall(optionalAuth(ensureGET(handleStatus)))) + http.HandleFunc("/control/enable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionEnable)))) + http.HandleFunc("/control/disable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionDisable)))) + http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(dnsforward.HandleQueryLog)))) + http.HandleFunc("/control/querylog_enable", postInstall(optionalAuth(ensurePOST(handleQueryLogEnable)))) + http.HandleFunc("/control/querylog_disable", postInstall(optionalAuth(ensurePOST(handleQueryLogDisable)))) + http.HandleFunc("/control/set_upstream_dns", postInstall(optionalAuth(ensurePOST(handleSetUpstreamDNS)))) + http.HandleFunc("/control/test_upstream_dns", postInstall(optionalAuth(ensurePOST(handleTestUpstreamDNS)))) + http.HandleFunc("/control/i18n/change_language", postInstall(optionalAuth(ensurePOST(handleI18nChangeLanguage)))) + http.HandleFunc("/control/i18n/current_language", postInstall(optionalAuth(ensureGET(handleI18nCurrentLanguage)))) + http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsTop)))) + http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(dnsforward.HandleStats)))) + http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsHistory)))) + http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(dnsforward.HandleStatsReset)))) + http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) + http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable)))) + http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable)))) + http.HandleFunc("/control/filtering/add_url", postInstall(optionalAuth(ensurePUT(handleFilteringAddURL)))) + http.HandleFunc("/control/filtering/remove_url", postInstall(optionalAuth(ensureDELETE(handleFilteringRemoveURL)))) + http.HandleFunc("/control/filtering/enable_url", postInstall(optionalAuth(ensurePOST(handleFilteringEnableURL)))) + http.HandleFunc("/control/filtering/disable_url", postInstall(optionalAuth(ensurePOST(handleFilteringDisableURL)))) + http.HandleFunc("/control/filtering/refresh", postInstall(optionalAuth(ensurePOST(handleFilteringRefresh)))) + http.HandleFunc("/control/filtering/status", postInstall(optionalAuth(ensureGET(handleFilteringStatus)))) + http.HandleFunc("/control/filtering/set_rules", postInstall(optionalAuth(ensurePUT(handleFilteringSetRules)))) + http.HandleFunc("/control/safebrowsing/enable", postInstall(optionalAuth(ensurePOST(handleSafeBrowsingEnable)))) + http.HandleFunc("/control/safebrowsing/disable", postInstall(optionalAuth(ensurePOST(handleSafeBrowsingDisable)))) + http.HandleFunc("/control/safebrowsing/status", postInstall(optionalAuth(ensureGET(handleSafeBrowsingStatus)))) + http.HandleFunc("/control/parental/enable", postInstall(optionalAuth(ensurePOST(handleParentalEnable)))) + http.HandleFunc("/control/parental/disable", postInstall(optionalAuth(ensurePOST(handleParentalDisable)))) + http.HandleFunc("/control/parental/status", postInstall(optionalAuth(ensureGET(handleParentalStatus)))) + http.HandleFunc("/control/safesearch/enable", postInstall(optionalAuth(ensurePOST(handleSafeSearchEnable)))) + http.HandleFunc("/control/safesearch/disable", postInstall(optionalAuth(ensurePOST(handleSafeSearchDisable)))) + http.HandleFunc("/control/safesearch/status", postInstall(optionalAuth(ensureGET(handleSafeSearchStatus)))) + http.HandleFunc("/control/dhcp/status", postInstall(optionalAuth(ensureGET(handleDHCPStatus)))) + http.HandleFunc("/control/dhcp/interfaces", postInstall(optionalAuth(ensureGET(handleDHCPInterfaces)))) + http.HandleFunc("/control/dhcp/set_config", postInstall(optionalAuth(ensurePOST(handleDHCPSetConfig)))) + http.HandleFunc("/control/dhcp/find_active_dhcp", postInstall(optionalAuth(ensurePOST(handleDHCPFindActiveServer)))) // TODO: move to registerInstallHandlers() - http.HandleFunc("/control/install/get_default_addresses", ensureGET(handleGetDefaultAddresses)) - http.HandleFunc("/control/install/set_all_settings", ensurePOST(handleSetAllSettings)) + http.HandleFunc("/control/install/get_default_addresses", preInstall(ensureGET(handleGetDefaultAddresses))) + http.HandleFunc("/control/install/set_all_settings", preInstall(ensurePOST(handleSetAllSettings))) } diff --git a/dhcp.go b/dhcp.go index 744e7a35..3e2cb31e 100644 --- a/dhcp.go +++ b/dhcp.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "io/ioutil" - "net" "net/http" "strings" "time" @@ -70,50 +69,14 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { response := map[string]interface{}{} - ifaces, err := net.Interfaces() + ifaces, err := getValidNetInterfaces() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't get list of interfaces: %s", err) + httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) return } - type responseInterface struct { - Name string `json:"name"` - MTU int `json:"mtu"` - HardwareAddr string `json:"hardware_address"` - Addresses []string `json:"ip_addresses"` - } - for i := range ifaces { - if ifaces[i].Flags&net.FlagLoopback != 0 { - // it's a loopback, skip it - continue - } - if ifaces[i].Flags&net.FlagBroadcast == 0 { - // this interface doesn't support broadcast, skip it - continue - } - if ifaces[i].Flags&net.FlagPointToPoint != 0 { - // this interface is ppp, don't do dhcp over it - continue - } - iface := responseInterface{ - Name: ifaces[i].Name, - MTU: ifaces[i].MTU, - HardwareAddr: ifaces[i].HardwareAddr.String(), - } - addrs, errAddrs := ifaces[i].Addrs() - if errAddrs != nil { - httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %v: %s", ifaces[i].Name, errAddrs) - return - } - for _, addr := range addrs { - iface.Addresses = append(iface.Addresses, addr.String()) - } - if len(iface.Addresses) == 0 { - // this interface has no addresses, skip it - continue - } - response[ifaces[i].Name] = iface + response[ifaces[i].Name] = ifaces[i] } err = json.NewEncoder(w).Encode(response) diff --git a/helpers.go b/helpers.go index 28412a58..53e6f32d 100644 --- a/helpers.go +++ b/helpers.go @@ -3,14 +3,18 @@ package main import ( "bufio" "errors" + "fmt" "io" "io/ioutil" + "net" "net/http" "os" "path" "path/filepath" "runtime" "strings" + + "github.com/hmage/golibs/log" ) // ---------------------------------- @@ -84,24 +88,78 @@ type authHandler struct { } func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if config.AuthName == "" || config.AuthPass == "" { - a.handler.ServeHTTP(w, r) - return - } - user, pass, ok := r.BasicAuth() - if !ok || user != config.AuthName || pass != config.AuthPass { - w.Header().Set("WWW-Authenticate", `Basic realm="dnsfilter"`) - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("Unauthorised.\n")) - return - } - a.handler.ServeHTTP(w, r) + optionalAuth(a.handler.ServeHTTP)(w, r) } func optionalAuthHandler(handler http.Handler) http.Handler { return &authHandler{handler} } +// ------------------- +// first run / install +// ------------------- +func detectFirstRun() bool { + configfile := config.ourConfigFilename + if !filepath.IsAbs(configfile) { + configfile = filepath.Join(config.ourBinaryDir, config.ourConfigFilename) + } + _, err := os.Stat(configfile) + if !os.IsNotExist(err) { + // do nothing, file exists + return false + } + return true +} + +// preInstall lets the handler run only if firstRun is true, no redirects +func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if !config.firstRun { + // if it's not first run, don't let users access it (for example /install.html when configuration is done) + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + handler(w, r) + } +} + +// preInstallStruct wraps preInstall into a struct that can be returned as an interface where neccessary +type preInstallHandlerStruct struct { + handler http.Handler +} + +func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { + preInstall(p.handler.ServeHTTP)(w, r) +} + +// preInstallHandler returns http.Handler interface for preInstall wrapper +func preInstallHandler(handler http.Handler) http.Handler { + return &preInstallHandlerStruct{handler} +} + +// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise +func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if config.firstRun && !strings.HasPrefix(r.URL.Path, "/install.") { + http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable + return + } + handler(w, r) + } +} + +type postInstallHandlerStruct struct { + handler http.Handler +} + +func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { + postInstall(p.handler.ServeHTTP)(w, r) +} + +func postInstallHandler(handler http.Handler) http.Handler { + return &postInstallHandlerStruct{handler} +} + // ------------------------------------------------- // helper functions for parsing parameters from body // ------------------------------------------------- @@ -125,6 +183,81 @@ func parseParametersFromBody(r io.Reader) (map[string]string, error) { return parameters, nil } +// ------------------ +// network interfaces +// ------------------ +type netInterface struct { + Name string `json:"name"` + MTU int `json:"mtu"` + HardwareAddr string `json:"hardware_address"` + Addresses []string `json:"ip_addresses"` +} + +// getValidNetInterfaces() returns interfaces that are eligible for DNS and/or DHCP +// invalid interface is either a loopback, ppp interface, or the one that doesn't allow broadcasts +func getValidNetInterfaces() ([]netInterface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) + } + + netIfaces := []netInterface{} + + for i := range ifaces { + if ifaces[i].Flags&net.FlagLoopback != 0 { + // it's a loopback, skip it + continue + } + if ifaces[i].Flags&net.FlagBroadcast == 0 { + // this interface doesn't support broadcast, skip it + continue + } + if ifaces[i].Flags&net.FlagPointToPoint != 0 { + // this interface is ppp, don't do dhcp over it + continue + } + + iface := netInterface{ + Name: ifaces[i].Name, + MTU: ifaces[i].MTU, + HardwareAddr: ifaces[i].HardwareAddr.String(), + } + + addrs, err := ifaces[i].Addrs() + if err != nil { + return nil, fmt.Errorf("Failed to get addresses for interface %v: %s", ifaces[i].Name, err) + } + for _, addr := range addrs { + iface.Addresses = append(iface.Addresses, addr.String()) + } + if len(iface.Addresses) == 0 { + // this interface has no addresses, skip it + continue + } + netIfaces = append(netIfaces, iface) + } + + return netIfaces, nil +} + +func findIPv4IfaceAddr(ifaces []netInterface) string { + for _, iface := range ifaces { + for _, addr := range iface.Addresses { + ip, _, err := net.ParseCIDR(addr) + if err != nil { + log.Printf("SHOULD NOT HAPPEN: got iface.Addresses element that's not a parseable CIDR: %s", addr) + continue + } + if ip.To4() == nil { + log.Tracef("Ignoring IP that isn't IPv4: %s", ip) + continue + } + return ip.To4().String() + } + } + return "" +} + // --------------------- // debug logging helpers // ---------------------