Skip to content

Commit 9efb306

Browse files
authored
Add fast generator for cog build (#2108)
* Add fast generator * Add copy weights step * Add package installation step * Add user cache to the monobase * Resolve monobase caching into the local userspace to allow for monobase caching between builds * Add rsync copy * Check if weight files have changed before checksum * Split install into separate steps * We can get the tarballs for each of these layers instead of creating 1 big tarball. * Create requirements.txt in the build tmp directory * Fix unit tests * Fix lint * Add basic unit tests * Use UV_CACHE_DIR and mount uv cache * Remove —skip-cuda from monobase build * Monobase now handles empty CUDA env vars * Fix file not found when evaluating weights * Add UV_LINK_MODE=copy to the uv install commands * Add UV_COMPILE_BYTECODE env var * Remove verbosity from monobase exec * Fix integration test * Switch tini and exec --------- Signed-off-by: Will Sackfield <[email protected]>
1 parent 6eb2d2e commit 9efb306

26 files changed

+913
-83
lines changed

pkg/cli/debug.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error {
3737
return err
3838
}
3939

40-
generator, err := dockerfile.NewGenerator(cfg, projectDir)
40+
generator, err := dockerfile.NewGenerator(cfg, projectDir, false)
4141
if err != nil {
4242
return fmt.Errorf("Error creating Dockerfile generator: %w", err)
4343
}

pkg/config/compatibility.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ func CUDABaseImageFor(cuda string, cuDNN string) (string, error) {
276276
func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err error) {
277277
for _, compat := range TFCompatibilityMatrix {
278278
if compat.TF == ver && version.Equal(compat.CUDA, cuda) {
279-
name, cpuVersion, _, _, err = splitPinnedPythonRequirement(compat.TFGPUPackage)
279+
name, cpuVersion, _, _, err = SplitPinnedPythonRequirement(compat.TFGPUPackage)
280280
return name, cpuVersion, err
281281
}
282282
}

pkg/config/config.go

Lines changed: 8 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ func (c *Config) cudaFromTF() (tfVersion string, tfCUDA string, tfCuDNN string,
220220

221221
func (c *Config) pythonPackageVersion(name string) (version string, ok bool) {
222222
for _, pkg := range c.Build.pythonRequirementsContent {
223-
pkgName, version, _, _, err := splitPinnedPythonRequirement(pkg)
223+
pkgName, version, _, _, err := SplitPinnedPythonRequirement(pkg)
224224
if err != nil {
225225
// package is not in package==version format
226226
continue
@@ -331,7 +331,11 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
331331

332332
includePackageNames := []string{}
333333
for _, pkg := range includePackages {
334-
includePackageNames = append(includePackageNames, packageName(pkg))
334+
packageName, err := PackageName(pkg)
335+
if err != nil {
336+
return "", err
337+
}
338+
includePackageNames = append(includePackageNames, packageName)
335339
}
336340

337341
// Include all the requirements and remove our include packages if they exist
@@ -352,7 +356,7 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
352356
}
353357
}
354358

355-
packageName := packageName(archPkg)
359+
packageName, _ := PackageName(archPkg)
356360
if packageName != "" {
357361
foundIdx := -1
358362
for i, includePkg := range includePackageNames {
@@ -390,7 +394,7 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
390394
// pythonPackageForArch takes a package==version line and
391395
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture
392396
func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) {
393-
name, version, findLinksList, extraIndexURLs, err := splitPinnedPythonRequirement(pkg)
397+
name, version, findLinksList, extraIndexURLs, err := SplitPinnedPythonRequirement(pkg)
394398
if err != nil {
395399
// It's not pinned, so just return the line verbatim
396400
return pkg, []string{}, []string{}, nil
@@ -562,50 +566,6 @@ Compatible cuDNN version is: %s`, c.Build.CuDNN, tfVersion, tfCuDNN)
562566
return nil
563567
}
564568

565-
// splitPythonPackage returns the name, version, findLinks, and extraIndexURLs from a requirements.txt line
566-
// in the form name==version [--find-links=<findLink>] [-f <findLink>] [--extra-index-url=<extraIndexURL>]
567-
func splitPinnedPythonRequirement(requirement string) (name string, version string, findLinks []string, extraIndexURLs []string, err error) {
568-
pinnedPackageRe := regexp.MustCompile(`(?:([a-zA-Z0-9\-_]+)==([^ ]+)|--find-links=([^\s]+)|-f\s+([^\s]+)|--extra-index-url=([^\s]+))`)
569-
570-
matches := pinnedPackageRe.FindAllStringSubmatch(requirement, -1)
571-
if matches == nil {
572-
return "", "", nil, nil, fmt.Errorf("Package %s is not in the expected format", requirement)
573-
}
574-
575-
nameFound := false
576-
versionFound := false
577-
578-
for _, match := range matches {
579-
if match[1] != "" {
580-
name = match[1]
581-
nameFound = true
582-
}
583-
584-
if match[2] != "" {
585-
version = match[2]
586-
versionFound = true
587-
}
588-
589-
if match[3] != "" {
590-
findLinks = append(findLinks, match[3])
591-
}
592-
593-
if match[4] != "" {
594-
findLinks = append(findLinks, match[4])
595-
}
596-
597-
if match[5] != "" {
598-
extraIndexURLs = append(extraIndexURLs, match[5])
599-
}
600-
}
601-
602-
if !nameFound || !versionFound {
603-
return "", "", nil, nil, fmt.Errorf("Package name or version is missing in %s", requirement)
604-
}
605-
606-
return name, version, findLinks, extraIndexURLs, nil
607-
}
608-
609569
func sliceContains(slice []string, s string) bool {
610570
for _, el := range slice {
611571
if el == s {
@@ -614,11 +574,3 @@ func sliceContains(slice []string, s string) bool {
614574
}
615575
return false
616576
}
617-
618-
func packageName(pipRequirement string) string {
619-
match := PipPackageNameRegex.FindStringSubmatch(pipRequirement)
620-
if len(match) <= 1 {
621-
return ""
622-
}
623-
return match[1]
624-
}

pkg/config/config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ func TestSplitPinnedPythonRequirement(t *testing.T) {
691691
}
692692

693693
for _, tc := range testCases {
694-
name, version, findLinks, extraIndexURLs, err := splitPinnedPythonRequirement(tc.input)
694+
name, version, findLinks, extraIndexURLs, err := SplitPinnedPythonRequirement(tc.input)
695695

696696
if tc.expectedError {
697697
require.Error(t, err)

pkg/config/requirements.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package config
2+
3+
import (
4+
"bufio"
5+
"errors"
6+
"fmt"
7+
"os"
8+
"path/filepath"
9+
"regexp"
10+
"sort"
11+
)
12+
13+
func GenerateRequirements(tmpDir string, config *Config) (string, error) {
14+
// Deduplicate packages between the requirements.txt and the python packages directive.
15+
packageNames := make(map[string]string)
16+
17+
// Read the python packages configuration.
18+
for _, requirement := range config.Build.PythonPackages {
19+
packageName, err := PackageName(requirement)
20+
if err != nil {
21+
return "", err
22+
}
23+
packageNames[packageName] = requirement
24+
}
25+
26+
// Read the python requirements.
27+
if config.Build.PythonRequirements != "" {
28+
fh, err := os.Open(config.Build.PythonRequirements)
29+
if err != nil {
30+
return "", err
31+
}
32+
scanner := bufio.NewScanner(fh)
33+
for scanner.Scan() {
34+
requirement := scanner.Text()
35+
packageName, err := PackageName(requirement)
36+
if err != nil {
37+
return "", err
38+
}
39+
packageNames[packageName] = requirement
40+
}
41+
}
42+
43+
// If we don't have any packages skip further processing
44+
if len(packageNames) == 0 {
45+
return "", nil
46+
}
47+
48+
// Sort the package names by alphabetical order.
49+
keys := make([]string, 0, len(packageNames))
50+
for k := range packageNames {
51+
keys = append(keys, k)
52+
}
53+
sort.Strings(keys)
54+
55+
// Render the expected contents
56+
requirementsContent := ""
57+
for _, k := range keys {
58+
requirementsContent += packageNames[k] + "\n"
59+
}
60+
61+
// Check against the old requirements contents
62+
requirementsFile := filepath.Join(tmpDir, "requirements.txt")
63+
_, err := os.Stat(requirementsFile)
64+
if !errors.Is(err, os.ErrNotExist) {
65+
bytes, err := os.ReadFile(requirementsFile)
66+
if err != nil {
67+
return "", err
68+
}
69+
oldRequirementsContents := string(bytes)
70+
if oldRequirementsContents == requirementsFile {
71+
return requirementsFile, nil
72+
}
73+
}
74+
75+
// Write out a new requirements file
76+
err = os.WriteFile(requirementsFile, []byte(requirementsContent), 0o644)
77+
if err != nil {
78+
return "", err
79+
}
80+
return requirementsFile, nil
81+
}
82+
83+
// SplitPinnedPythonRequirement returns the name, version, findLinks, and extraIndexURLs from a requirements.txt line
84+
// in the form name==version [--find-links=<findLink>] [-f <findLink>] [--extra-index-url=<extraIndexURL>]
85+
func SplitPinnedPythonRequirement(requirement string) (name string, version string, findLinks []string, extraIndexURLs []string, err error) {
86+
pinnedPackageRe := regexp.MustCompile(`(?:([a-zA-Z0-9\-_]+)==([^ ]+)|--find-links=([^\s]+)|-f\s+([^\s]+)|--extra-index-url=([^\s]+))`)
87+
88+
matches := pinnedPackageRe.FindAllStringSubmatch(requirement, -1)
89+
if matches == nil {
90+
return "", "", nil, nil, fmt.Errorf("Package %s is not in the expected format", requirement)
91+
}
92+
93+
nameFound := false
94+
versionFound := false
95+
96+
for _, match := range matches {
97+
if match[1] != "" {
98+
name = match[1]
99+
nameFound = true
100+
}
101+
102+
if match[2] != "" {
103+
version = match[2]
104+
versionFound = true
105+
}
106+
107+
if match[3] != "" {
108+
findLinks = append(findLinks, match[3])
109+
}
110+
111+
if match[4] != "" {
112+
findLinks = append(findLinks, match[4])
113+
}
114+
115+
if match[5] != "" {
116+
extraIndexURLs = append(extraIndexURLs, match[5])
117+
}
118+
}
119+
120+
if !nameFound || !versionFound {
121+
return "", "", nil, nil, fmt.Errorf("Package name or version is missing in %s", requirement)
122+
}
123+
124+
return name, version, findLinks, extraIndexURLs, nil
125+
}
126+
127+
func PackageName(pipRequirement string) (string, error) {
128+
name, _, _, _, err := SplitPinnedPythonRequirement(pipRequirement)
129+
return name, err
130+
}

pkg/config/requirements_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package config
2+
3+
import (
4+
"path/filepath"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestGenerateRequirements(t *testing.T) {
11+
tmpDir := t.TempDir()
12+
build := Build{
13+
PythonPackages: []string{"torch==2.5.1"},
14+
}
15+
config := Config{
16+
Build: &build,
17+
}
18+
requirementsFile, err := GenerateRequirements(tmpDir, &config)
19+
require.NoError(t, err)
20+
require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile)
21+
}

pkg/docker/build.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,22 @@ import (
88
"strings"
99

1010
"github.com/replicate/cog/pkg/config"
11+
"github.com/replicate/cog/pkg/dockerfile"
1112

1213
"github.com/replicate/cog/pkg/util"
1314
"github.com/replicate/cog/pkg/util/console"
1415
)
1516

16-
func Build(dir, dockerfile, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64) error {
17+
func Build(dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64) error {
1718
var args []string
1819

20+
userCache, err := dockerfile.UserCache()
21+
if err != nil {
22+
return err
23+
}
24+
1925
args = append(args,
20-
"buildx", "build",
26+
"buildx", "build", "--build-context", "usercache="+userCache,
2127
)
2228

2329
if util.IsAppleSiliconMac(runtime.GOOS, runtime.GOARCH) {
@@ -65,7 +71,7 @@ func Build(dir, dockerfile, imageName string, secrets []string, noCache bool, pr
6571
cmd.Dir = dir
6672
cmd.Stdout = os.Stderr // redirect stdout to stderr - build output is all messaging
6773
cmd.Stderr = os.Stderr
68-
cmd.Stdin = strings.NewReader(dockerfile)
74+
cmd.Stdin = strings.NewReader(dockerfileContents)
6975

7076
console.Debug("$ " + strings.Join(cmd.Args, " "))
7177
return cmd.Run()

pkg/dockerfile/base.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func (g *BaseImageGenerator) GenerateDockerfile() (string, error) {
178178
return "", err
179179
}
180180

181-
generator, err := NewGenerator(conf, "")
181+
generator, err := NewGenerator(conf, "", false)
182182
if err != nil {
183183
return "", err
184184
}

pkg/dockerfile/build_tempdir.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package dockerfile
2+
3+
import (
4+
"os"
5+
"path"
6+
"time"
7+
)
8+
9+
func BuildCogTempDir(dir string) (string, error) {
10+
rootTmp := path.Join(dir, ".cog/tmp")
11+
if err := os.MkdirAll(rootTmp, 0o755); err != nil {
12+
return "", err
13+
}
14+
return rootTmp, nil
15+
}
16+
17+
func BuildTempDir(dir string) (string, error) {
18+
rootTmp, err := BuildCogTempDir(dir)
19+
if err != nil {
20+
return "", err
21+
}
22+
23+
if err := os.MkdirAll(rootTmp, 0o755); err != nil {
24+
return "", err
25+
}
26+
// tmpDir ends up being something like dir/.cog/tmp/build20240620123456.000000
27+
now := time.Now().Format("20060102150405.000000")
28+
tmpDir, err := os.MkdirTemp(rootTmp, "build"+now)
29+
if err != nil {
30+
return "", err
31+
}
32+
return tmpDir, nil
33+
}

pkg/dockerfile/build_tempdir_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package dockerfile
2+
3+
import (
4+
"path/filepath"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestBuildCogTempDir(t *testing.T) {
11+
tmpDir := t.TempDir()
12+
cogTmpDir, err := BuildCogTempDir(tmpDir)
13+
require.NoError(t, err)
14+
require.Equal(t, filepath.Join(tmpDir, ".cog/tmp"), cogTmpDir)
15+
}

0 commit comments

Comments
 (0)