diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 76d2662..85311dc 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -1,22 +1,55 @@ package middleware import ( + "net" + "net/http" + "net/url" + "strings" + "github.com/gin-gonic/gin" ) func CORS() gin.HandlerFunc { return func(c *gin.Context) { - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") - c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization") - c.Header("Access-Control-Expose-Headers", "Content-Length, Content-Disposition") - c.Header("Access-Control-Max-Age", "86400") + origin := strings.TrimSpace(c.GetHeader("Origin")) + if origin != "" { + if isLoopbackOrigin(origin) { + c.Header("Access-Control-Allow-Origin", origin) + c.Header("Vary", "Origin") + c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization") + c.Header("Access-Control-Expose-Headers", "Content-Length, Content-Disposition") + c.Header("Access-Control-Max-Age", "86400") + } else if c.Request.Method == http.MethodOptions { + c.AbortWithStatus(http.StatusForbidden) + return + } + } - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(204) + if c.Request.Method == http.MethodOptions { + c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } + +func isLoopbackOrigin(origin string) bool { + u, err := url.Parse(origin) + if err != nil { + return false + } + if u.Scheme != "http" && u.Scheme != "https" { + return false + } + host := strings.TrimSpace(u.Hostname()) + if host == "" { + return false + } + if strings.EqualFold(host, "localhost") { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} diff --git a/internal/repository/pricelist.go b/internal/repository/pricelist.go index 2152084..d885a92 100644 --- a/internal/repository/pricelist.go +++ b/internal/repository/pricelist.go @@ -3,10 +3,12 @@ package repository import ( "errors" "fmt" + "sort" "strconv" "strings" "time" + "git.mchus.pro/mchus/quoteforge/internal/lotmatch" "git.mchus.pro/mchus/quoteforge/internal/models" "gorm.io/gorm" ) @@ -243,9 +245,91 @@ func (r *PricelistRepository) GetItems(pricelistID uint, offset, limit int, sear } } + if err := r.enrichItemsWithStock(items); err != nil { + return nil, 0, fmt.Errorf("enriching pricelist items with stock: %w", err) + } + return items, total, nil } +func (r *PricelistRepository) enrichItemsWithStock(items []models.PricelistItem) error { + if len(items) == 0 { + return nil + } + + resolver, err := lotmatch.NewLotResolverFromDB(r.db) + if err != nil { + return err + } + + type stockRow struct { + Partnumber string `gorm:"column:partnumber"` + Qty *float64 `gorm:"column:qty"` + } + rows := make([]stockRow, 0) + if err := r.db.Raw(` + SELECT s.partnumber, s.qty + FROM stock_log s + INNER JOIN ( + SELECT partnumber, MAX(date) AS max_date + FROM stock_log + GROUP BY partnumber + ) latest ON latest.partnumber = s.partnumber AND latest.max_date = s.date + WHERE s.qty IS NOT NULL + `).Scan(&rows).Error; err != nil { + return err + } + + lotTotals := make(map[string]float64, len(items)) + lotPartnumbers := make(map[string][]string, len(items)) + seenPartnumbers := make(map[string]map[string]struct{}, len(items)) + + for i := range rows { + row := rows[i] + if strings.TrimSpace(row.Partnumber) == "" { + continue + } + lotName, _, resolveErr := resolver.Resolve(row.Partnumber) + if resolveErr != nil || strings.TrimSpace(lotName) == "" { + continue + } + + if row.Qty != nil { + lotTotals[lotName] += *row.Qty + } + + pn := strings.TrimSpace(row.Partnumber) + if pn == "" { + continue + } + if _, ok := seenPartnumbers[lotName]; !ok { + seenPartnumbers[lotName] = make(map[string]struct{}, 4) + } + key := strings.ToLower(pn) + if _, exists := seenPartnumbers[lotName][key]; exists { + continue + } + seenPartnumbers[lotName][key] = struct{}{} + lotPartnumbers[lotName] = append(lotPartnumbers[lotName], pn) + } + + for i := range items { + lotName := items[i].LotName + if qty, ok := lotTotals[lotName]; ok { + qtyCopy := qty + items[i].AvailableQty = &qtyCopy + } + if partnumbers := lotPartnumbers[lotName]; len(partnumbers) > 0 { + sort.Slice(partnumbers, func(a, b int) bool { + return strings.ToLower(partnumbers[a]) < strings.ToLower(partnumbers[b]) + }) + items[i].Partnumbers = partnumbers + } + } + + return nil +} + // GetLotNames returns distinct lot names from pricelist items. func (r *PricelistRepository) GetLotNames(pricelistID uint) ([]string, error) { var lotNames []string