diff --git a/internal/article/categories.go b/internal/article/categories.go index 85ba390..f611b58 100644 --- a/internal/article/categories.go +++ b/internal/article/categories.go @@ -71,13 +71,31 @@ func ResolveLotCategoriesStrict(local *localdb.LocalDB, serverPricelistID uint, if err != nil { return nil, err } + missing := make([]string, 0) for _, lot := range lotNames { cat := strings.TrimSpace(cats[lot]) if cat == "" { - return nil, &MissingCategoryForLotError{LotName: lot} + missing = append(missing, lot) + continue } cats[lot] = cat } + if len(missing) > 0 { + fallback, err := local.GetLocalComponentCategoriesByLotNames(missing) + if err != nil { + return nil, err + } + for _, lot := range missing { + if cat := strings.TrimSpace(fallback[lot]); cat != "" { + cats[lot] = cat + } + } + for _, lot := range missing { + if strings.TrimSpace(cats[lot]) == "" { + return nil, &MissingCategoryForLotError{LotName: lot} + } + } + } return cats, nil } diff --git a/internal/article/categories_test.go b/internal/article/categories_test.go index dbdef64..eb8c027 100644 --- a/internal/article/categories_test.go +++ b/internal/article/categories_test.go @@ -45,6 +45,49 @@ func TestResolveLotCategoriesStrict_MissingCategoryReturnsError(t *testing.T) { } } +func TestResolveLotCategoriesStrict_FallbackToLocalComponents(t *testing.T) { + local, err := localdb.New(filepath.Join(t.TempDir(), "local.db")) + if err != nil { + t.Fatalf("init local db: %v", err) + } + t.Cleanup(func() { _ = local.Close() }) + + if err := local.SaveLocalPricelist(&localdb.LocalPricelist{ + ServerID: 2, + Source: "estimate", + Version: "S-2026-02-11-002", + Name: "test", + CreatedAt: time.Now(), + SyncedAt: time.Now(), + }); err != nil { + t.Fatalf("save local pricelist: %v", err) + } + localPL, err := local.GetLocalPricelistByServerID(2) + if err != nil { + t.Fatalf("get local pricelist: %v", err) + } + if err := local.SaveLocalPricelistItems([]localdb.LocalPricelistItem{ + {PricelistID: localPL.ID, LotName: "CPU_B", LotCategory: "", Price: 10}, + }); err != nil { + t.Fatalf("save local items: %v", err) + } + if err := local.DB().Create(&localdb.LocalComponent{ + LotName: "CPU_B", + Category: "CPU", + LotDescription: "cpu", + }).Error; err != nil { + t.Fatalf("save local components: %v", err) + } + + cats, err := ResolveLotCategoriesStrict(local, 2, []string{"CPU_B"}) + if err != nil { + t.Fatalf("expected fallback, got error: %v", err) + } + if cats["CPU_B"] != "CPU" { + t.Fatalf("expected CPU, got %q", cats["CPU_B"]) + } +} + func TestGroupForLotCategory(t *testing.T) { if g, ok := GroupForLotCategory("cpu"); !ok || g != GroupCPU { t.Fatalf("expected cpu -> GroupCPU") @@ -53,4 +96,3 @@ func TestGroupForLotCategory(t *testing.T) { t.Fatalf("expected SFP to be excluded") } } - diff --git a/internal/localdb/components.go b/internal/localdb/components.go index a199e4f..013edb6 100644 --- a/internal/localdb/components.go +++ b/internal/localdb/components.go @@ -242,6 +242,31 @@ func (l *LocalDB) GetLocalComponent(lotName string) (*LocalComponent, error) { return &component, nil } +// GetLocalComponentCategoriesByLotNames returns category for each lot_name in the local component cache. +// Missing lots are not included in the map; caller is responsible for strict validation. +func (l *LocalDB) GetLocalComponentCategoriesByLotNames(lotNames []string) (map[string]string, error) { + result := make(map[string]string, len(lotNames)) + if len(lotNames) == 0 { + return result, nil + } + + type row struct { + LotName string `gorm:"column:lot_name"` + Category string `gorm:"column:category"` + } + var rows []row + if err := l.db.Model(&LocalComponent{}). + Select("lot_name, category"). + Where("lot_name IN ?", lotNames). + Find(&rows).Error; err != nil { + return nil, err + } + for _, r := range rows { + result[r.LotName] = r.Category + } + return result, nil +} + // GetLocalComponentCategories returns distinct categories from local components func (l *LocalDB) GetLocalComponentCategories() ([]string, error) { var categories []string @@ -302,4 +327,3 @@ func (l *LocalDB) NeedComponentSync(maxAgeHours int) bool { } return time.Since(*syncTime).Hours() > float64(maxAgeHours) } -