SSH connection status indicators

- Add Connected bool field to vfs.Entry and RemoteMount
- Track connection status in sshState.connectedHosts
- Show status icon (connected/disconnected) in pane header when browsing remote host
- Async SSH connection test with cancel support for Add Host dialog
- Colored labels and styled help text in SSH dialogs
- Confirmation dialog when deleting manually-added SSH hosts
This commit is contained in:
vrubelroman 2026-04-29 03:11:53 +03:00
parent df4df6b8f6
commit 1ed2d3defb
224 changed files with 33447 additions and 236 deletions

View file

@ -74,18 +74,21 @@ var (
)
type Entry struct {
Name string
Path string
Extension string
Mode fs.FileMode
Size int64
ModifiedAt time.Time
CreatedAt time.Time
CreatedKnown bool
IsDir bool
IsParent bool
IsHidden bool
DirSizeKnown bool
Name string
Path string
Extension string
Mode fs.FileMode
Size int64
ModifiedAt time.Time
CreatedAt time.Time
CreatedKnown bool
IsDir bool
IsParent bool
IsHidden bool
IsRemote bool
Connected bool
DirSizeKnown bool
RemoteHostName string
}
func (e Entry) DisplayName() string {
@ -114,6 +117,8 @@ func (e Entry) Category() string {
switch {
case e.IsParent:
return "parent"
case e.IsRemote:
return "remote"
case e.IsDir:
return "directory"
case hasExt(configExtensions, e.Extension):

View file

@ -0,0 +1,573 @@
package remote
import (
"fmt"
"io"
"net"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// SSHClient wraps an SSH connection and SFTP client for remote filesystem access.
type SSHClient struct {
// Host is the SSH host configuration used to establish the connection.
Host SSHHost
sshConn *ssh.Client
sftpCli *sftp.Client
}
// Connect establishes an SSH connection to the remote host and opens an SFTP session.
// It uses key-based authentication if IdentityFile is set, otherwise falls back to password auth.
func Connect(host SSHHost) (*SSHClient, error) {
authMethods, err := authMethodsForHost(host)
if err != nil {
return nil, fmt.Errorf("ssh auth: %w", err)
}
user := host.User
if user == "" {
user = os.Getenv("USER")
}
config := &ssh.ClientConfig{
User: user,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: support known_hosts verification
Timeout: 15 * time.Second,
}
addr := host.Addr()
sshConn, err := ssh.Dial("tcp", addr, config)
if err != nil {
return nil, fmt.Errorf("ssh dial %s: %w", addr, err)
}
sftpCli, err := sftp.NewClient(sshConn)
if err != nil {
sshConn.Close()
return nil, fmt.Errorf("sftp client: %w", err)
}
return &SSHClient{
Host: host,
sshConn: sshConn,
sftpCli: sftpCli,
}, nil
}
// authMethodsForHost returns the appropriate SSH auth methods for the given host.
// For SSH config hosts with IdentityFile, it uses public key authentication.
// For custom hosts with a password, it uses password authentication.
func authMethodsForHost(host SSHHost) ([]ssh.AuthMethod, error) {
var methods []ssh.AuthMethod
// Try key-based auth if identity file is specified
if host.IdentityFile != "" {
key, err := os.ReadFile(host.IdentityFile)
if err == nil {
signer, err := ssh.ParsePrivateKey(key)
if err == nil {
methods = append(methods, ssh.PublicKeys(signer))
} else {
// If the key is encrypted, try with empty passphrase or common ones
// For simplicity, we try without passphrase first
// In a real implementation, we might prompt for a passphrase
}
}
}
// Try password auth if password is set
if host.Password != "" {
methods = append(methods, ssh.Password(host.Password))
}
// Always include keyboard-interactive as a fallback (it wraps password)
if host.Password != "" {
methods = append(methods, ssh.KeyboardInteractive(
func(user, instruction string, questions []string, echos []bool) ([]string, error) {
answers := make([]string, len(questions))
for i := range questions {
answers[i] = host.Password
}
return answers, nil
},
))
}
// Always try default SSH agent and default keys as a last resort
// This covers the case where the user has an SSH agent running with loaded keys
// but no IdentityFile is specified in the config.
if host.IdentityFile == "" && host.Password == "" {
// Add default key paths
home, err := os.UserHomeDir()
if err == nil {
defaultKeys := []string{
home + "/.ssh/id_rsa",
home + "/.ssh/id_ed25519",
home + "/.ssh/id_ecdsa",
home + "/.ssh/id_ecdsa_sk",
home + "/.ssh/id_ed25519_sk",
home + "/.ssh/identity",
}
for _, keyPath := range defaultKeys {
if key, err := os.ReadFile(keyPath); err == nil {
if signer, err := ssh.ParsePrivateKey(key); err == nil {
methods = append(methods, ssh.PublicKeys(signer))
}
}
}
}
}
if len(methods) == 0 {
return nil, fmt.Errorf("no authentication methods available for host %q", host.Name)
}
return methods, nil
}
// ReadDir reads the contents of a remote directory and returns os.FileInfo entries.
func (c *SSHClient) ReadDir(dirPath string) ([]os.FileInfo, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.ReadDir(dirPath)
}
// Lstat returns file information without following symlinks.
func (c *SSHClient) Lstat(path string) (os.FileInfo, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.Lstat(path)
}
// Stat returns file information following symlinks.
func (c *SSHClient) Stat(path string) (os.FileInfo, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.Stat(path)
}
// ReadLink reads the target of a symbolic link.
func (c *SSHClient) ReadLink(linkPath string) (string, error) {
if c.sftpCli == nil {
return "", fmt.Errorf("not connected")
}
return c.sftpCli.ReadLink(linkPath)
}
// RealPath resolves a path to its absolute form on the remote server.
func (c *SSHClient) RealPath(p string) (string, error) {
if c.sftpCli == nil {
return "", fmt.Errorf("not connected")
}
return c.sftpCli.RealPath(p)
}
// ReadFile opens a remote file for reading.
func (c *SSHClient) ReadFile(filePath string) (io.ReadCloser, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.Open(filePath)
}
// CreateFile opens a remote file for writing, creating it if it doesn't exist.
func (c *SSHClient) CreateFile(filePath string) (io.WriteCloser, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.Create(filePath)
}
// MkdirAll creates a remote directory and any necessary parents.
// If the directory already exists, it returns nil (no error).
func (c *SSHClient) MkdirAll(dirPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
// sftp doesn't have MkdirAll, so we implement it manually
// First check if the path already exists
_, err := c.sftpCli.Stat(dirPath)
if err == nil {
return nil // already exists
}
if !os.IsNotExist(err) {
return err
}
// Ensure parent exists first
parent := path.Dir(dirPath)
if parent != dirPath && parent != "." {
if err := c.MkdirAll(parent); err != nil {
return err
}
}
return c.sftpCli.Mkdir(dirPath)
}
// Mkdir creates a single remote directory.
func (c *SSHClient) Mkdir(dirPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
return c.sftpCli.Mkdir(dirPath)
}
// Remove deletes a remote file.
func (c *SSHClient) Remove(filePath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
return c.sftpCli.Remove(filePath)
}
// RemoveDirectory removes a remote directory (must be empty).
func (c *SSHClient) RemoveDirectory(dirPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
return c.sftpCli.RemoveDirectory(dirPath)
}
// Rename moves/renames a remote file or directory.
func (c *SSHClient) Rename(oldPath, newPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
return c.sftpCli.Rename(oldPath, newPath)
}
// Close closes the SFTP session and SSH connection.
func (c *SSHClient) Close() error {
var firstErr error
if c.sftpCli != nil {
if err := c.sftpCli.Close(); err != nil {
firstErr = err
}
c.sftpCli = nil
}
if c.sshConn != nil {
if err := c.sshConn.Close(); err != nil && firstErr == nil {
firstErr = err
}
c.sshConn = nil
}
return firstErr
}
// IsConnected returns true if the client has an active connection.
func (c *SSHClient) IsConnected() bool {
return c.sftpCli != nil && c.sshConn != nil
}
// RemoveRecursive recursively deletes a remote file or directory.
// For directories, it walks and removes all children first.
func (c *SSHClient) RemoveRecursive(path string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
info, err := c.sftpCli.Stat(path)
if err != nil {
return err
}
if !info.IsDir() {
return c.sftpCli.Remove(path)
}
// Walk directory and collect all paths (files first, then dirs)
var files []string
var dirs []string
err = c.Walk(path, func(walkPath string, info os.FileInfo, walkErr error) error {
if walkErr != nil {
return walkErr
}
if walkPath == path {
return nil // skip root
}
if info.IsDir() {
dirs = append(dirs, walkPath)
} else {
files = append(files, walkPath)
}
return nil
})
if err != nil {
return err
}
// Remove files first, then directories (reverse order for deepest first)
for _, f := range files {
if err := c.sftpCli.Remove(f); err != nil {
return err
}
}
for i := len(dirs) - 1; i >= 0; i-- {
if err := c.sftpCli.RemoveDirectory(dirs[i]); err != nil {
return err
}
}
// Finally remove the root directory
return c.sftpCli.RemoveDirectory(path)
}
// CopyFileToRemote copies a local file to a remote destination via SFTP.
// It creates parent directories as needed.
func (c *SSHClient) CopyFileToRemote(localPath, remotePath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
localFile, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("open local: %w", err)
}
defer localFile.Close()
// Ensure parent directory exists
parent := path.Dir(remotePath)
if err := c.MkdirAll(parent); err != nil {
return fmt.Errorf("mkdir remote: %w", err)
}
remoteFile, err := c.sftpCli.Create(remotePath)
if err != nil {
return fmt.Errorf("create remote: %w", err)
}
defer remoteFile.Close()
_, err = io.Copy(remoteFile, localFile)
if err != nil {
return fmt.Errorf("copy to remote: %w", err)
}
return nil
}
// CopyFileFromRemote copies a remote file to a local destination via SFTP.
// It creates parent directories as needed.
func (c *SSHClient) CopyFileFromRemote(remotePath, localPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
remoteFile, err := c.sftpCli.Open(remotePath)
if err != nil {
return fmt.Errorf("open remote: %w", err)
}
defer remoteFile.Close()
// Ensure parent directory exists
parent := filepath.Dir(localPath)
if err := os.MkdirAll(parent, 0o755); err != nil {
return fmt.Errorf("mkdir local: %w", err)
}
localFile, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("create local: %w", err)
}
defer localFile.Close()
_, err = io.Copy(localFile, remoteFile)
if err != nil {
return fmt.Errorf("copy from remote: %w", err)
}
return nil
}
// CopyDirToRemote recursively copies a local directory to a remote path.
func (c *SSHClient) CopyDirToRemote(localDir, remoteDir string) error {
return filepath.Walk(localDir, func(localPath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, _ := filepath.Rel(localDir, localPath)
remotePath := path.Join(remoteDir, relPath)
if info.IsDir() {
return c.MkdirAll(remotePath)
}
return c.CopyFileToRemote(localPath, remotePath)
})
}
// CopyDirFromRemote recursively copies a remote directory to a local path.
func (c *SSHClient) CopyDirFromRemote(remoteDir, localDir string) error {
return c.Walk(remoteDir, func(remotePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, _ := filepath.Rel(remoteDir, remotePath)
localPath := filepath.Join(localDir, relPath)
if info.IsDir() {
return os.MkdirAll(localPath, 0o755)
}
return c.CopyFileFromRemote(remotePath, localPath)
})
}
// CopyFileBetweenRemotes copies a single file from one remote host to another
// by streaming the file contents through the local machine. Both SFTP connections
// must be active (connected).
func CopyFileBetweenRemotes(srcClient, dstClient *SSHClient, srcPath, dstPath string) error {
if srcClient.sftpCli == nil {
return fmt.Errorf("source client not connected")
}
if dstClient.sftpCli == nil {
return fmt.Errorf("destination client not connected")
}
srcFile, err := srcClient.sftpCli.Open(srcPath)
if err != nil {
return fmt.Errorf("open remote source %s: %w", srcPath, err)
}
defer srcFile.Close()
// Ensure parent directory exists on the destination
parent := path.Dir(dstPath)
if err := dstClient.MkdirAll(parent); err != nil {
return fmt.Errorf("mkdir remote dest %s: %w", parent, err)
}
dstFile, err := dstClient.sftpCli.Create(dstPath)
if err != nil {
return fmt.Errorf("create remote dest %s: %w", dstPath, err)
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
if err != nil {
return fmt.Errorf("copy remote to remote %s → %s: %w", srcPath, dstPath, err)
}
return nil
}
// CopyDirBetweenRemotes recursively copies a directory from one remote host to another.
// Directories are created on the destination first, then all files are streamed through
// the local machine.
func CopyDirBetweenRemotes(srcClient, dstClient *SSHClient, srcDir, dstDir string) error {
return srcClient.Walk(srcDir, func(remotePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, _ := filepath.Rel(srcDir, remotePath)
dstPath := path.Join(dstDir, relPath)
if info.IsDir() {
return dstClient.MkdirAll(dstPath)
}
return CopyFileBetweenRemotes(srcClient, dstClient, remotePath, dstPath)
})
}
// Walk walks the remote filesystem tree rooted at root, calling walkFn for each file/dir.
// This is a simplified version of filepath.Walk for SFTP.
type walkFunc func(path string, info os.FileInfo, err error) error
func (c *SSHClient) Walk(root string, walkFn walkFunc) error {
return c.walk(root, walkFn, nil)
}
func (c *SSHClient) walk(dirPath string, walkFn walkFunc, info os.FileInfo) error {
if info == nil {
var err error
info, err = c.sftpCli.Stat(dirPath)
if err != nil {
return walkFn(dirPath, nil, err)
}
}
err := walkFn(dirPath, info, nil)
if err != nil {
if err == filepathSkipDir {
return nil
}
return err
}
if !info.IsDir() {
return nil
}
entries, err := c.sftpCli.ReadDir(dirPath)
if err != nil {
return walkFn(dirPath, info, err)
}
for _, entry := range entries {
childPath := path.Join(dirPath, entry.Name())
if entry.IsDir() {
err = c.walk(childPath, walkFn, entry)
} else {
err = walkFn(childPath, entry, nil)
}
if err != nil {
return err
}
}
return nil
}
// filepathSkipDir is used as a return value from Walk to skip a directory.
var filepathSkipDir = fmt.Errorf("skip this directory")
// SftpToFileInfo converts an os.FileInfo to a vfs-compatible file info.
// This is used for consistent file information handling across local and remote.
func SftpToFileInfo(name string, info os.FileInfo) (os.FileInfo, error) {
return info, nil
}
// WalkDirEntry wraps os.FileInfo with the file name for directory listings.
type WalkDirEntry struct {
os.FileInfo
entryName string
}
func (e *WalkDirEntry) Name() string {
return e.entryName
}
// NewWalkDirEntry creates a new WalkDirEntry with an overridden name.
func NewWalkDirEntry(info os.FileInfo, name string) *WalkDirEntry {
return &WalkDirEntry{FileInfo: info, entryName: name}
}
// DialTimeout is the timeout for establishing SSH connections.
const DialTimeout = 15 * time.Second
// DefaultPort is the default SSH port.
const DefaultPort = "22"
// ResolveAddr returns the SSH address for the given host, applying the default port if needed.
func ResolveAddr(hostname, port string) string {
host := strings.TrimSpace(hostname)
if port == "" || port == "0" {
port = DefaultPort
}
return net.JoinHostPort(host, port)
}

View file

@ -0,0 +1,209 @@
package remote
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
)
// ParseSSHConfig parses ~/.ssh/config and returns a list of SSH hosts.
// It handles the most common SSH config directives: Host, HostName, Port, User, IdentityFile.
func ParseSSHConfig() []SSHHost {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".ssh", "config")
return parseSSHConfigFile(configPath)
}
func parseSSHConfigFile(path string) []SSHHost {
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
var hosts []SSHHost
var current *SSHHost
var currentNames []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Remove inline comments (everything after # that's not in quotes)
if idx := strings.Index(line, "#"); idx >= 0 {
line = strings.TrimSpace(line[:idx])
if line == "" {
continue
}
}
parts := strings.Fields(line)
if len(parts) < 2 {
continue
}
keyword := strings.ToLower(parts[0])
value := strings.Join(parts[1:], " ")
switch keyword {
case "host":
// Save previous host block
if current != nil && len(currentNames) > 0 {
for _, name := range currentNames {
if !isWildcardPattern(name) {
host := *current
host.Name = name
hosts = append(hosts, host)
}
}
}
// Start new host block
current = &SSHHost{
Port: "22",
FromSSHConfig: true,
}
currentNames = strings.Fields(value)
case "hostname":
if current != nil {
current.HostName = value
}
case "port":
if current != nil {
current.Port = value
}
case "user":
if current != nil {
current.User = value
}
case "identityfile":
if current != nil {
// Handle ~ expansion and relative paths
resolved := resolveIdentityPath(value)
if resolved != "" {
current.IdentityFile = resolved
}
}
}
}
// Save last host block
if current != nil && len(currentNames) > 0 {
for _, name := range currentNames {
if !isWildcardPattern(name) {
host := *current
host.Name = name
hosts = append(hosts, host)
}
}
}
return hosts
}
// isWildcardPattern returns true if the pattern contains wildcard characters.
func isWildcardPattern(pattern string) bool {
return strings.ContainsAny(pattern, "*?")
}
// resolveIdentityPath resolves a path from SSH config (handles ~ and relative paths).
func resolveIdentityPath(path string) string {
if path == "" {
return ""
}
// Handle ~/ or $HOME/
if strings.HasPrefix(path, "~/") {
home, err := os.UserHomeDir()
if err != nil {
return path
}
path = filepath.Join(home, path[2:])
}
// Handle relative paths (relative to ~/.ssh/)
if !filepath.IsAbs(path) {
home, err := os.UserHomeDir()
if err != nil {
return path
}
path = filepath.Join(home, ".ssh", path)
}
return filepath.Clean(path)
}
// SSHConfigPath returns the path to the user's SSH config file.
func SSHConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".ssh", "config")
}
// HostsFilePath returns the path to the custom hosts data file.
func HostsFilePath() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".config", "vcom", "hosts.dat")
}
// GetSSHDir returns the path to the .ssh directory.
func GetSSHDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".ssh")
}
// KnownHostsPath returns the path to known_hosts.
func KnownHostsPath() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".ssh", "known_hosts")
}
// ConfigFileExists checks if the SSH config file exists.
func ConfigFileExists() bool {
path := SSHConfigPath()
if path == "" {
return false
}
_, err := os.Stat(path)
return err == nil
}
// ValidateHost checks if a host entry has the minimum required fields.
func ValidateHost(host SSHHost) error {
if strings.TrimSpace(host.Name) == "" {
return fmt.Errorf("host name is required")
}
if strings.TrimSpace(host.HostName) == "" {
return fmt.Errorf("hostname/address is required")
}
if strings.TrimSpace(host.User) == "" {
return fmt.Errorf("username is required")
}
return nil
}

