Harden local runtime safety and error handling

This commit is contained in:
Mikhail Chusavitin
2026-03-15 16:28:32 +03:00
parent f0e6bba7e9
commit c964d66e64
25 changed files with 726 additions and 245 deletions

View File

@@ -64,3 +64,28 @@ logging:
t.Fatalf("migrated config did not preserve logging level:\n%s", text)
}
}
func TestEnsureLoopbackServerHost(t *testing.T) {
t.Parallel()
cases := []struct {
host string
wantErr bool
}{
{host: "127.0.0.1", wantErr: false},
{host: "localhost", wantErr: false},
{host: "::1", wantErr: false},
{host: "0.0.0.0", wantErr: true},
{host: "192.168.1.10", wantErr: true},
}
for _, tc := range cases {
err := ensureLoopbackServerHost(tc.host)
if tc.wantErr && err == nil {
t.Fatalf("expected error for host %q", tc.host)
}
if !tc.wantErr && err != nil {
t.Fatalf("unexpected error for host %q: %v", tc.host, err)
}
}
}

View File

@@ -10,6 +10,7 @@ import (
"io/fs"
"log/slog"
"math"
"net"
"net/http"
"os"
"os/exec"
@@ -43,11 +44,16 @@ import (
// Version is set via ldflags during build
var Version = "dev"
var errVendorImportTooLarge = errors.New("vendor workspace file exceeds 1 GiB limit")
const backgroundSyncInterval = 5 * time.Minute
const onDemandPullCooldown = 30 * time.Second
const startupConsoleWarning = "Не закрывайте консоль иначе приложение не будет работать"
var vendorImportMaxBytes int64 = 1 << 30
const vendorImportMultipartOverheadBytes int64 = 8 << 20
func main() {
showStartupConsoleWarning()
@@ -142,6 +148,10 @@ func main() {
}
}
setConfigDefaults(cfg)
if err := ensureLoopbackServerHost(cfg.Server.Host); err != nil {
slog.Error("invalid server host", "host", cfg.Server.Host, "error", err)
os.Exit(1)
}
if err := migrateConfigFileToRuntimeShape(resolvedConfigPath, cfg); err != nil {
slog.Error("failed to migrate config file format", "path", resolvedConfigPath, "error", err)
os.Exit(1)
@@ -342,6 +352,35 @@ func setConfigDefaults(cfg *config.Config) {
}
}
func ensureLoopbackServerHost(host string) error {
trimmed := strings.TrimSpace(host)
if trimmed == "" {
return fmt.Errorf("server.host must not be empty")
}
if strings.EqualFold(trimmed, "localhost") {
return nil
}
ip := net.ParseIP(strings.Trim(trimmed, "[]"))
if ip != nil && ip.IsLoopback() {
return nil
}
return fmt.Errorf("QuoteForge local client must bind to localhost only")
}
func vendorImportBodyLimit() int64 {
return vendorImportMaxBytes + vendorImportMultipartOverheadBytes
}
func isVendorImportTooLarge(fileSize int64, err error) bool {
if fileSize > vendorImportMaxBytes {
return true
}
var maxBytesErr *http.MaxBytesError
return errors.As(err, &maxBytesErr)
}
func ensureDefaultConfigFile(configPath string) error {
if strings.TrimSpace(configPath) == "" {
return fmt.Errorf("config path is empty")
@@ -747,6 +786,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
pricelistHandler := handlers.NewPricelistHandler(local)
vendorSpecHandler := handlers.NewVendorSpecHandler(local)
partnumberBooksHandler := handlers.NewPartnumberBooksHandler(local)
respondError := handlers.RespondError
syncHandler, err := handlers.NewSyncHandler(local, syncService, connMgr, templatesPath, backgroundSyncInterval)
if err != nil {
return nil, nil, fmt.Errorf("creating sync handler: %w", err)
@@ -766,6 +806,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
// Router
router := gin.New()
router.MaxMultipartMemory = vendorImportBodyLimit()
router.Use(gin.Recovery())
router.Use(requestLogger())
router.Use(middleware.CORS())
@@ -786,17 +827,17 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
})
})
// Restart endpoint (for development purposes)
router.POST("/api/restart", func(c *gin.Context) {
// This will cause the server to restart by exiting
// The restartProcess function will be called to restart the process
slog.Info("Restart requested via API")
go func() {
time.Sleep(100 * time.Millisecond)
restartProcess()
}()
c.JSON(http.StatusOK, gin.H{"message": "restarting..."})
})
// Restart endpoint is intentionally debug-only.
if cfg.Server.Mode == "debug" {
router.POST("/api/restart", func(c *gin.Context) {
slog.Info("Restart requested via API")
go func() {
time.Sleep(100 * time.Millisecond)
restartProcess()
}()
c.JSON(http.StatusOK, gin.H{"message": "restarting..."})
})
}
// DB status endpoint
router.GET("/api/db-status", func(c *gin.Context) {
@@ -928,7 +969,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
cfgs, total, err := configService.ListAllWithStatus(page, perPage, status, search)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
@@ -949,7 +990,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Database is offline"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
c.JSON(http.StatusOK, result)
@@ -958,13 +999,13 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
configs.POST("", func(c *gin.Context) {
var req services.CreateConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
config, err := configService.Create(dbUsername, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
@@ -974,12 +1015,12 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
configs.POST("/preview-article", func(c *gin.Context) {
var req services.ArticlePreviewRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
result, err := configService.BuildArticlePreview(&req)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -1002,7 +1043,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
uuid := c.Param("uuid")
var req services.CreateConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
@@ -1010,13 +1051,13 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err != nil {
switch {
case errors.Is(err, services.ErrConfigNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
respondError(c, http.StatusForbidden, "access denied", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1027,7 +1068,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
configs.DELETE("/:uuid", func(c *gin.Context) {
uuid := c.Param("uuid")
if err := configService.DeleteNoAuth(uuid); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "archived"})
@@ -1037,7 +1078,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
uuid := c.Param("uuid")
config, err := configService.ReactivateNoAuth(uuid)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -1052,13 +1093,13 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
Name string `json:"name"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
config, err := configService.RenameNoAuth(uuid, req.Name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
@@ -1072,7 +1113,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
FromVersion int `json:"from_version"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
@@ -1082,7 +1123,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
c.JSON(http.StatusNotFound, gin.H{"error": "version not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
@@ -1093,7 +1134,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
uuid := c.Param("uuid")
config, err := configService.RefreshPricesNoAuth(uuid)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
c.JSON(http.StatusOK, config)
@@ -1105,20 +1146,20 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
ProjectUUID string `json:"project_uuid"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
updated, err := configService.SetProjectNoAuth(uuid, req.ProjectUUID)
if err != nil {
switch {
case errors.Is(err, services.ErrConfigNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
respondError(c, http.StatusForbidden, "access denied", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1147,7 +1188,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
case errors.Is(err, services.ErrInvalidVersionNumber):
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid paging params"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1175,7 +1216,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
case errors.Is(err, services.ErrConfigVersionNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": "version not found"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1190,7 +1231,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
Note string `json:"note"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
if req.TargetVersion <= 0 {
@@ -1208,7 +1249,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
case errors.Is(err, services.ErrVersionConflict):
c.JSON(http.StatusConflict, gin.H{"error": "version conflict"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1243,12 +1284,12 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
ServerCount int `json:"server_count" binding:"required,min=1"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
config, err := configService.UpdateServerCount(uuid, req.ServerCount)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
c.JSON(http.StatusOK, config)
@@ -1293,7 +1334,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
allProjects, err := projectService.ListByUser(dbUsername, true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
@@ -1427,7 +1468,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
projects.GET("/all", func(c *gin.Context) {
allProjects, err := projectService.ListByUser(dbUsername, true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
@@ -1457,7 +1498,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
projects.POST("", func(c *gin.Context) {
var req services.CreateProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
if strings.TrimSpace(req.Code) == "" {
@@ -1468,9 +1509,9 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err != nil {
switch {
case errors.Is(err, services.ErrProjectCodeExists):
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
respondError(c, http.StatusConflict, "conflict detected", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1482,11 +1523,11 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err != nil {
switch {
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
respondError(c, http.StatusForbidden, "access denied", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1496,20 +1537,20 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
projects.PUT("/:uuid", func(c *gin.Context) {
var req services.UpdateProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
project, err := projectService.Update(c.Param("uuid"), dbUsername, &req)
if err != nil {
switch {
case errors.Is(err, services.ErrProjectCodeExists):
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
respondError(c, http.StatusConflict, "conflict detected", err)
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
respondError(c, http.StatusForbidden, "access denied", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1520,11 +1561,11 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err := projectService.Archive(c.Param("uuid"), dbUsername); err != nil {
switch {
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
respondError(c, http.StatusForbidden, "access denied", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1535,11 +1576,11 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err := projectService.Reactivate(c.Param("uuid"), dbUsername); err != nil {
switch {
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
respondError(c, http.StatusForbidden, "access denied", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1550,13 +1591,13 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err := projectService.DeleteVariant(c.Param("uuid"), dbUsername); err != nil {
switch {
case errors.Is(err, services.ErrCannotDeleteMainVariant):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
respondError(c, http.StatusForbidden, "access denied", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1576,11 +1617,11 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err != nil {
switch {
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
case errors.Is(err, services.ErrProjectForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
respondError(c, http.StatusForbidden, "access denied", err)
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
}
return
}
@@ -1593,7 +1634,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
OrderedUUIDs []string `json:"ordered_uuids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
if len(req.OrderedUUIDs) == 0 {
@@ -1605,9 +1646,9 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err != nil {
switch {
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
default:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
}
return
}
@@ -1628,7 +1669,7 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
projects.POST("/:uuid/configs", func(c *gin.Context) {
var req services.CreateConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
projectUUID := c.Param("uuid")
@@ -1636,29 +1677,42 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
config, err := configService.Create(dbUsername, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
c.JSON(http.StatusCreated, config)
})
projects.POST("/:uuid/vendor-import", func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, vendorImportBodyLimit())
fileHeader, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "file is required"})
if isVendorImportTooLarge(0, err) {
respondError(c, http.StatusBadRequest, "vendor workspace file exceeds 1 GiB limit", errVendorImportTooLarge)
return
}
respondError(c, http.StatusBadRequest, "file is required", err)
return
}
if isVendorImportTooLarge(fileHeader.Size, nil) {
respondError(c, http.StatusBadRequest, "vendor workspace file exceeds 1 GiB limit", errVendorImportTooLarge)
return
}
file, err := fileHeader.Open()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to open uploaded file"})
respondError(c, http.StatusBadRequest, "failed to open uploaded file", err)
return
}
defer file.Close()
data, err := io.ReadAll(file)
data, err := io.ReadAll(io.LimitReader(file, vendorImportMaxBytes+1))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read uploaded file"})
respondError(c, http.StatusBadRequest, "failed to read uploaded file", err)
return
}
if int64(len(data)) > vendorImportMaxBytes {
respondError(c, http.StatusBadRequest, "vendor workspace file exceeds 1 GiB limit", errVendorImportTooLarge)
return
}
if !services.IsCFXMLWorkspace(data) {
@@ -1670,9 +1724,9 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
if err != nil {
switch {
case errors.Is(err, services.ErrProjectNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
respondError(c, http.StatusNotFound, "resource not found", err)
default:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
}
return
}
@@ -1688,14 +1742,14 @@ func setupRouter(cfg *config.Config, local *localdb.LocalDB, connMgr *db.Connect
Name string `json:"name"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
respondError(c, http.StatusBadRequest, "invalid request", err)
return
}
projectUUID := c.Param("uuid")
config, err := configService.CloneNoAuthToProject(c.Param("config_uuid"), req.Name, dbUsername, &projectUUID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
respondError(c, http.StatusInternalServerError, "internal server error", err)
return
}
c.JSON(http.StatusCreated, config)
@@ -1769,22 +1823,12 @@ func requestLogger() gin.HandlerFunc {
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
blw := &captureResponseWriter{
ResponseWriter: c.Writer,
body: bytes.NewBuffer(nil),
}
c.Writer = blw
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
if status >= http.StatusBadRequest {
responseBody := strings.TrimSpace(blw.body.String())
if len(responseBody) > 2048 {
responseBody = responseBody[:2048] + "...(truncated)"
}
errText := strings.TrimSpace(c.Errors.String())
slog.Error("request failed",
@@ -1795,7 +1839,6 @@ func requestLogger() gin.HandlerFunc {
"latency", latency,
"ip", c.ClientIP(),
"errors", errText,
"response", responseBody,
)
return
}
@@ -1810,22 +1853,3 @@ func requestLogger() gin.HandlerFunc {
)
}
}
type captureResponseWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
func (w *captureResponseWriter) Write(b []byte) (int, error) {
if len(b) > 0 {
_, _ = w.body.Write(b)
}
return w.ResponseWriter.Write(b)
}
func (w *captureResponseWriter) WriteString(s string) (int, error) {
if s != "" {
_, _ = w.body.WriteString(s)
}
return w.ResponseWriter.WriteString(s)
}

View File

@@ -0,0 +1,48 @@
package main
import (
"bytes"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestRequestLoggerDoesNotLogResponseBody(t *testing.T) {
gin.SetMode(gin.TestMode)
var logBuffer bytes.Buffer
previousLogger := slog.Default()
slog.SetDefault(slog.New(slog.NewTextHandler(&logBuffer, &slog.HandlerOptions{})))
defer slog.SetDefault(previousLogger)
router := gin.New()
router.Use(requestLogger())
router.GET("/fail", func(c *gin.Context) {
_ = c.Error(errors.New("root cause"))
c.JSON(http.StatusBadRequest, gin.H{"error": "do not log this body"})
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/fail?debug=1", nil)
router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", rec.Code)
}
logOutput := logBuffer.String()
if !strings.Contains(logOutput, "request failed") {
t.Fatalf("expected request failure log, got %q", logOutput)
}
if strings.Contains(logOutput, "do not log this body") {
t.Fatalf("response body leaked into logs: %q", logOutput)
}
if !strings.Contains(logOutput, "root cause") {
t.Fatalf("expected error details in logs, got %q", logOutput)
}
}

View File

@@ -3,10 +3,12 @@ package main
import (
"bytes"
"encoding/json"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"git.mchus.pro/mchus/quoteforge/internal/config"
@@ -290,6 +292,88 @@ func TestConfigMoveToProjectEndpoint(t *testing.T) {
}
}
func TestVendorImportRejectsOversizedUpload(t *testing.T) {
moveToRepoRoot(t)
prevLimit := vendorImportMaxBytes
vendorImportMaxBytes = 128
defer func() { vendorImportMaxBytes = prevLimit }()
local, connMgr, _ := newAPITestStack(t)
cfg := &config.Config{}
setConfigDefaults(cfg)
router, _, err := setupRouter(cfg, local, connMgr, "tester", nil)
if err != nil {
t.Fatalf("setup router: %v", err)
}
createProjectReq := httptest.NewRequest(http.MethodPost, "/api/projects", bytes.NewReader([]byte(`{"name":"Import Project","code":"IMP"}`)))
createProjectReq.Header.Set("Content-Type", "application/json")
createProjectRec := httptest.NewRecorder()
router.ServeHTTP(createProjectRec, createProjectReq)
if createProjectRec.Code != http.StatusCreated {
t.Fatalf("create project status=%d body=%s", createProjectRec.Code, createProjectRec.Body.String())
}
var project models.Project
if err := json.Unmarshal(createProjectRec.Body.Bytes(), &project); err != nil {
t.Fatalf("unmarshal project: %v", err)
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
part, err := writer.CreateFormFile("file", "huge.xml")
if err != nil {
t.Fatalf("create form file: %v", err)
}
payload := "<CFXML>" + strings.Repeat("A", int(vendorImportMaxBytes)+1) + "</CFXML>"
if _, err := part.Write([]byte(payload)); err != nil {
t.Fatalf("write multipart payload: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close multipart writer: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/api/projects/"+project.UUID+"/vendor-import", &body)
req.Header.Set("Content-Type", writer.FormDataContentType())
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for oversized upload, got %d body=%s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "1 GiB") {
t.Fatalf("expected size limit message, got %s", rec.Body.String())
}
}
func TestCreateConfigMalformedJSONReturnsGenericError(t *testing.T) {
moveToRepoRoot(t)
local, connMgr, _ := newAPITestStack(t)
cfg := &config.Config{}
setConfigDefaults(cfg)
router, _, err := setupRouter(cfg, local, connMgr, "tester", nil)
if err != nil {
t.Fatalf("setup router: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/api/configs", bytes.NewReader([]byte(`{"name":`)))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for malformed json, got %d body=%s", rec.Code, rec.Body.String())
}
if strings.Contains(strings.ToLower(rec.Body.String()), "unexpected eof") {
t.Fatalf("expected sanitized error body, got %s", rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "invalid request") {
t.Fatalf("expected generic invalid request message, got %s", rec.Body.String())
}
}
func newAPITestStack(t *testing.T) (*localdb.LocalDB, *db.ConnectionManager, *services.LocalConfigurationService) {
t.Helper()