Skip to content

Commit 7af1eae

Browse files
authored
Merge pull request #1093 from trheyi/main
Enhance knowledge base provider management with multi-language support
2 parents e6c5631 + 4bdd921 commit 7af1eae

File tree

9 files changed

+722
-308
lines changed

9 files changed

+722
-308
lines changed

kb/kb.go

Lines changed: 25 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ var Instance types.GraphRag = nil
2323

2424
// KnowledgeBase is the Knowledge Base instance
2525
type KnowledgeBase struct {
26-
Config *kbtypes.Config // Knowledge Base configuration
26+
Config *kbtypes.Config // Knowledge Base configuration
27+
Providers *kbtypes.ProviderConfig // Multi-language provider configurations
2728
*graphrag.GraphRag
2829
}
2930

@@ -52,6 +53,13 @@ func Load(appConfig config.Config) (*KnowledgeBase, error) {
5253
return nil, err
5354
}
5455

56+
// Load providers from directories
57+
providers, err := kbtypes.LoadProviders("kb")
58+
if err != nil {
59+
return nil, err
60+
}
61+
config.Providers = providers
62+
5563
// Set global configurations for providers to use
5664
kbtypes.SetGlobalPDF(config.PDF)
5765
kbtypes.SetGlobalFFmpeg(config.FFmpeg)
@@ -69,7 +77,7 @@ func Load(appConfig config.Config) (*KnowledgeBase, error) {
6977
}
7078

7179
// Set the instance
72-
instance := &KnowledgeBase{Config: &config, GraphRag: graphRag}
80+
instance := &KnowledgeBase{Config: &config, Providers: providers, GraphRag: graphRag}
7381

7482
// Set the instance to the global variable
7583
Instance = instance
@@ -88,48 +96,13 @@ func GetProviders(typ string, ids []string, locale string) ([]kbtypes.Provider,
8896
return nil, fmt.Errorf("knowledge base not initialized")
8997
}
9098