305
internal/fs/remote/host.go Normal file
View file

@ -0,0 +1,305 @@
package remote
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// SSHHost represents a single SSH host configuration.
type SSHHost struct {
// Name is the host alias (e.g. "myserver").
Name string `json:"name"`
// HostName is the actual hostname or IP address.
HostName string `json:"hostname"`
// Port is the SSH port (default 22).
Port string `json:"port,omitempty"`
// User is the SSH username.
User string `json:"user,omitempty"`
// IdentityFile is the path to the private key file (for key-based auth).
IdentityFile string `json:"identity_file,omitempty"`
// Password is stored encrypted (for password-based auth, user-added hosts).
Password string `json:"password,omitempty"`
// FromSSHConfig indicates this host came from ~/.ssh/config.
FromSSHConfig bool `json:"from_ssh_config"`
}
// DisplayName returns the host display name.
func (h SSHHost) DisplayName() string {
addr := h.HostName
if h.Port != "" && h.Port != "22" {
addr = fmt.Sprintf("%s:%s", addr, h.Port)
}
if h.User != "" {
return fmt.Sprintf("%s (%s@%s)", h.Name, h.User, addr)
}
return fmt.Sprintf("%s (%s)", h.Name, addr)
}
// Addr returns the SSH address string (host:port).
func (h SSHHost) Addr() string {
if h.Port == "" || h.Port == "22" {
return h.HostName + ":22"
}
return h.HostName + ":" + h.Port
}
// HostStore manages SSH hosts from both ~/.ssh/config and user-added hosts.
type HostStore struct {
customHosts []SSHHost
configPath string
cipherKey []byte
}
// NewHostStore creates a new HostStore.
func NewHostStore() (*HostStore, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("home dir: %w", err)
}
store := &HostStore{
configPath: filepath.Join(home, ".config", "vcom", "hosts.dat"),
}
// Load or create encryption key
keyPath := filepath.Join(home, ".config", "vcom", ".hosts-key")
store.cipherKey, err = loadOrCreateKey(keyPath)
if err != nil {
return nil, fmt.Errorf("encryption key: %w", err)
}
// Load custom hosts
if err := store.load(); err != nil {
// Ignore load errors for missing file
if !os.IsNotExist(err) {
return nil, err
}
}
return store, nil
}
// loadOrCreateKey loads an existing AES key or creates a new one.
func loadOrCreateKey(path string) ([]byte, error) {
if data, err := os.ReadFile(path); err == nil {
key, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(data)))
if err != nil {
return nil, err
}
if len(key) == 32 {
return key, nil
}
}
// Generate new 32-byte key for AES-256
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o700); err != nil {
return nil, fmt.Errorf("mkdir: %w", err)
}
encoded := base64.StdEncoding.EncodeToString(key)
if err := os.WriteFile(path, []byte(encoded), 0o600); err != nil {
return nil, fmt.Errorf("write key: %w", err)
}
return key, nil
}
type storedHosts struct {
Hosts []storedHost `json:"hosts"`
}
type storedHost struct {
Name string `json:"name"`
HostName string `json:"hostname"`
Port string `json:"port,omitempty"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"` // encrypted
IdentityFile string `json:"identity_file,omitempty"`
}
func (s *HostStore) load() error {
data, err := os.ReadFile(s.configPath)
if err != nil {
return err
}
// Decrypt
decrypted, err := decrypt(data, s.cipherKey)
if err != nil {
return fmt.Errorf("decrypt hosts: %w", err)
}
var stored storedHosts
if err := json.Unmarshal(decrypted, &stored); err != nil {
return fmt.Errorf("parse hosts: %w", err)
}
s.customHosts = make([]SSHHost, len(stored.Hosts))
for i, h := range stored.Hosts {
password := ""
if h.Password != "" {
pwd, err := decrypt([]byte(h.Password), s.cipherKey)
if err == nil {
password = string(pwd)
}
}
s.customHosts[i] = SSHHost{
Name: h.Name,
HostName: h.HostName,
Port: h.Port,
User: h.User,
Password: password,
IdentityFile: h.IdentityFile,
FromSSHConfig: false,
}
}
return nil
}
// Save persists custom hosts to disk (encrypted).
func (s *HostStore) Save() error {
stored := storedHosts{
Hosts: make([]storedHost, len(s.customHosts)),
}
for i, h := range s.customHosts {
password := ""
if h.Password != "" {
enc, err := encrypt([]byte(h.Password), s.cipherKey)
if err == nil {
password = string(enc)
}
}
stored.Hosts[i] = storedHost{
Name: h.Name,
HostName: h.HostName,
Port: h.Port,
User: h.User,
Password: password,
IdentityFile: h.IdentityFile,
}
}
data, err := json.Marshal(stored)
if err != nil {
return fmt.Errorf("marshal hosts: %w", err)
}
encrypted, err := encrypt(data, s.cipherKey)
if err != nil {
return fmt.Errorf("encrypt hosts: %w", err)
}
dir := filepath.Dir(s.configPath)
if err := os.MkdirAll(dir, 0o700); err != nil {
return fmt.Errorf("mkdir: %w", err)
}
return os.WriteFile(s.configPath, encrypted, 0o600)
}
// AddHost adds a custom host and saves.
func (s *HostStore) AddHost(host SSHHost) error {
host.FromSSHConfig = false
s.customHosts = append(s.customHosts, host)
return s.Save()
}
// RemoveHost removes a custom host by name.
func (s *HostStore) RemoveHost(name string) error {
for i, h := range s.customHosts {
if h.Name == name {
s.customHosts = append(s.customHosts[:i], s.customHosts[i+1:]...)
return s.Save()
}
}
return fmt.Errorf("host %q not found", name)
}
// AllHosts returns all hosts (from ssh config + custom).
func (s *HostStore) AllHosts() []SSHHost {
sshConfigHosts := ParseSSHConfig()
result := make([]SSHHost, 0, len(sshConfigHosts)+len(s.customHosts))
// Build a set of names from ssh config to avoid duplicates
seen := make(map[string]bool)
for _, h := range sshConfigHosts {
lower := strings.ToLower(h.Name)
seen[lower] = true
result = append(result, h)
}
for _, h := range s.customHosts {
lower := strings.ToLower(h.Name)
if !seen[lower] {
result = append(result, h)
seen[lower] = true
}
}
return result
}
// FindByName looks up a host by its Name field. Returns nil if not found.
func (s *HostStore) FindByName(name string) *SSHHost {
all := s.AllHosts()
for i := range all {
if strings.EqualFold(all[i].Name, name) {
return &all[i]
}
}
return nil
}
func encrypt(plaintext []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return gcm.Seal(nonce, nonce, plaintext, nil), nil
}
func decrypt(ciphertext []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(ciphertext) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
return gcm.Open(nil, nonce, ciphertext, nil)
}