SSH connection status indicators

- Add Connected bool field to vfs.Entry and RemoteMount
- Track connection status in sshState.connectedHosts
- Show status icon (connected/disconnected) in pane header when browsing remote host
- Async SSH connection test with cancel support for Add Host dialog
- Colored labels and styled help text in SSH dialogs
- Confirmation dialog when deleting manually-added SSH hosts
This commit is contained in:
vrubelroman 2026-04-29 03:11:53 +03:00
parent df4df6b8f6
commit 1ed2d3defb
224 changed files with 33447 additions and 236 deletions

View file

@ -0,0 +1,573 @@
package remote
import (
"fmt"
"io"
"net"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// SSHClient wraps an SSH connection and SFTP client for remote filesystem access.
type SSHClient struct {
// Host is the SSH host configuration used to establish the connection.
Host SSHHost
sshConn *ssh.Client
sftpCli *sftp.Client
}
// Connect establishes an SSH connection to the remote host and opens an SFTP session.
// It uses key-based authentication if IdentityFile is set, otherwise falls back to password auth.
func Connect(host SSHHost) (*SSHClient, error) {
authMethods, err := authMethodsForHost(host)
if err != nil {
return nil, fmt.Errorf("ssh auth: %w", err)
}
user := host.User
if user == "" {
user = os.Getenv("USER")
}
config := &ssh.ClientConfig{
User: user,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: support known_hosts verification
Timeout: 15 * time.Second,
}
addr := host.Addr()
sshConn, err := ssh.Dial("tcp", addr, config)
if err != nil {
return nil, fmt.Errorf("ssh dial %s: %w", addr, err)
}
sftpCli, err := sftp.NewClient(sshConn)
if err != nil {
sshConn.Close()
return nil, fmt.Errorf("sftp client: %w", err)
}
return &SSHClient{
Host: host,
sshConn: sshConn,
sftpCli: sftpCli,
}, nil
}
// authMethodsForHost returns the appropriate SSH auth methods for the given host.
// For SSH config hosts with IdentityFile, it uses public key authentication.
// For custom hosts with a password, it uses password authentication.
func authMethodsForHost(host SSHHost) ([]ssh.AuthMethod, error) {
var methods []ssh.AuthMethod
// Try key-based auth if identity file is specified
if host.IdentityFile != "" {
key, err := os.ReadFile(host.IdentityFile)
if err == nil {
signer, err := ssh.ParsePrivateKey(key)
if err == nil {
methods = append(methods, ssh.PublicKeys(signer))
} else {
// If the key is encrypted, try with empty passphrase or common ones
// For simplicity, we try without passphrase first
// In a real implementation, we might prompt for a passphrase
}
}
}
// Try password auth if password is set
if host.Password != "" {
methods = append(methods, ssh.Password(host.Password))
}
// Always include keyboard-interactive as a fallback (it wraps password)
if host.Password != "" {
methods = append(methods, ssh.KeyboardInteractive(
func(user, instruction string, questions []string, echos []bool) ([]string, error) {
answers := make([]string, len(questions))
for i := range questions {
answers[i] = host.Password
}
return answers, nil
},
))
}
// Always try default SSH agent and default keys as a last resort
// This covers the case where the user has an SSH agent running with loaded keys
// but no IdentityFile is specified in the config.
if host.IdentityFile == "" && host.Password == "" {
// Add default key paths
home, err := os.UserHomeDir()
if err == nil {
defaultKeys := []string{
home + "/.ssh/id_rsa",
home + "/.ssh/id_ed25519",
home + "/.ssh/id_ecdsa",
home + "/.ssh/id_ecdsa_sk",
home + "/.ssh/id_ed25519_sk",
home + "/.ssh/identity",
}
for _, keyPath := range defaultKeys {
if key, err := os.ReadFile(keyPath); err == nil {
if signer, err := ssh.ParsePrivateKey(key); err == nil {
methods = append(methods, ssh.PublicKeys(signer))
}
}
}
}
}
if len(methods) == 0 {
return nil, fmt.Errorf("no authentication methods available for host %q", host.Name)
}
return methods, nil
}
// ReadDir reads the contents of a remote directory and returns os.FileInfo entries.
func (c *SSHClient) ReadDir(dirPath string) ([]os.FileInfo, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.ReadDir(dirPath)
}
// Lstat returns file information without following symlinks.
func (c *SSHClient) Lstat(path string) (os.FileInfo, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.Lstat(path)
}
// Stat returns file information following symlinks.
func (c *SSHClient) Stat(path string) (os.FileInfo, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.Stat(path)
}
// ReadLink reads the target of a symbolic link.
func (c *SSHClient) ReadLink(linkPath string) (string, error) {
if c.sftpCli == nil {
return "", fmt.Errorf("not connected")
}
return c.sftpCli.ReadLink(linkPath)
}
// RealPath resolves a path to its absolute form on the remote server.
func (c *SSHClient) RealPath(p string) (string, error) {
if c.sftpCli == nil {
return "", fmt.Errorf("not connected")
}
return c.sftpCli.RealPath(p)
}
// ReadFile opens a remote file for reading.
func (c *SSHClient) ReadFile(filePath string) (io.ReadCloser, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.Open(filePath)
}
// CreateFile opens a remote file for writing, creating it if it doesn't exist.
func (c *SSHClient) CreateFile(filePath string) (io.WriteCloser, error) {
if c.sftpCli == nil {
return nil, fmt.Errorf("not connected")
}
return c.sftpCli.Create(filePath)
}
// MkdirAll creates a remote directory and any necessary parents.
// If the directory already exists, it returns nil (no error).
func (c *SSHClient) MkdirAll(dirPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
// sftp doesn't have MkdirAll, so we implement it manually
// First check if the path already exists
_, err := c.sftpCli.Stat(dirPath)
if err == nil {
return nil // already exists
}
if !os.IsNotExist(err) {
return err
}
// Ensure parent exists first
parent := path.Dir(dirPath)
if parent != dirPath && parent != "." {
if err := c.MkdirAll(parent); err != nil {
return err
}
}
return c.sftpCli.Mkdir(dirPath)
}
// Mkdir creates a single remote directory.
func (c *SSHClient) Mkdir(dirPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
return c.sftpCli.Mkdir(dirPath)
}
// Remove deletes a remote file.
func (c *SSHClient) Remove(filePath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
return c.sftpCli.Remove(filePath)
}
// RemoveDirectory removes a remote directory (must be empty).
func (c *SSHClient) RemoveDirectory(dirPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
return c.sftpCli.RemoveDirectory(dirPath)
}
// Rename moves/renames a remote file or directory.
func (c *SSHClient) Rename(oldPath, newPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
return c.sftpCli.Rename(oldPath, newPath)
}
// Close closes the SFTP session and SSH connection.
func (c *SSHClient) Close() error {
var firstErr error
if c.sftpCli != nil {
if err := c.sftpCli.Close(); err != nil {
firstErr = err
}
c.sftpCli = nil
}
if c.sshConn != nil {
if err := c.sshConn.Close(); err != nil && firstErr == nil {
firstErr = err
}
c.sshConn = nil
}
return firstErr
}
// IsConnected returns true if the client has an active connection.
func (c *SSHClient) IsConnected() bool {
return c.sftpCli != nil && c.sshConn != nil
}
// RemoveRecursive recursively deletes a remote file or directory.
// For directories, it walks and removes all children first.
func (c *SSHClient) RemoveRecursive(path string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
info, err := c.sftpCli.Stat(path)
if err != nil {
return err
}
if !info.IsDir() {
return c.sftpCli.Remove(path)
}
// Walk directory and collect all paths (files first, then dirs)
var files []string
var dirs []string
err = c.Walk(path, func(walkPath string, info os.FileInfo, walkErr error) error {
if walkErr != nil {
return walkErr
}
if walkPath == path {
return nil // skip root
}
if info.IsDir() {
dirs = append(dirs, walkPath)
} else {
files = append(files, walkPath)
}
return nil
})
if err != nil {
return err
}
// Remove files first, then directories (reverse order for deepest first)
for _, f := range files {
if err := c.sftpCli.Remove(f); err != nil {
return err
}
}
for i := len(dirs) - 1; i >= 0; i-- {
if err := c.sftpCli.RemoveDirectory(dirs[i]); err != nil {
return err
}
}
// Finally remove the root directory
return c.sftpCli.RemoveDirectory(path)
}
// CopyFileToRemote copies a local file to a remote destination via SFTP.
// It creates parent directories as needed.
func (c *SSHClient) CopyFileToRemote(localPath, remotePath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
localFile, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("open local: %w", err)
}
defer localFile.Close()
// Ensure parent directory exists
parent := path.Dir(remotePath)
if err := c.MkdirAll(parent); err != nil {
return fmt.Errorf("mkdir remote: %w", err)
}
remoteFile, err := c.sftpCli.Create(remotePath)
if err != nil {
return fmt.Errorf("create remote: %w", err)
}
defer remoteFile.Close()
_, err = io.Copy(remoteFile, localFile)
if err != nil {
return fmt.Errorf("copy to remote: %w", err)
}
return nil
}
// CopyFileFromRemote copies a remote file to a local destination via SFTP.
// It creates parent directories as needed.
func (c *SSHClient) CopyFileFromRemote(remotePath, localPath string) error {
if c.sftpCli == nil {
return fmt.Errorf("not connected")
}
remoteFile, err := c.sftpCli.Open(remotePath)
if err != nil {
return fmt.Errorf("open remote: %w", err)
}
defer remoteFile.Close()
// Ensure parent directory exists
parent := filepath.Dir(localPath)
if err := os.MkdirAll(parent, 0o755); err != nil {
return fmt.Errorf("mkdir local: %w", err)
}
localFile, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("create local: %w", err)
}
defer localFile.Close()
_, err = io.Copy(localFile, remoteFile)
if err != nil {
return fmt.Errorf("copy from remote: %w", err)
}
return nil
}
// CopyDirToRemote recursively copies a local directory to a remote path.
func (c *SSHClient) CopyDirToRemote(localDir, remoteDir string) error {
return filepath.Walk(localDir, func(localPath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, _ := filepath.Rel(localDir, localPath)
remotePath := path.Join(remoteDir, relPath)
if info.IsDir() {
return c.MkdirAll(remotePath)
}
return c.CopyFileToRemote(localPath, remotePath)
})
}
// CopyDirFromRemote recursively copies a remote directory to a local path.
func (c *SSHClient) CopyDirFromRemote(remoteDir, localDir string) error {
return c.Walk(remoteDir, func(remotePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, _ := filepath.Rel(remoteDir, remotePath)
localPath := filepath.Join(localDir, relPath)
if info.IsDir() {
return os.MkdirAll(localPath, 0o755)
}
return c.CopyFileFromRemote(remotePath, localPath)
})
}
// CopyFileBetweenRemotes copies a single file from one remote host to another
// by streaming the file contents through the local machine. Both SFTP connections
// must be active (connected).
func CopyFileBetweenRemotes(srcClient, dstClient *SSHClient, srcPath, dstPath string) error {
if srcClient.sftpCli == nil {
return fmt.Errorf("source client not connected")
}
if dstClient.sftpCli == nil {
return fmt.Errorf("destination client not connected")
}
srcFile, err := srcClient.sftpCli.Open(srcPath)
if err != nil {
return fmt.Errorf("open remote source %s: %w", srcPath, err)
}
defer srcFile.Close()
// Ensure parent directory exists on the destination
parent := path.Dir(dstPath)
if err := dstClient.MkdirAll(parent); err != nil {
return fmt.Errorf("mkdir remote dest %s: %w", parent, err)
}
dstFile, err := dstClient.sftpCli.Create(dstPath)
if err != nil {
return fmt.Errorf("create remote dest %s: %w", dstPath, err)
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
if err != nil {
return fmt.Errorf("copy remote to remote %s → %s: %w", srcPath, dstPath, err)
}
return nil
}
// CopyDirBetweenRemotes recursively copies a directory from one remote host to another.
// Directories are created on the destination first, then all files are streamed through
// the local machine.
func CopyDirBetweenRemotes(srcClient, dstClient *SSHClient, srcDir, dstDir string) error {
return srcClient.Walk(srcDir, func(remotePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, _ := filepath.Rel(srcDir, remotePath)
dstPath := path.Join(dstDir, relPath)
if info.IsDir() {
return dstClient.MkdirAll(dstPath)
}
return CopyFileBetweenRemotes(srcClient, dstClient, remotePath, dstPath)
})
}
// Walk walks the remote filesystem tree rooted at root, calling walkFn for each file/dir.
// This is a simplified version of filepath.Walk for SFTP.
type walkFunc func(path string, info os.FileInfo, err error) error
func (c *SSHClient) Walk(root string, walkFn walkFunc) error {
return c.walk(root, walkFn, nil)
}
func (c *SSHClient) walk(dirPath string, walkFn walkFunc, info os.FileInfo) error {
if info == nil {
var err error
info, err = c.sftpCli.Stat(dirPath)
if err != nil {
return walkFn(dirPath, nil, err)
}
}
err := walkFn(dirPath, info, nil)
if err != nil {
if err == filepathSkipDir {
return nil
}
return err
}
if !info.IsDir() {
return nil
}
entries, err := c.sftpCli.ReadDir(dirPath)
if err != nil {
return walkFn(dirPath, info, err)
}
for _, entry := range entries {
childPath := path.Join(dirPath, entry.Name())
if entry.IsDir() {
err = c.walk(childPath, walkFn, entry)
} else {
err = walkFn(childPath, entry, nil)
}
if err != nil {
return err
}
}
return nil
}
// filepathSkipDir is used as a return value from Walk to skip a directory.
var filepathSkipDir = fmt.Errorf("skip this directory")
// SftpToFileInfo converts an os.FileInfo to a vfs-compatible file info.
// This is used for consistent file information handling across local and remote.
func SftpToFileInfo(name string, info os.FileInfo) (os.FileInfo, error) {
return info, nil
}
// WalkDirEntry wraps os.FileInfo with the file name for directory listings.
type WalkDirEntry struct {
os.FileInfo
entryName string
}
func (e *WalkDirEntry) Name() string {
return e.entryName
}
// NewWalkDirEntry creates a new WalkDirEntry with an overridden name.
func NewWalkDirEntry(info os.FileInfo, name string) *WalkDirEntry {
return &WalkDirEntry{FileInfo: info, entryName: name}
}
// DialTimeout is the timeout for establishing SSH connections.
const DialTimeout = 15 * time.Second
// DefaultPort is the default SSH port.
const DefaultPort = "22"
// ResolveAddr returns the SSH address for the given host, applying the default port if needed.
func ResolveAddr(hostname, port string) string {
host := strings.TrimSpace(hostname)
if port == "" || port == "0" {
port = DefaultPort
}
return net.JoinHostPort(host, port)
}

View file

@ -0,0 +1,209 @@
package remote
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
)
// ParseSSHConfig parses ~/.ssh/config and returns a list of SSH hosts.
// It handles the most common SSH config directives: Host, HostName, Port, User, IdentityFile.
func ParseSSHConfig() []SSHHost {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".ssh", "config")
return parseSSHConfigFile(configPath)
}
func parseSSHConfigFile(path string) []SSHHost {
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
var hosts []SSHHost
var current *SSHHost
var currentNames []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Remove inline comments (everything after # that's not in quotes)
if idx := strings.Index(line, "#"); idx >= 0 {
line = strings.TrimSpace(line[:idx])
if line == "" {
continue
}
}
parts := strings.Fields(line)
if len(parts) < 2 {
continue
}
keyword := strings.ToLower(parts[0])
value := strings.Join(parts[1:], " ")
switch keyword {
case "host":
// Save previous host block
if current != nil && len(currentNames) > 0 {
for _, name := range currentNames {
if !isWildcardPattern(name) {
host := *current
host.Name = name
hosts = append(hosts, host)
}
}
}
// Start new host block
current = &SSHHost{
Port: "22",
FromSSHConfig: true,
}
currentNames = strings.Fields(value)
case "hostname":
if current != nil {
current.HostName = value
}
case "port":
if current != nil {
current.Port = value
}
case "user":
if current != nil {
current.User = value
}
case "identityfile":
if current != nil {
// Handle ~ expansion and relative paths
resolved := resolveIdentityPath(value)
if resolved != "" {
current.IdentityFile = resolved
}
}
}
}
// Save last host block
if current != nil && len(currentNames) > 0 {
for _, name := range currentNames {
if !isWildcardPattern(name) {
host := *current
host.Name = name
hosts = append(hosts, host)
}
}
}
return hosts
}
// isWildcardPattern returns true if the pattern contains wildcard characters.
func isWildcardPattern(pattern string) bool {
return strings.ContainsAny(pattern, "*?")
}
// resolveIdentityPath resolves a path from SSH config (handles ~ and relative paths).
func resolveIdentityPath(path string) string {
if path == "" {
return ""
}
// Handle ~/ or $HOME/
if strings.HasPrefix(path, "~/") {
home, err := os.UserHomeDir()
if err != nil {
return path
}
path = filepath.Join(home, path[2:])
}
// Handle relative paths (relative to ~/.ssh/)
if !filepath.IsAbs(path) {
home, err := os.UserHomeDir()
if err != nil {
return path
}
path = filepath.Join(home, ".ssh", path)
}
return filepath.Clean(path)
}
// SSHConfigPath returns the path to the user's SSH config file.
func SSHConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".ssh", "config")
}
// HostsFilePath returns the path to the custom hosts data file.
func HostsFilePath() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".config", "vcom", "hosts.dat")
}
// GetSSHDir returns the path to the .ssh directory.
func GetSSHDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".ssh")
}
// KnownHostsPath returns the path to known_hosts.
func KnownHostsPath() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".ssh", "known_hosts")
}
// ConfigFileExists checks if the SSH config file exists.
func ConfigFileExists() bool {
path := SSHConfigPath()
if path == "" {
return false
}
_, err := os.Stat(path)
return err == nil
}
// ValidateHost checks if a host entry has the minimum required fields.
func ValidateHost(host SSHHost) error {
if strings.TrimSpace(host.Name) == "" {
return fmt.Errorf("host name is required")
}
if strings.TrimSpace(host.HostName) == "" {
return fmt.Errorf("hostname/address is required")
}
if strings.TrimSpace(host.User) == "" {
return fmt.Errorf("username is required")
}
return nil
}

