Files
PriceForge/internal/localdb/encryption.go
2026-03-07 22:14:31 +03:00

163 lines
3.6 KiB
Go

package localdb
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
)
const encryptionKeyFileName = "local_secret.key"
var (
encryptionStateMu sync.RWMutex
encryptionKeyDir string
)
func configureEncryption(dir string) {
encryptionStateMu.Lock()
defer encryptionStateMu.Unlock()
encryptionKeyDir = dir
}
func getEncryptionKey() ([]byte, error) {
if key := strings.TrimSpace(os.Getenv("PRICEFORGE_ENCRYPTION_KEY")); key != "" {
hash := sha256.Sum256([]byte(key))
return hash[:], nil
}
if key := strings.TrimSpace(os.Getenv("QUOTEFORGE_ENCRYPTION_KEY")); key != "" {
hash := sha256.Sum256([]byte(key))
return hash[:], nil
}
return loadOrCreateKeyFile()
}
func loadOrCreateKeyFile() ([]byte, error) {
keyPath, err := encryptionKeyPath()
if err != nil {
return nil, err
}
if err := os.MkdirAll(filepath.Dir(keyPath), 0700); err != nil {
return nil, fmt.Errorf("create encryption key dir: %w", err)
}
if data, err := os.ReadFile(keyPath); err == nil {
trimmed := strings.TrimSpace(string(data))
if trimmed == "" {
return nil, fmt.Errorf("encryption key file %s is empty", keyPath)
}
hash := sha256.Sum256([]byte(trimmed))
return hash[:], nil
}
raw := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, raw); err != nil {
return nil, fmt.Errorf("generate encryption key: %w", err)
}
encoded := base64.StdEncoding.EncodeToString(raw)
if err := os.WriteFile(keyPath, []byte(encoded), 0600); err != nil {
return nil, fmt.Errorf("write encryption key file: %w", err)
}
hash := sha256.Sum256([]byte(encoded))
return hash[:], nil
}
func encryptionKeyPath() (string, error) {
encryptionStateMu.RLock()
dir := encryptionKeyDir
encryptionStateMu.RUnlock()
if strings.TrimSpace(dir) == "" {
return "", errors.New("encryption key directory is not configured")
}
return filepath.Join(dir, encryptionKeyFileName), nil
}
func legacyEncryptionKey() []byte {
hostname, _ := os.Hostname()
hash := sha256.Sum256([]byte(hostname + "priceforge-salt-2024"))
return hash[:]
}
// Encrypt encrypts plaintext using AES-256-GCM
func Encrypt(plaintext string) (string, error) {
if plaintext == "" {
return "", nil
}
key, err := getEncryptionKey()
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts ciphertext that was encrypted with Encrypt
func Decrypt(ciphertext string) (string, error) {
if ciphertext == "" {
return "", nil
}
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", err
}
key, err := getEncryptionKey()
if err == nil {
plaintext, decryptErr := decryptWithKey(data, key)
if decryptErr == nil {
return plaintext, nil
}
}
return decryptWithKey(data, legacyEncryptionKey())
}
func decryptWithKey(data []byte, key []byte) (string, error) {
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", errors.New("ciphertext too short")
}
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}