Skip to content

Commit cca8bd0

Browse files
committed
introduce support for models
Signed-off-by: Nicolas De Loof <[email protected]>
1 parent 5687bcd commit cca8bd0

File tree

10 files changed

+444
-216
lines changed

10 files changed

+444
-216
lines changed

loader/loader_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3862,3 +3862,39 @@ services:
38623862
InterfaceName: "eth0",
38633863
})
38643864
}
3865+
3866+
func TestModel(t *testing.T) {
3867+
p, err := loadYAML(`
3868+
name: model
3869+
services:
3870+
test_array:
3871+
models:
3872+
- foo
3873+
3874+
test_mapping:
3875+
models:
3876+
foo:
3877+
endpoint_var: MODEL_URL
3878+
3879+
models:
3880+
foo:
3881+
model: ai/model
3882+
context_size: 1024
3883+
runtime_flags:
3884+
- "--some-flag"
3885+
`)
3886+
assert.NilError(t, err)
3887+
assert.DeepEqual(t, p.Models["foo"], types.ModelConfig{
3888+
Model: "ai/model",
3889+
ContextSize: 1024,
3890+
RuntimeFlags: []string{"--some-flag"},
3891+
})
3892+
assert.DeepEqual(t, p.Services["test_array"].Models, map[string]*types.ServiceModelConfig{
3893+
"foo": nil,
3894+
})
3895+
assert.DeepEqual(t, p.Services["test_mapping"].Models, map[string]*types.ServiceModelConfig{
3896+
"foo": {
3897+
EndpointVariable: "MODEL_URL",
3898+
},
3899+
})
3900+
}

override/merge.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ func init() {
6363
mergeSpecials["services.*.labels"] = mergeToSequence
6464
mergeSpecials["services.*.volumes.*.volume.labels"] = mergeToSequence
6565
mergeSpecials["services.*.logging"] = mergeLogging
66+
mergeSpecials["services.*.models"] = mergeModels
6667
mergeSpecials["services.*.networks"] = mergeNetworks
6768
mergeSpecials["services.*.sysctls"] = mergeToSequence
6869
mergeSpecials["services.*.tmpfs"] = mergeToSequence
@@ -158,6 +159,12 @@ func mergeDependsOn(c any, o any, path tree.Path) (any, error) {
158159
return mergeMappings(right, left, path)
159160
}
160161

162+
func mergeModels(c any, o any, path tree.Path) (any, error) {
163+
right := convertIntoMapping(c, nil)
164+
left := convertIntoMapping(o, nil)
165+
return mergeMappings(right, left, path)
166+
}
167+
161168
func mergeNetworks(c any, o any, path tree.Path) (any, error) {
162169
right := convertIntoMapping(c, nil)
163170
left := convertIntoMapping(o, nil)

schema/compose-spec.json

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@
3636
"description": "The services that will be used by your application."
3737
},
3838

39+
"models": {
40+
"type": "object",
41+
"patternProperties": {
42+
"^[a-zA-Z0-9._-]+$": {
43+
"$ref": "#/definitions/model"
44+
}
45+
},
46+
"description": "Language models that will be used by your application."
47+
},
48+
49+
3950
"networks": {
4051
"type": "object",
4152
"patternProperties": {
@@ -511,6 +522,27 @@
511522
"type": "string",
512523
"description": "Network mode. Values can be 'bridge', 'host', 'none', 'service:[service name]', or 'container:[container name]'."
513524
},
525+
"models": {
526+
"oneOf": [
527+
{"$ref": "#/definitions/list_of_strings"},
528+
{"type": "object",
529+
"patternProperties": {
530+
"^[a-zA-Z0-9._-]+$": {
531+
"type": "object",
532+
"properties": {
533+
"endpoint_var": {
534+
"type": "string",
535+
"description": "Environment variable set to AI model endpoint."
536+
}
537+
},
538+
"additionalProperties": false,
539+
"patternProperties": {"^x-": {}}
540+
}
541+
}
542+
}
543+
],
544+
"description": "AI Models to use, referencing entries under the top-level models key."
545+
},
514546
"networks": {
515547
"oneOf": [
516548
{"$ref": "#/definitions/list_of_strings"},
@@ -1530,6 +1562,32 @@
15301562
"patternProperties": {"^x-": {}}
15311563
},
15321564

1565+
"model": {
1566+
"type": "object",
1567+
"description": "Language Model for the Compose application.",
1568+
"properties": {
1569+
"name": {
1570+
"type": "string",
1571+
"description": "Custom name for this model."
1572+
},
1573+
"model": {
1574+
"type": "string",
1575+
"description": "Language Model to run."
1576+
},
1577+
"context_size": {
1578+
"type": "integer"
1579+
},
1580+
"runtime_flags": {
1581+
"type": "array",
1582+
"items": {"type": "string"},
1583+
"description": "Raw runtime flags to pass to the inference engine."
1584+
}
1585+
},
1586+
"required": ["model"],
1587+
"additionalProperties": false,
1588+
"patternProperties": {"^x-": {}}
1589+
},
1590+
15331591
"command": {
15341592
"oneOf": [
15351593
{

transform/canonical.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ func init() {
3333
transformers["services.*.label_file"] = transformStringOrList
3434
transformers["services.*.extends"] = transformExtends
3535
transformers["services.*.gpus"] = transformGpus
36-
transformers["services.*.networks"] = transformServiceNetworks
36+
transformers["services.*.networks"] = transformStringSliceToMap
37+
transformers["services.*.models"] = transformStringSliceToMap
3738
transformers["services.*.volumes.*"] = transformVolumeMount
3839
transformers["services.*.dns"] = transformStringOrList
3940
transformers["services.*.devices.*"] = transformDeviceMapping

transform/services.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ func transformService(data any, p tree.Path, ignoreParseError bool) (any, error)
2929
}
3030
}
3131

32-
func transformServiceNetworks(data any, _ tree.Path, _ bool) (any, error) {
32+
func transformStringSliceToMap(data any, _ tree.Path, _ bool) (any, error) {
3333
if slice, ok := data.([]any); ok {
34-
networks := make(map[string]any, len(slice))
34+
mapping := make(map[string]any, len(slice))
3535
for _, net := range slice {
36-
networks[net.(string)] = nil
36+
mapping[net.(string)] = nil
3737
}
38-
return networks, nil
38+
return mapping, nil
3939
}
4040
return data, nil
4141
}

types/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ type Secrets map[string]SecretConfig
101101
// Configs is a map of ConfigObjConfig
102102
type Configs map[string]ConfigObjConfig
103103

104+
type Models map[string]ModelConfig
105+
104106
// Extensions is a map of custom extension
105107
type Extensions map[string]any
106108

0 commit comments

Comments
 (0)