diff options
| -rw-r--r-- | Makefile | 2 | ||||
| -rw-r--r-- | go.mod | 2 | ||||
| -rw-r--r-- | go.sum | 2 | ||||
| -rw-r--r-- | server.go | 135 | ||||
| -rw-r--r-- | util.go | 85 |
5 files changed, 178 insertions, 48 deletions
@@ -2,7 +2,7 @@ build_dev: tidy format go build -o devserver serve_dev: build_dev - ./devserver --dev + ./devserver --chroot ./ --root /htdocs/ --host localhost build: tidy format GOOS=openbsd GOARCH=amd64 go build -o webserver @@ -1,3 +1,5 @@ module git.samanthony.xyz/samanthony.xyz go 1.18 + +require golang.org/x/sys v0.0.0-20220412211240-33da011f77ad @@ -0,0 +1,2 @@ +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -3,9 +3,11 @@ package main import ( "flag" "fmt" - tmpl "html/template" + "golang.org/x/sys/unix" + "html/template" "io/fs" "log" + "net" "net/http" "os" "path" @@ -13,50 +15,88 @@ import ( "strings" ) +// Flags var ( - host = "samanthony.xyz" - port = "443" - htdocs = "/var/www/htdocs/samanthony.xyz" + host = "localhost" + port = "80" + chroot = "/var/www/" + user = "www" + group = "www" + root = "/htdocs/samanthony.xyz/" ) -const ( - acmeDocs = "/var/www/acme/" - certFile = "/etc/ssl/samanthony.xyz.fullchain.pem" - keyFile = "/etc/ssl/private/samanthony.xyz.key" -) +func init() { + flag.StringVar(&host, "host", host, "") + flag.StringVar(&port, "port", port, "") + flag.StringVar(&chroot, "chroot", chroot, "") + flag.StringVar(&user, "user", user, "") + flag.StringVar(&group, "group", group, "") + flag.StringVar(&root, "root", root, "") -const ( - devHost = "localhost" - devPort = "6969" - devHtdocs = "htdocs/" -) + flag.Parse() +} -var devMode bool +// Must lookup the hostname before entering the chroot. +var addr = "" func init() { - flag.BoolVar(&devMode, "dev", false, - "Run server in debug/development mode (on localhost without tls)") + // host is an ip address + if ip := net.ParseIP(host); ip != nil { + addr = ip.String() + } else { // host is a domain name + addrs, err := net.LookupHost(host) + if err != nil { + log.Fatal(err) + } + for _, a := range addrs { + if ip := net.ParseIP(a); ip != nil { + if v4 := ip.To4(); v4 != nil { + addr = v4.String() + } + } + } + if addr == "" { + log.Fatalf("No ipv4 address bound to %s", host) + } + } +} - flag.Parse() +var ( + uid int + gid int +) - if devMode { - host = devHost - port = devPort - htdocs = devHtdocs +func init() { + var err error + uid, err = uidOf(user) + if err != nil { + log.Fatal(err) + } + gid, err = gidOf(group) + if err != nil { + log.Fatal(err) } } -var tmpls = make(map[string]*tmpl.Template) +// Enter chroot +func init() { + if err := unix.Chroot(chroot); err != nil { + log.Fatalf("chroot: %s: %v", chroot, err) + } +} + +// Build templates +var tmpl = make(map[string]*template.Template) func init() { - err := fp.WalkDir(htdocs, func(path string, d fs.DirEntry, err error) error { - if fp.Clean(path) == fp.Clean(htdocs) || + err := fp.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if fp.Clean(path) == fp.Clean(root) || fp.Ext(path) != ".html" || - path == fp.Join(htdocs, "base.html") { + path == fp.Join(root, "base.html") { return nil } - label := path[len(fp.Clean(htdocs)):] - tmpls[label] = tmpl.Must(tmpl.ParseFiles(fp.Join(htdocs, "base.html"), path)) + label := path[len(fp.Clean(root)):] + tmpl[label] = template.Must(template.ParseFiles(fp.Join(root, "base.html"), path)) return nil }) if err != nil { @@ -64,6 +104,7 @@ func init() { } } +// Template data type Page struct { Nav Nav } @@ -86,11 +127,18 @@ var nav = Nav{ } func rootHandler(w http.ResponseWriter, r *http.Request) { + if err := dropPerms(uid, gid); err != nil { + log.Println(err) + code := http.StatusInternalServerError + http.Error(w, http.StatusText(code), code) + return + } + reqPath := r.URL.Path // If request directory, serve index.html. // ie. /software -> /software/index.html - if info, err := os.Stat(fp.Join(htdocs, reqPath)); err == nil { + if info, err := os.Stat(fp.Join(root, reqPath)); err == nil { if info.IsDir() { reqPath = path.Join(reqPath, "index.html") } @@ -104,7 +152,7 @@ func rootHandler(w http.ResponseWriter, r *http.Request) { return } - if t, ok := tmpls[reqPath]; ok { + if t, ok := tmpl[reqPath]; ok { thisSection := "" for _, link := range nav.Links { if strings.HasPrefix(reqPath, link.Href) { @@ -123,26 +171,19 @@ func rootHandler(w http.ResponseWriter, r *http.Request) { return } } else { - http.ServeFile(w, r, fp.Join(htdocs, reqPath)) + http.ServeFile(w, r, fp.Join(root, reqPath)) } } func main() { http.HandleFunc("/", rootHandler) - if !devMode { - http.Handle("/.well-known/acme-challenge/", - http.StripPrefix( - "/.well-known/acme-challenge/", - http.FileServer(http.Dir(acmeDocs)), - ), - ) - } - - if devMode { - log.Printf("Listening on %s:%s\n", devHost, devPort) - log.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%s", devHost, devPort), nil)) - } else { - log.Printf("Listening on %s:%s\n", host, port) - log.Fatal(http.ListenAndServeTLS(fmt.Sprintf("%s:%s", host, port), certFile, keyFile, nil)) - } + http.Handle("/.well-known/acme-challenge/", + http.StripPrefix( + "/.well-known/acme-challenge/", + http.FileServer(http.Dir("/acme/")), + ), + ) + + log.Printf("Listening on %s:%s\n", addr, port) + log.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%s", addr, port), nil)) } @@ -0,0 +1,85 @@ +package main + +import ( + "bufio" + "errors" + "fmt" + "golang.org/x/sys/unix" + "log" + "os" + "runtime" + "strconv" + "strings" +) + +func uidOf(user string) (int, error) { + passwdFile, err := os.Open("/etc/passwd") + if err != nil { + return -1, err + } + defer passwdFile.Close() + + scanner := bufio.NewScanner(passwdFile) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + line := scanner.Text() + + parsed := strings.Split(line, ":") + + name := parsed[0] + + if name == user { + uid, err := strconv.Atoi(parsed[2]) + if err != nil { + return -1, err + } + return uid, nil + } + } + return -1, errors.New(fmt.Sprintf("user '%s' not in /etc/passwd", user)) +} + +func gidOf(group string) (int, error) { + groupFile, err := os.Open("/etc/group") + if err != nil { + return -1, err + } + defer groupFile.Close() + + scanner := bufio.NewScanner(groupFile) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + line := scanner.Text() + + parsed := strings.Split(line, ":") + + name := parsed[0] + + if name == group { + gid, err := strconv.Atoi(parsed[2]) + if err != nil { + return -1, err + } + return gid, nil + } + } + return -1, errors.New(fmt.Sprintf("group '%s' not in /etc/group", group)) +} + +func dropPerms(uid, gid int) error { + if runtime.GOOS != "linux" { + if err := unix.Setgid(gid); err != nil { + return errors.New(fmt.Sprintf("setgid(%d): %v", gid, err)) + } + if err := unix.Setuid(uid); err != nil { + return errors.New(fmt.Sprintf("setuid(%d): %v", uid, err)) + } + return nil + } else { + // setuid/setgid has supposedly been fully supported on Linux + // since go 1.16 but I can't seem to get it to work properly. + log.Print("setgid not supported on Linux, skipping.") + log.Print("setuid not supported on Linux, skipping.") + return nil + } +} |