diff --git a/main.go b/main.go index 45b07b0..739bd62 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,15 @@ package main import ( + "flag" "fmt" "os" "os/exec" - "github.com/moby/sys/mountinfo" + "path/filepath" + "strings" + "github.com/kevinburke/ssh_config" - "flag" + "github.com/moby/sys/mountinfo" ) var eFlag = flag.Bool("e", false, "open mountpoint in your editor") @@ -16,65 +19,81 @@ func main() { args := flag.Args() if len(args) == 0 { - fmt.Println("No hostname specified.") - os.Exit(80) - } else { - hostname := ssh_config.Get(args[0], "HostName") - user := ssh_config.Get(args[0], "User") - port := ssh_config.Get(args[0], "Port") - ifile := ssh_config.Get(args[0], "IdentityFile") - - if len(hostname) == 0 || len(user) == 0 || len(ifile) == 0 { - fmt.Println("Hostname not found in ~/.ssh_config") - os.Exit(3) - } else { - mount := verify_mount_dir(hostname) - - fmt.Println("Hostname: ",hostname) - fmt.Println("User: ", user) - fmt.Println("Port: ", port) - fmt.Println("Ifile: ", ifile) - fmt.Println("Mount: ", mount) - fmt.Println("---") - - chkmount, chkmount_err := mountinfo.Mounted(mount) - if chkmount_err != nil { - fmt.Println("mountinfo.Mounted() failed with %s\n", chkmount_err) - } - if chkmount == false { - mount_sshfs(hostname, user, ifile, port, mount) - } else { - fmt.Println("!!! Already mounted") - } - run_editor(mount) - } + fmt.Fprintln(os.Stderr, "No hostname specified.") + os.Exit(2) } + + hostname := ssh_config.Get(args[0], "HostName") + user := ssh_config.Get(args[0], "User") + port := ssh_config.Get(args[0], "Port") + ifile := ssh_config.Get(args[0], "IdentityFile") + + if len(hostname) == 0 || len(user) == 0 || len(ifile) == 0 { + fmt.Fprintln(os.Stderr, "Hostname not found in ~/.ssh/config") + os.Exit(3) + } + if len(port) == 0 { + port = "22" + } + + mount, err := verify_mount_dir(hostname) + if err != nil { + fmt.Fprintln(os.Stderr, "verify_mount_dir() failed:", err) + os.Exit(4) + } + + fmt.Println("Hostname: ", hostname) + fmt.Println("User: ", user) + fmt.Println("Port: ", port) + fmt.Println("Ifile: ", ifile) + fmt.Println("Mount: ", mount) + fmt.Println("---") + + chkmount, chkmount_err := mountinfo.Mounted(mount) + if chkmount_err != nil { + fmt.Fprintf(os.Stderr, "mountinfo.Mounted() failed with %s\n", chkmount_err) + os.Exit(5) + } + if !chkmount { + if err := mount_sshfs(hostname, user, ifile, port, mount); err != nil { + fmt.Fprintln(os.Stderr, "mount_sshfs() failed:", err) + os.Exit(6) + } + } else { + fmt.Println("!!! Already mounted") + } + run_editor(mount) } func run_editor(mount string) { - if(*eFlag == true) { + if *eFlag { cmd := exec.Command("subl", mount) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() if err != nil { - fmt.Println("run_editor() failed with\n",err) + fmt.Fprintln(os.Stderr, "run_editor() failed with", err) } } } -func verify_mount_dir(hostname string)(mount string) { - homedir, homedirerr := os.UserHomeDir() - if homedirerr != nil { - fmt.Println( homedirerr ) +func verify_mount_dir(hostname string) (string, error) { + homedir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolve home dir: %w", err) } - mount = homedir+"/Servers/"+hostname - os.MkdirAll(mount, os.ModePerm) - - return + base := filepath.Join(homedir, "Servers") + mount := filepath.Clean(filepath.Join(base, hostname)) + if !strings.HasPrefix(mount, base+string(os.PathSeparator)) { + return "", fmt.Errorf("hostname %q escapes mount base %q", hostname, base) + } + if err := os.MkdirAll(mount, 0700); err != nil { + return "", fmt.Errorf("create mount dir %q: %w", mount, err) + } + return mount, nil } -func mount_sshfs(hostname string, user string, ifile string, port string, mount string) { +func mount_sshfs(hostname string, user string, ifile string, port string, mount string) error { cmd := exec.Command("sshfs", "-p", port, "-o", "IdentityFile="+ifile, "-o", "idmap=user", @@ -91,8 +110,5 @@ func mount_sshfs(hostname string, user string, ifile string, port string, mount user+"@"+hostname+":/", mount) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - err := cmd.Run() - if err != nil { - fmt.Println("mount_sshfs() failed with\n",err) - } + return cmd.Run() }