305
internal/fs/remote/host.go Normal file
View file

@ -0,0 +1,305 @@
package remote
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// SSHHost represents a single SSH host configuration.
type SSHHost struct {
// Name is the host alias (e.g. "myserver").
Name string `json:"name"`
// HostName is the actual hostname or IP address.
HostName string `json:"hostname"`
// Port is the SSH port (default 22).
Port string `json:"port,omitempty"`
// User is the SSH username.
User string `json:"user,omitempty"`
// IdentityFile is the path to the private key file (for key-based auth).
IdentityFile string `json:"identity_file,omitempty"`
// Password is stored encrypted (for password-based auth, user-added hosts).
Password string `json:"password,omitempty"`
// FromSSHConfig indicates this host came from ~/.ssh/config.
FromSSHConfig bool `json:"from_ssh_config"`
}
// DisplayName returns the host display name.
func (h SSHHost) DisplayName() string {
addr := h.HostName
if h.Port != "" && h.Port != "22" {
addr = fmt.Sprintf("%s:%s", addr, h.Port)
}
if h.User != "" {
return fmt.Sprintf("%s (%s@%s)", h.Name, h.User, addr)
}
return fmt.Sprintf("%s (%s)", h.Name, addr)
}
// Addr returns the SSH address string (host:port).
func (h SSHHost) Addr() string {
if h.Port == "" || h.Port == "22" {
return h.HostName + ":22"
}
return h.HostName + ":" + h.Port
}
// HostStore manages SSH hosts from both ~/.ssh/config and user-added hosts.
type HostStore struct {
customHosts []SSHHost
configPath string
cipherKey []byte
}
// NewHostStore creates a new HostStore.
func NewHostStore() (*HostStore, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("home dir: %w", err)
}
store := &HostStore{
configPath: filepath.Join(home, ".config", "vcom", "hosts.dat"),
}
// Load or create encryption key
keyPath := filepath.Join(home, ".config", "vcom", ".hosts-key")
store.cipherKey, err = loadOrCreateKey(keyPath)
if err != nil {
return nil, fmt.Errorf("encryption key: %w", err)
}
// Load custom hosts
if err := store.load(); err != nil {
// Ignore load errors for missing file
if !os.IsNotExist(err) {
return nil, err
}
}
return store, nil
}
// loadOrCreateKey loads an existing AES key or creates a new one.
func loadOrCreateKey(path string) ([]byte, error) {
if data, err := os.ReadFile(path); err == nil {
key, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(data)))
if err != nil {
return nil, err
}
if len(key) == 32 {
return key, nil
}
}
// Generate new 32-byte key for AES-256
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o700); err != nil {
return nil, fmt.Errorf("mkdir: %w", err)
}
encoded := base64.StdEncoding.EncodeToString(key)
if err := os.WriteFile(path, []byte(encoded), 0o600); err != nil {
return nil, fmt.Errorf("write key: %w", err)
}
return key, nil
}
type storedHosts struct {
Hosts []storedHost `json:"hosts"`
}
type storedHost struct {
Name string `json:"name"`
HostName string `json:"hostname"`
Port string `json:"port,omitempty"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"` // encrypted
IdentityFile string `json:"identity_file,omitempty"`
}
func (s *HostStore) load() error {
data, err := os.ReadFile(s.configPath)
if err != nil {
return err
}
// Decrypt
decrypted, err := decrypt(data, s.cipherKey)
if err != nil {
return fmt.Errorf("decrypt hosts: %w", err)
}
var stored storedHosts
if err := json.Unmarshal(decrypted, &stored); err != nil {
return fmt.Errorf("parse hosts: %w", err)
}
s.customHosts = make([]SSHHost, len(stored.Hosts))
for i, h := range stored.Hosts {
password := ""
if h.Password != "" {
pwd, err := decrypt([]byte(h.Password), s.cipherKey)
if err == nil {
password = string(pwd)
}
}
s.customHosts[i] = SSHHost{
Name: h.Name,
HostName: h.HostName,
Port: h.Port,
User: h.User,
Password: password,
IdentityFile: h.IdentityFile,
FromSSHConfig: false,
}
}
return nil
}
// Save persists custom hosts to disk (encrypted).
func (s *HostStore) Save() error {
stored := storedHosts{
Hosts: make([]storedHost, len(s.customHosts)),
}
for i, h := range s.customHosts {
password := ""
if h.Password != "" {
enc, err := encrypt([]byte(h.Password), s.cipherKey)
if err == nil {
password = string(enc)
}
}
stored.Hosts[i] = storedHost{
Name: h.Name,
HostName: h.HostName,
Port: h.Port,
User: h.User,
Password: password,
IdentityFile: h.IdentityFile,
}
}
data, err := json.Marshal(stored)
if err != nil {
return fmt.Errorf("marshal hosts: %w", err)
}
encrypted, err := encrypt(data, s.cipherKey)
if err != nil {
return fmt.Errorf("encrypt hosts: %w", err)
}
dir := filepath.Dir(s.configPath)
if err := os.MkdirAll(dir, 0o700); err != nil {
return fmt.Errorf("mkdir: %w", err)
}
return os.WriteFile(s.configPath, encrypted, 0o600)
}
// AddHost adds a custom host and saves.
func (s *HostStore) AddHost(host SSHHost) error {
host.FromSSHConfig = false
s.customHosts = append(s.customHosts, host)
return s.Save()
}
// RemoveHost removes a custom host by name.
func (s *HostStore) RemoveHost(name string) error {
for i, h := range s.customHosts {
if h.Name == name {
s.customHosts = append(s.customHosts[:i], s.customHosts[i+1:]...)
return s.Save()
}
}
return fmt.Errorf("host %q not found", name)
}
// AllHosts returns all hosts (from ssh config + custom).
func (s *HostStore) AllHosts() []SSHHost {
sshConfigHosts := ParseSSHConfig()
result := make([]SSHHost, 0, len(sshConfigHosts)+len(s.customHosts))
// Build a set of names from ssh config to avoid duplicates
seen := make(map[string]bool)
for _, h := range sshConfigHosts {
lower := strings.ToLower(h.Name)
seen[lower] = true
result = append(result, h)
}
for _, h := range s.customHosts {
lower := strings.ToLower(h.Name)
if !seen[lower] {
result = append(result, h)
seen[lower] = true
}
}
return result
}
// FindByName looks up a host by its Name field. Returns nil if not found.
func (s *HostStore) FindByName(name string) *SSHHost {
all := s.AllHosts()
for i := range all {
if strings.EqualFold(all[i].Name, name) {
return &all[i]
}
}
return nil
}
func encrypt(plaintext []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return gcm.Seal(nonce, nonce, plaintext, nil), nil
}
func decrypt(ciphertext []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(ciphertext) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
return gcm.Open(nil, nonce, ciphertext, nil)
}