91-
// Get the configuration
92-
conf := knowledgeBase.Config
93-
if conf == nil {
94-
return nil, fmt.Errorf("configuration not found")
99+
// Default locale to "en" if empty
100+
if locale == "" {
101+
locale = "en"
95102
}
96103

97-
providers := []*kbtypes.Provider{}
98-
switch typ {
99-
case "chunking":
100-
providers = conf.Chunkings
101-
102-
case "converter":
103-
providers = conf.Converters
104-
105-
case "embedding":
106-
providers = conf.Embeddings
107-
108-
case "extractor":
109-
providers = conf.Extractors
110-
111-
case "fetcher":
112-
providers = conf.Fetchers
113-
114-
case "searcher":
115-
providers = conf.Searchers
116-
117-
case "reranker":
118-
providers = conf.Rerankers
119-
120-
case "vote":
121-
providers = conf.Votes
122-
123-
case "weight":
124-
providers = conf.Weights
125-
126-
case "score":
127-
providers = conf.Scores
128-
129-
default:
130-
return nil, fmt.Errorf("invalid provider type: %s", typ)
131-
132-
}
104+
// Get providers for the requested type and language
105+
providers := knowledgeBase.Providers.GetProviders(typ, locale)
133106

134107
// Filter empty ids
135108
filteredIds := []string{}
@@ -149,8 +122,13 @@ func GetProviders(typ string, ids []string, locale string) ([]kbtypes.Provider,
149122
return filteredProviders, nil
150123
}
151124

152-
// GetProvider returns a provider by id
125+
// GetProvider returns a provider by id with default language "en"
153126
func GetProvider(typ string, id string) (*kbtypes.Provider, error) {
127+
return GetProviderWithLanguage(typ, id, "en")
128+
}
129+
130+
// GetProviderWithLanguage returns a provider by id, type, and language
131+
func GetProviderWithLanguage(typ string, id string, locale string) (*kbtypes.Provider, error) {
154132
if Instance == nil {
155133
return nil, fmt.Errorf("knowledge base not initialized")
156134
}
@@ -160,53 +138,10 @@ func GetProvider(typ string, id string) (*kbtypes.Provider, error) {
160138
return nil, fmt.Errorf("knowledge base not initialized")
161139
}
162140

163-
conf := knowledgeBase.Config
164-
if conf == nil {
165-
return nil, fmt.Errorf("configuration not found")
166-
}
167-
168-
providers := []*kbtypes.Provider{}
169-
switch typ {
170-
case "chunking":
171-
providers = conf.Chunkings
172-
173-
case "converter":
174-
providers = conf.Converters
175-
176-
case "embedding":
177-
providers = conf.Embeddings
178-
179-
case "extractor":
180-
providers = conf.Extractors
181-
182-
case "fetcher":
183-
providers = conf.Fetchers
184-
185-
case "searcher":
186-
providers = conf.Searchers
187-
188-
case "reranker":
189-
providers = conf.Rerankers
190-
191-
case "vote":
192-
providers = conf.Votes
193-
194-
case "weight":
195-
providers = conf.Weights
196-
197-
case "score":
198-
providers = conf.Scores
199-
200-
default:
201-
return nil, fmt.Errorf("invalid provider type: %s", typ)
202-
}
203-
204-
// Find the provider by id
205-
for _, provider := range providers {
206-
if provider.ID == id {
207-
return provider, nil
208-
}
141+
// Default locale to "en" if empty
142+
if locale == "" {
143+
locale = "en"
209144
}
210145

211-
return nil, fmt.Errorf("provider %s not found", id)
146+
return knowledgeBase.Providers.GetProvider(typ, id, locale)
212147
}

kb/kb_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"testing"
55

66
"github.com/yaoapp/yao/config"
7+
kbtypes "github.com/yaoapp/yao/kb/types"
78
"github.com/yaoapp/yao/test"
89
)
910

@@ -12,8 +13,134 @@ func TestLoad(t *testing.T) {
1213
test.Prepare(&testing.T{}, config.Conf)
1314
defer test.Clean()
1415

16+
kb, err := Load(config.Conf)
17+
if err != nil {
18+
t.Fatalf("Failed to load knowledge base: %v", err)
19+
}
20+
21+
// Test that providers are loaded
22+
if kb != nil && kb.Providers != nil {
23+
t.Logf("Knowledge base loaded successfully with providers")
24+
}
25+
}
26+
27+
func TestGetProviders(t *testing.T) {
28+
// Setup
29+
test.Prepare(&testing.T{}, config.Conf)
30+
defer test.Clean()
31+
32+
_, err := Load(config.Conf)
33+
if err != nil {
34+
t.Fatalf("Failed to load knowledge base: %v", err)
35+
}
36+
37+
// Test getting providers for different languages
38+
testCases := []struct {
39+
providerType string
40+
locale string
41+
expectEmpty bool
42+
}{
43+
{"chunking", "en", false},
44+
{"embedding", "en", false},
45+
{"chunking", "zh-cn", false},
46+
{"embedding", "zh-cn", false},
47+
{"chunking", "nonexistent", false}, // Should fallback to "en"
48+
}
49+
50+
for _, tc := range testCases {
51+
providers, err := GetProviders(tc.providerType, []string{}, tc.locale)
52+
if err != nil {
53+
t.Errorf("Failed to get %s providers for locale %s: %v", tc.providerType, tc.locale, err)
54+
continue
55+
}
56+
57+
if tc.expectEmpty && len(providers) > 0 {
58+
t.Errorf("Expected empty providers for %s/%s, got %d", tc.providerType, tc.locale, len(providers))
59+
} else if !tc.expectEmpty && len(providers) == 0 {
60+
t.Logf("No providers found for %s/%s (this may be expected if no provider files exist)", tc.providerType, tc.locale)
61+
} else {
62+
t.Logf("Found %d providers for %s/%s", len(providers), tc.providerType, tc.locale)
63+
}
64+
}
65+
}
66+
67+
func TestGetProviderWithLanguage(t *testing.T) {
68+
// Setup
69+
test.Prepare(&testing.T{}, config.Conf)
70+
defer test.Clean()
71+
1572
_, err := Load(config.Conf)
1673
if err != nil {
1774
t.Fatalf("Failed to load knowledge base: %v", err)
1875
}
76+
77+
// Test getting a specific provider with language
78+
provider, err := GetProviderWithLanguage("chunking", "__yao.structured", "en")
79+
if err != nil {
80+
t.Logf("Provider __yao.structured not found for chunking/en: %v (this may be expected if provider files don't exist)", err)
81+
} else {
82+
t.Logf("Found provider: %s", provider.ID)
83+
}
84+
85+
// Test language fallback
86+
provider, err = GetProviderWithLanguage("chunking", "__yao.structured", "nonexistent")
87+
if err != nil {
88+
t.Logf("Provider __yao.structured not found with fallback: %v (this may be expected if provider files don't exist)", err)
89+
} else {
90+
t.Logf("Found provider with fallback: %s", provider.ID)
91+
}
92+
}
93+
94+
func TestLoadProviders(t *testing.T) {
95+
// Setup
96+
test.Prepare(&testing.T{}, config.Conf)
97+
defer test.Clean()
98+
99+
// Test loading providers from a directory
100+
providers, err := kbtypes.LoadProviders("kb")
101+
if err != nil {
102+
t.Fatalf("Failed to load providers: %v", err)
103+
}
104+
105+
if providers == nil {
106+
t.Fatal("Providers config is nil")
107+
}
108+
109+
// Check if provider maps are initialized
110+
if providers.Chunkings == nil {
111+
t.Error("Chunkings map is nil")
112+
}
113+
if providers.Embeddings == nil {
114+
t.Error("Embeddings map is nil")
115+
}
116+
117+
t.Logf("Loaded providers successfully")
118+
}
119+
120+
func TestProviderConfigGetProviders(t *testing.T) {
121+
// Setup
122+
test.Prepare(&testing.T{}, config.Conf)
123+
defer test.Clean()
124+
125+
providers, err := kbtypes.LoadProviders("kb")
126+
if err != nil {
127+
t.Fatalf("Failed to load providers: %v", err)
128+
}
129+
130+
// Test getting providers for different types and languages
131+
testCases := []string{"chunking", "embedding", "converter", "extractor", "fetcher"}
132+
133+
for _, providerType := range testCases {
134+
// Test with "en"
135+
enProviders := providers.GetProviders(providerType, "en")
136+
t.Logf("Found %d %s providers for 'en'", len(enProviders), providerType)
137+
138+
// Test with "zh-cn"
139+
zhProviders := providers.GetProviders(providerType, "zh-cn")
140+
t.Logf("Found %d %s providers for 'zh-cn'", len(zhProviders), providerType)
141+
142+
// Test with nonexistent language (should fallback to "en")
143+
fallbackProviders := providers.GetProviders(providerType, "nonexistent")
144+
t.Logf("Found %d %s providers for 'nonexistent' (fallback)", len(fallbackProviders), providerType)
145+
}
19146
}

kb/types/config.go

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -272,20 +272,31 @@ func (c *Config) resolveAllEnvVars() error {
272272

273273
// resolveProviderEnvVars resolves environment variables in provider configurations
274274
func (c *Config) resolveProviderEnvVars() error {
275-
providerLists := [][]*Provider{
276-
c.Chunkings, c.Embeddings, c.Converters, c.Extractors,
277-
c.Fetchers, c.Searchers, c.Rerankers, c.Votes, c.Weights, c.Scores,
278-
}
279-
280-
for _, providers := range providerLists {
281-
for _, provider := range providers {
282-
for _, option := range provider.Options {
283-
if option.Properties != nil {
284-
resolved, err := c.resolveEnvVars(option.Properties)
285-
if err != nil {
286-
return err
275+
if c.Providers == nil {
276+
return nil
277+
}
278+
279+
// Resolve env vars for all provider types and languages
280+
providerMaps := []map[string][]*Provider{
281+
c.Providers.Chunkings, c.Providers.Embeddings, c.Providers.Converters, c.Providers.Extractors,
282+
c.Providers.Fetchers, c.Providers.Searchers, c.Providers.Rerankers, c.Providers.Votes,
283+
c.Providers.Weights, c.Providers.Scores,
284+
}
285+
286+
for _, providerMap := range providerMaps {
287+
if providerMap == nil {
288+
continue
289+
}
290+
for _, providers := range providerMap {
291+
for _, provider := range providers {
292+
for _, option := range provider.Options {
293+
if option.Properties != nil {
294+
resolved, err := c.resolveEnvVars(option.Properties)
295+
if err != nil {
296+
return err
297+
}
298+
option.Properties = resolved
287299
}
288-
option.Properties = resolved
289300
}
290301
}
291302
}
@@ -335,8 +346,13 @@ func (c *Config) ComputeFeatures() Features {
335346

336347
// File format support (based on converters)
337348
converterMap := make(map[string]bool)
338-
for _, provider := range c.Converters {
339-
converterMap[provider.ID] = true
349+
if c.Providers != nil && c.Providers.Converters != nil {
350+
// Check all languages for converter availability
351+
for _, providers := range c.Providers.Converters {
352+
for _, provider := range providers {
353+
converterMap[provider.ID] = true
354+
}
355+
}
340356
}
341357

342358
features.PlainText = true // Plain text is always supported as a basic feature
@@ -346,13 +362,28 @@ func (c *Config) ComputeFeatures() Features {
346362
features.ImageAnalysis = converterMap["__yao.vision"]
347363

348364
// Advanced features
349-
features.EntityExtraction = len(c.Extractors) > 0
350-
features.WebFetching = len(c.Fetchers) > 0
351-
features.CustomSearch = len(c.Searchers) > 0
352-
features.ResultReranking = len(c.Rerankers) > 0
353-
features.SegmentVoting = len(c.Votes) > 0
354-
features.SegmentWeighting = len(c.Weights) > 0
355-
features.SegmentScoring = len(c.Scores) > 0
365+
if c.Providers != nil {
366+
features.EntityExtraction = c.hasProvidersInAnyLanguage(c.Providers.Extractors)
367+
features.WebFetching = c.hasProvidersInAnyLanguage(c.Providers.Fetchers)
368+
features.CustomSearch = c.hasProvidersInAnyLanguage(c.Providers.Searchers)
369+
features.ResultReranking = c.hasProvidersInAnyLanguage(c.Providers.Rerankers)
370+
features.SegmentVoting = c.hasProvidersInAnyLanguage(c.Providers.Votes)
371+
features.SegmentWeighting = c.hasProvidersInAnyLanguage(c.Providers.Weights)
372+
features.SegmentScoring = c.hasProvidersInAnyLanguage(c.Providers.Scores)
373+
}
356374

357375
return features
358376
}
377+
378+
// hasProvidersInAnyLanguage checks if there are providers available in any language
379+
func (c *Config) hasProvidersInAnyLanguage(providerMap map[string][]*Provider) bool {
380+
if providerMap == nil {
381+
return false
382+
}
383+
for _, providers := range providerMap {
384+
if len(providers) > 0 {
385+
return true
386+
}
387+
}
388+
return false
389+
}

0 commit comments

Comments
 (0)