From 96572be71238faee94c9ee76890f921188eebd3b Mon Sep 17 00:00:00 2001 From: Mikhail Chusavitin Date: Sat, 7 Mar 2026 22:14:31 +0300 Subject: [PATCH] Harden local admin and secret storage --- cmd/pfs/main.go | 18 +++-- internal/appstate/path.go | 14 +++- internal/localdb/encryption.go | 97 +++++++++++++++++++++--- internal/localdb/localdb.go | 4 +- internal/middleware/cors.go | 31 +++++++- internal/middleware/origin_protection.go | 45 +++++++++++ 6 files changed, 187 insertions(+), 22 deletions(-) create mode 100644 internal/middleware/origin_protection.go diff --git a/cmd/pfs/main.go b/cmd/pfs/main.go index 555539a..6df9f42 100644 --- a/cmd/pfs/main.go +++ b/cmd/pfs/main.go @@ -76,6 +76,9 @@ func main() { slog.Info("migrated legacy config file", "from", migratedFrom, "to", resolvedConfigPath) } } + if err := appstate.EnsurePrivateFile(resolvedConfigPath); err != nil { + slog.Warn("failed to enforce private permissions on config", "path", resolvedConfigPath, "error", err) + } // Load config for server settings cfg, err := config.Load(resolvedConfigPath) @@ -448,6 +451,7 @@ func setupRouter(cfg *config.Config, configPath string, connMgr *db.ConnectionMa router.MaxMultipartMemory = 26 << 20 // 26MB; stock import handler enforces 25MB payload limit router.Use(gin.Recovery()) router.Use(requestLogger()) + router.Use(middleware.OriginProtection()) router.Use(middleware.CORS()) router.Use(middleware.OfflineDetector(connMgr)) @@ -587,13 +591,13 @@ func setupRouter(cfg *config.Config, configPath string, connMgr *db.ConnectionMa pricingAdmin.POST("/stock/mappings", pricingHandler.UpsertStockMapping) pricingAdmin.DELETE("/stock/mappings/:partnumber", pricingHandler.DeleteStockMapping) pricingAdmin.GET("/vendor-mappings", pricingHandler.ListVendorMappings) - pricingAdmin.GET("/vendor-mappings/detail", pricingHandler.GetVendorMappingDetail) - pricingAdmin.POST("/vendor-mappings", pricingHandler.UpsertVendorMapping) - pricingAdmin.POST("/vendor-mappings/import-csv", pricingHandler.ImportVendorMappingsCSV) - pricingAdmin.GET("/vendor-mappings/export-unmapped-csv", pricingHandler.ExportUnmappedVendorMappingsCSV) - pricingAdmin.DELETE("/vendor-mappings", pricingHandler.DeleteVendorMapping) - pricingAdmin.POST("/vendor-mappings/ignore", pricingHandler.IgnoreVendorMapping) - pricingAdmin.POST("/vendor-mappings/unignore", pricingHandler.UnignoreVendorMapping) + pricingAdmin.GET("/vendor-mappings/detail", pricingHandler.GetVendorMappingDetail) + pricingAdmin.POST("/vendor-mappings", pricingHandler.UpsertVendorMapping) + pricingAdmin.POST("/vendor-mappings/import-csv", pricingHandler.ImportVendorMappingsCSV) + pricingAdmin.GET("/vendor-mappings/export-unmapped-csv", pricingHandler.ExportUnmappedVendorMappingsCSV) + pricingAdmin.DELETE("/vendor-mappings", pricingHandler.DeleteVendorMapping) + pricingAdmin.POST("/vendor-mappings/ignore", pricingHandler.IgnoreVendorMapping) + pricingAdmin.POST("/vendor-mappings/unignore", pricingHandler.UnignoreVendorMapping) pricingAdmin.GET("/alerts", pricingHandler.ListAlerts) pricingAdmin.POST("/alerts/:id/acknowledge", pricingHandler.AcknowledgeAlert) pricingAdmin.POST("/alerts/:id/resolve", pricingHandler.ResolveAlert) diff --git a/internal/appstate/path.go b/internal/appstate/path.go index 3ecb999..a744b4b 100644 --- a/internal/appstate/path.go +++ b/internal/appstate/path.go @@ -67,7 +67,7 @@ func MigrateLegacyDB(targetPath string, legacyPaths []string) (string, error) { return "", nil } - if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(targetPath), 0700); err != nil { return "", fmt.Errorf("creating target db directory: %w", err) } @@ -105,7 +105,7 @@ func MigrateLegacyFile(targetPath string, legacyPaths []string) (string, error) return "", nil } - if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(targetPath), 0700); err != nil { return "", fmt.Errorf("creating target directory: %w", err) } @@ -195,3 +195,13 @@ func copyFile(src, dst string) error { return out.Sync() } + +func EnsurePrivateFile(path string) error { + if path == "" { + return nil + } + if err := os.Chmod(path, 0600); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/internal/localdb/encryption.go b/internal/localdb/encryption.go index 4db1142..4637acb 100644 --- a/internal/localdb/encryption.go +++ b/internal/localdb/encryption.go @@ -7,20 +7,81 @@ import ( "crypto/sha256" "encoding/base64" "errors" + "fmt" "io" "os" + "path/filepath" + "strings" + "sync" ) -// getEncryptionKey derives a 32-byte key from environment variable or machine ID -func getEncryptionKey() []byte { - key := os.Getenv("QUOTEFORGE_ENCRYPTION_KEY") - if key == "" { - // Fallback to a machine-based key (hostname + fixed salt) - hostname, _ := os.Hostname() - key = hostname + "priceforge-salt-2024" +const encryptionKeyFileName = "local_secret.key" + +var ( + encryptionStateMu sync.RWMutex + encryptionKeyDir string +) + +func configureEncryption(dir string) { + encryptionStateMu.Lock() + defer encryptionStateMu.Unlock() + encryptionKeyDir = dir +} + +func getEncryptionKey() ([]byte, error) { + if key := strings.TrimSpace(os.Getenv("PRICEFORGE_ENCRYPTION_KEY")); key != "" { + hash := sha256.Sum256([]byte(key)) + return hash[:], nil } - // Hash to get exactly 32 bytes for AES-256 - hash := sha256.Sum256([]byte(key)) + if key := strings.TrimSpace(os.Getenv("QUOTEFORGE_ENCRYPTION_KEY")); key != "" { + hash := sha256.Sum256([]byte(key)) + return hash[:], nil + } + return loadOrCreateKeyFile() +} + +func loadOrCreateKeyFile() ([]byte, error) { + keyPath, err := encryptionKeyPath() + if err != nil { + return nil, err + } + if err := os.MkdirAll(filepath.Dir(keyPath), 0700); err != nil { + return nil, fmt.Errorf("create encryption key dir: %w", err) + } + if data, err := os.ReadFile(keyPath); err == nil { + trimmed := strings.TrimSpace(string(data)) + if trimmed == "" { + return nil, fmt.Errorf("encryption key file %s is empty", keyPath) + } + hash := sha256.Sum256([]byte(trimmed)) + return hash[:], nil + } + + raw := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, raw); err != nil { + return nil, fmt.Errorf("generate encryption key: %w", err) + } + encoded := base64.StdEncoding.EncodeToString(raw) + if err := os.WriteFile(keyPath, []byte(encoded), 0600); err != nil { + return nil, fmt.Errorf("write encryption key file: %w", err) + } + hash := sha256.Sum256([]byte(encoded)) + return hash[:], nil +} + +func encryptionKeyPath() (string, error) { + encryptionStateMu.RLock() + dir := encryptionKeyDir + encryptionStateMu.RUnlock() + if strings.TrimSpace(dir) == "" { + return "", errors.New("encryption key directory is not configured") + } + return filepath.Join(dir, encryptionKeyFileName), nil +} + +func legacyEncryptionKey() []byte { + hostname, _ := os.Hostname() + hash := sha256.Sum256([]byte(hostname + "priceforge-salt-2024")) return hash[:] } @@ -30,7 +91,10 @@ func Encrypt(plaintext string) (string, error) { return "", nil } - key := getEncryptionKey() + key, err := getEncryptionKey() + if err != nil { + return "", err + } block, err := aes.NewCipher(key) if err != nil { return "", err @@ -56,12 +120,23 @@ func Decrypt(ciphertext string) (string, error) { return "", nil } - key := getEncryptionKey() data, err := base64.StdEncoding.DecodeString(ciphertext) if err != nil { return "", err } + key, err := getEncryptionKey() + if err == nil { + plaintext, decryptErr := decryptWithKey(data, key) + if decryptErr == nil { + return plaintext, nil + } + } + + return decryptWithKey(data, legacyEncryptionKey()) +} + +func decryptWithKey(data []byte, key []byte) (string, error) { block, err := aes.NewCipher(key) if err != nil { return "", err diff --git a/internal/localdb/localdb.go b/internal/localdb/localdb.go index bcb2089..1b8a93b 100644 --- a/internal/localdb/localdb.go +++ b/internal/localdb/localdb.go @@ -43,9 +43,11 @@ type LocalDB struct { // New creates a new LocalDB instance func New(dbPath string) (*LocalDB, error) { + configureEncryption(filepath.Dir(dbPath)) + // Ensure directory exists dir := filepath.Dir(dbPath) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, 0700); err != nil { return nil, fmt.Errorf("creating data directory: %w", err) } diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 76d2662..bd7c011 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -1,18 +1,31 @@ package middleware import ( + "net" + "net/url" + "strings" + "github.com/gin-gonic/gin" ) func CORS() gin.HandlerFunc { return func(c *gin.Context) { - c.Header("Access-Control-Allow-Origin", "*") + if origin := strings.TrimSpace(c.GetHeader("Origin")); origin != "" { + if isLoopbackOrigin(origin) { + c.Header("Access-Control-Allow-Origin", origin) + c.Header("Vary", "Origin") + } + } c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization") c.Header("Access-Control-Expose-Headers", "Content-Length, Content-Disposition") c.Header("Access-Control-Max-Age", "86400") if c.Request.Method == "OPTIONS" { + if strings.TrimSpace(c.GetHeader("Origin")) != "" && !isLoopbackOrigin(c.GetHeader("Origin")) { + c.AbortWithStatus(403) + return + } c.AbortWithStatus(204) return } @@ -20,3 +33,19 @@ func CORS() gin.HandlerFunc { c.Next() } } + +func isLoopbackOrigin(origin string) bool { + u, err := url.Parse(origin) + if err != nil { + return false + } + host := strings.TrimSpace(u.Hostname()) + if host == "" { + return false + } + if strings.EqualFold(host, "localhost") { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} diff --git a/internal/middleware/origin_protection.go b/internal/middleware/origin_protection.go new file mode 100644 index 0000000..34d5234 --- /dev/null +++ b/internal/middleware/origin_protection.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +// OriginProtection blocks browser-driven cross-site requests to localhost. +// Same-origin UI requests continue to work; CLI clients without Origin/Referer remain allowed. +func OriginProtection() gin.HandlerFunc { + return func(c *gin.Context) { + if isSafeMethod(c.Request.Method) { + c.Next() + return + } + + if secFetchSite := strings.TrimSpace(c.GetHeader("Sec-Fetch-Site")); secFetchSite == "cross-site" { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "cross-site requests are not allowed"}) + return + } + + if origin := strings.TrimSpace(c.GetHeader("Origin")); origin != "" && !isLoopbackOrigin(origin) { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "origin must be localhost"}) + return + } + + if referer := strings.TrimSpace(c.GetHeader("Referer")); referer != "" && !isLoopbackOrigin(referer) { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "referer must be localhost"}) + return + } + + c.Next() + } +} + +func isSafeMethod(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + return true + default: + return false + } +}