Skip to content

Commit 2fe25d9

Browse files
authored
Merge pull request #64 from iflytek/feat/statedriven&reconciler
add unit tests for basic components
2 parents 94d0bfe + 8a633b6 commit 2fe25d9

File tree

23 files changed

+5477
-12
lines changed

23 files changed

+5477
-12
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ require (
5959
github.com/spf13/afero v1.15.0 // indirect
6060
github.com/spf13/cast v1.10.0 // indirect
6161
github.com/spf13/pflag v1.0.10 // indirect
62+
github.com/stretchr/objx v0.5.2 // indirect
6263
github.com/subosito/gotenv v1.6.0 // indirect
6364
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
6465
github.com/ugorji/go/codec v1.3.0 // indirect

internal/core/orchestrator/orchestrator.go

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@ import (
1313
"fmt"
1414
)
1515

16+
// CurrentShimletGetter 定义获取当前shimlet ID的函数类型
17+
type CurrentShimletGetter func() string
18+
1619
type Orchestrator struct {
17-
shimReg *typereg.TypeReg[shimlet.Shimlet]
18-
goalSetReg map[string]*goal.GoalSet
19-
specStore spec.Store
20-
queue *workqueue.Queue
20+
shimReg *typereg.TypeReg[shimlet.Shimlet]
21+
goalSetReg map[string]*goal.GoalSet
22+
specStore spec.Store
23+
queue *workqueue.Queue
24+
currentShimletGetter CurrentShimletGetter
2125
}
2226

2327
func NewOrchestrator(
@@ -27,10 +31,31 @@ func NewOrchestrator(
2731
specStore spec.Store,
2832
) *Orchestrator {
2933
return &Orchestrator{
30-
queue: queue,
31-
shimReg: shimReg,
32-
goalSetReg: pipeReg,
33-
specStore: specStore,
34+
queue: queue,
35+
shimReg: shimReg,
36+
goalSetReg: pipeReg,
37+
specStore: specStore,
38+
currentShimletGetter: func() string {
39+
return config.Get().CurrentShimlet
40+
},
41+
}
42+
}
43+
44+
// NewOrchestratorWithShimletGetter 创建一个可以自定义currentShimletGetter的Orchestrator实例
45+
// 主要用于测试
46+
func NewOrchestratorWithShimletGetter(
47+
shimReg *typereg.TypeReg[shimlet.Shimlet],
48+
pipeReg map[string]*goal.GoalSet,
49+
queue *workqueue.Queue,
50+
specStore spec.Store,
51+
currentShimletGetter CurrentShimletGetter,
52+
) *Orchestrator {
53+
return &Orchestrator{
54+
queue: queue,
55+
shimReg: shimReg,
56+
goalSetReg: pipeReg,
57+
specStore: specStore,
58+
currentShimletGetter: currentShimletGetter,
3459
}
3560
}
3661

@@ -46,7 +71,7 @@ func (o *Orchestrator) Provision(spec *dto.RequirementSpec) error {
4671

4772
// RequirementSpec 持久化 部署期望
4873
spec.ReplicaCount = 1
49-
spec.ShimletName = config.Get().CurrentShimlet
74+
spec.ShimletName = o.currentShimletGetter()
5075
// 如果这里是更新, 则需要 对应goalset reconcile 检测到 不一致 并调用ensure 闭环
5176
o.specStore.Set(spec.ServiceId, spec)
5277

@@ -59,7 +84,7 @@ func (o *Orchestrator) Provision(spec *dto.RequirementSpec) error {
5984
// DeleteService 删除指定的模型服务
6085
func (o *Orchestrator) DeleteService(serviceID string) error {
6186
// 获取当前使用的shimlet
62-
currentShimletId := config.Get().CurrentShimlet
87+
currentShimletId := o.currentShimletGetter()
6388
runtimeShimlet, err := o.shimReg.GetSingleton(currentShimletId)
6489
if err != nil {
6590
log.Error("get runtime shimlet error", err)
@@ -82,7 +107,7 @@ func (o *Orchestrator) GetServiceStatus(serviceID string) (*dto.RuntimeStatus, e
82107
return nil, fmt.Errorf("serviceID is required")
83108
}
84109

85-
currentShimletId := config.Get().CurrentShimlet
110+
currentShimletId := o.currentShimletGetter()
86111
runtimeShimlet, err := o.shimReg.GetSingleton(currentShimletId)
87112
if err != nil {
88113
return nil, err
@@ -95,4 +120,4 @@ func (o *Orchestrator) GetServiceStatus(serviceID string) (*dto.RuntimeStatus, e
95120
status.EndPoint += "/v1/chat/completions"
96121
}
97122
return status, nil
98-
}
123+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,131 @@
11
package orchestrator
2+
3+
import (
4+
"astron-xmod-shim/internal/config"
5+
"astron-xmod-shim/internal/core/goal"
6+
"astron-xmod-shim/internal/core/shimlet"
7+
"astron-xmod-shim/internal/core/spec"
8+
"astron-xmod-shim/internal/core/typereg"
9+
"astron-xmod-shim/internal/core/workqueue"
10+
dto "astron-xmod-shim/internal/dto/deploy"
11+
"os"
12+
"testing"
13+
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/mock"
16+
)
17+
18+
// MockShimlet 是shimlet.Shimlet接口的简化mock实现
19+
type MockShimlet struct {
20+
mock.Mock
21+
}
22+
23+
func (m *MockShimlet) ID() string {
24+
args := m.Called()
25+
return args.String(0)
26+
}
27+
28+
func (m *MockShimlet) Description() string {
29+
args := m.Called()
30+
return args.String(0)
31+
}
32+
33+
func (m *MockShimlet) InitWithConfig(confPath string) error {
34+
args := m.Called(confPath)
35+
return args.Error(0)
36+
}
37+
38+
func (m *MockShimlet) Apply(spec *dto.RequirementSpec) error {
39+
args := m.Called(spec)
40+
return args.Error(0)
41+
}
42+
43+
func (m *MockShimlet) Delete(resourceId string) error {
44+
args := m.Called(resourceId)
45+
return args.Error(0)
46+
}
47+
48+
func (m *MockShimlet) Status(resourceId string) (*dto.RuntimeStatus, error) {
49+
args := m.Called(resourceId)
50+
status, _ := args.Get(0).(*dto.RuntimeStatus)
51+
return status, args.Error(1)
52+
}
53+
54+
func (m *MockShimlet) ListDeployedServices() ([]string, error) {
55+
args := m.Called()
56+
services, _ := args.Get(0).([]string)
57+
return services, args.Error(1)
58+
}
59+
60+
// 简化的测试用例
61+
func TestOrchestrator(t *testing.T) {
62+
// 创建临时配置文件
63+
configContent := `
64+
current-shimlet: "k8s"
65+
shimlets:
66+
k8s:
67+
config-path: "/conf/k8s-shimlet.yaml"
68+
`
69+
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
70+
assert.NoError(t, err)
71+
defer os.Remove(tmpFile.Name())
72+
73+
_, err = tmpFile.WriteString(configContent)
74+
assert.NoError(t, err)
75+
assert.NoError(t, tmpFile.Close())
76+
77+
// 设置配置文件路径
78+
config.SetConfigPath(tmpFile.Name())
79+
80+
// 测试Provision方法
81+
t.Run("Provision", func(t *testing.T) {
82+
// 创建测试需要的组件
83+
mockQueue := workqueue.New()
84+
defer mockQueue.ShutDown()
85+
86+
mockSpecStore := spec.NewMemoryStore()
87+
shimReg := typereg.New[shimlet.Shimlet]()
88+
89+
// 创建并配置mock shimlet
90+
mockShim := new(MockShimlet)
91+
mockShim.On("ID").Return("k8s")
92+
mockShim.On("InitWithConfig", "/conf/k8s-shimlet.yaml").Return(nil)
93+
94+
// 注册mock shimlet到注册中心
95+
shimReg.AutoRegister(mockShim)
96+
97+
// 创建orchestrator实例,使用固定的shimlet ID "k8s"
98+
orchestrator := NewOrchestratorWithShimletGetter(
99+
shimReg,
100+
map[string]*goal.GoalSet{},
101+
mockQueue,
102+
mockSpecStore,
103+
func() string {
104+
return "k8s"
105+
},
106+
)
107+
108+
serviceSpec := &dto.RequirementSpec{
109+
ServiceId: "test-service",
110+
GoalSetName: "test-goalset",
111+
ResourceRequirements: &dto.ResourceRequirements{},
112+
}
113+
114+
err := orchestrator.Provision(serviceSpec)
115+
assert.NoError(t, err)
116+
117+
// 验证规格被保存
118+
savedSpec := mockSpecStore.Get("test-service")
119+
assert.NotNil(t, savedSpec)
120+
121+
// 验证项目被添加到队列
122+
assert.Equal(t, 1, mockQueue.Len())
123+
124+
// 清理队列
125+
if mockQueue.Len() > 0 {
126+
_, done := mockQueue.Get()
127+
done()
128+
}
129+
})
130+
131+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package spec
2+
3+
import (
4+
dto "astron-xmod-shim/internal/dto/deploy"
5+
"testing"
6+
)
7+
8+
func TestMemoryStoreBasicOperations(t *testing.T) {
9+
// 创建内存存储
10+
store := NewMemoryStore()
11+
12+
// 创建测试数据
13+
serviceID := "test-service"
14+
spec := &dto.RequirementSpec{
15+
ServiceId: serviceID,
16+
ModelName: "test-model",
17+
ReplicaCount: 1,
18+
}
19+
20+
// 测试设置和获取
21+
store.Set(serviceID, spec)
22+
retrieved := store.Get(serviceID)
23+
if retrieved == nil {
24+
t.Error("Expected to retrieve non-nil spec")
25+
} else if retrieved.ServiceId != serviceID {
26+
t.Errorf("Expected service ID '%s', got '%s'", serviceID, retrieved.ServiceId)
27+
} else if retrieved.ModelName != "test-model" {
28+
t.Errorf("Expected model name 'test-model', got '%s'", retrieved.ModelName)
29+
}
30+
31+
// 测试删除
32+
store.Delete(serviceID)
33+
retrieved = store.Get(serviceID)
34+
if retrieved != nil {
35+
t.Error("Expected nil after delete")
36+
}
37+
}
38+
39+
func TestMemoryStoreMultipleServices(t *testing.T) {
40+
// 创建内存存储
41+
store := NewMemoryStore()
42+
43+
// 创建多个服务规格
44+
spec1 := &dto.RequirementSpec{ServiceId: "service-1", ModelName: "model-1"}
45+
spec2 := &dto.RequirementSpec{ServiceId: "service-2", ModelName: "model-2"}
46+
spec3 := &dto.RequirementSpec{ServiceId: "service-3", ModelName: "model-3"}
47+
48+
// 设置多个服务
49+
store.Set("service-1", spec1)
50+
store.Set("service-2", spec2)
51+
store.Set("service-3", spec3)
52+
53+
// 验证所有服务都能正确获取
54+
if retrieved := store.Get("service-1"); retrieved == nil || retrieved.ModelName != "model-1" {
55+
t.Error("Failed to retrieve service-1")
56+
}
57+
if retrieved := store.Get("service-2"); retrieved == nil || retrieved.ModelName != "model-2" {
58+
t.Error("Failed to retrieve service-2")
59+
}
60+
if retrieved := store.Get("service-3"); retrieved == nil || retrieved.ModelName != "model-3" {
61+
t.Error("Failed to retrieve service-3")
62+
}
63+
64+
// 验证不存在的服务返回nil
65+
if retrieved := store.Get("non-existent"); retrieved != nil {
66+
t.Error("Expected nil for non-existent service")
67+
}
68+
69+
// 删除一个服务并验证
70+
store.Delete("service-2")
71+
if retrieved := store.Get("service-2"); retrieved != nil {
72+
t.Error("Expected nil after deleting service-2")
73+
}
74+
// 确保其他服务仍然存在
75+
if retrieved := store.Get("service-1"); retrieved == nil {
76+
t.Error("service-1 should still exist")
77+
}
78+
}

0 commit comments

Comments
 (0)