Skip to content

Commit 7b97da5

Browse files
authored
Add fast pusher for fast builds (#2114)
* Add fast pusher for fast builds * Separates push into standard and fast * Generates the necessary tar files to push * Calls the backend to create a file to push * Add Login Token to HTTP POST request * Remove redundant tar files * When building/pushing make sure we remove redundant tar files we have created previously. * Add tests to fast push * Fix injecting docker command into push * Add LoadLoginToken to docker command object * Add a create src tarfile error * Mock CreateTarFile * Format generate apt tarball errors * Inject the docker command object * Allow mocking of docker exec commands * Use command to create apt tar file * Fix bad variable assignment * Remove aptTarPath assignment * Add a comment on the use of image on CreateTarFile * Change the config arg to a packages array arg * Use snakeCase for apt tarball prefix/suffix consts * Use snakeCase in fast push consts * Use error groups to wait for multiple uploads * Remove ineffectual assignment * Do not use url as a variable name * Remove size from the backend call * Set test environment with t.Setenv * Put MockCommand in dockertest package * Fix pip_freeze integration test
1 parent 60017a9 commit 7b97da5

36 files changed

+925
-252
lines changed

pkg/cli/baseimage.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,5 +148,6 @@ func baseImageGeneratorFromFlags() (*dockerfile.BaseImageGenerator, error) {
148148
baseImageCUDAVersion,
149149
baseImagePythonVersion,
150150
baseImageTorchVersion,
151+
docker.NewDockerCommand(),
151152
)
152153
}

pkg/cli/debug.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/spf13/cobra"
77

88
"github.com/replicate/cog/pkg/config"
9+
"github.com/replicate/cog/pkg/docker"
910
"github.com/replicate/cog/pkg/dockerfile"
1011
"github.com/replicate/cog/pkg/global"
1112
"github.com/replicate/cog/pkg/util/console"
@@ -37,7 +38,8 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error {
3738
return err
3839
}
3940

40-
generator, err := dockerfile.NewGenerator(cfg, projectDir, false)
41+
command := docker.NewDockerCommand()
42+
generator, err := dockerfile.NewGenerator(cfg, projectDir, false, command)
4143
if err != nil {
4244
return fmt.Errorf("Error creating Dockerfile generator: %w", err)
4345
}

pkg/cli/push.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ func push(cmd *cobra.Command, args []string) error {
6060
}
6161

6262
if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast); err != nil {
63-
6463
return err
6564
}
6665

@@ -69,7 +68,8 @@ func push(cmd *cobra.Command, args []string) error {
6968
console.Info("Fast push enabled.")
7069
}
7170

72-
err = docker.Push(imageName)
71+
command := docker.NewDockerCommand()
72+
err = docker.Push(imageName, buildFast, projectDir, command)
7373
if err != nil {
7474
if strings.Contains(err.Error(), "NAME_UNKNOWN") {
7575
return fmt.Errorf("Unable to find existing Replicate model for %s. "+

pkg/config/config_test.go

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -670,41 +670,6 @@ func TestBlankBuild(t *testing.T) {
670670
require.Equal(t, false, config.Build.GPU)
671671
}
672672

673-
func TestSplitPinnedPythonRequirement(t *testing.T) {
674-
testCases := []struct {
675-
input string
676-
expectedName string
677-
expectedVersion string
678-
expectedFindLinks []string
679-
expectedExtraIndexURLs []string
680-
expectedError bool
681-
}{
682-
{"package1==1.0.0", "package1", "1.0.0", nil, nil, false},
683-
{"package1==1.0.0+alpha", "package1", "1.0.0+alpha", nil, nil, false},
684-
{"--find-links=link1 --find-links=link2 package3==3.0.0", "package3", "3.0.0", []string{"link1", "link2"}, nil, false},
685-
{"package4==4.0.0 --extra-index-url=url1 --extra-index-url=url2", "package4", "4.0.0", nil, []string{"url1", "url2"}, false},
686-
{"-f link1 --find-links=link2 package5==5.0.0 --extra-index-url=url1 --extra-index-url=url2", "package5", "5.0.0", []string{"link1", "link2"}, []string{"url1", "url2"}, false},
687-
{"package6 --find-links=link1 --find-links=link2 --extra-index-url=url1 --extra-index-url=url2", "", "", nil, nil, true},
688-
{"invalid package", "", "", nil, nil, true},
689-
{"package8==", "", "", nil, nil, true},
690-
{"==8.0.0", "", "", nil, nil, true},
691-
}
692-
693-
for _, tc := range testCases {
694-
name, version, findLinks, extraIndexURLs, err := SplitPinnedPythonRequirement(tc.input)
695-
696-
if tc.expectedError {
697-
require.Error(t, err)
698-
} else {
699-
require.NoError(t, err)
700-
require.Equal(t, tc.expectedName, name, "input: "+tc.input)
701-
require.Equal(t, tc.expectedVersion, version, "input: "+tc.input)
702-
require.Equal(t, tc.expectedFindLinks, findLinks, "input: "+tc.input)
703-
require.Equal(t, tc.expectedExtraIndexURLs, extraIndexURLs, "input: "+tc.input)
704-
}
705-
}
706-
}
707-
708673
func TestPythonRequirementsForArchWithAddedPackage(t *testing.T) {
709674
config := &Config{
710675
Build: &Build{

pkg/config/requirements.go

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,10 @@
11
package config
22

33
import (
4-
"bufio"
5-
"errors"
64
"fmt"
7-
"os"
8-
"path/filepath"
95
"regexp"
10-
"sort"
116
)
127

13-
func GenerateRequirements(tmpDir string, config *Config) (string, error) {
14-
// Deduplicate packages between the requirements.txt and the python packages directive.
15-
packageNames := make(map[string]string)
16-
17-
// Read the python packages configuration.
18-
for _, requirement := range config.Build.PythonPackages {
19-
packageName, err := PackageName(requirement)
20-
if err != nil {
21-
return "", err
22-
}
23-
packageNames[packageName] = requirement
24-
}
25-
26-
// Read the python requirements.
27-
if config.Build.PythonRequirements != "" {
28-
fh, err := os.Open(config.Build.PythonRequirements)
29-
if err != nil {
30-
return "", err
31-
}
32-
scanner := bufio.NewScanner(fh)
33-
for scanner.Scan() {
34-
requirement := scanner.Text()
35-
packageName, err := PackageName(requirement)
36-
if err != nil {
37-
return "", err
38-
}
39-
packageNames[packageName] = requirement
40-
}
41-
}
42-
43-
// If we don't have any packages skip further processing
44-
if len(packageNames) == 0 {
45-
return "", nil
46-
}
47-
48-
// Sort the package names by alphabetical order.
49-
keys := make([]string, 0, len(packageNames))
50-
for k := range packageNames {
51-
keys = append(keys, k)
52-
}
53-
sort.Strings(keys)
54-
55-
// Render the expected contents
56-
requirementsContent := ""
57-
for _, k := range keys {
58-
requirementsContent += packageNames[k] + "\n"
59-
}
60-
61-
// Check against the old requirements contents
62-
requirementsFile := filepath.Join(tmpDir, "requirements.txt")
63-
_, err := os.Stat(requirementsFile)
64-
if !errors.Is(err, os.ErrNotExist) {
65-
bytes, err := os.ReadFile(requirementsFile)
66-
if err != nil {
67-
return "", err
68-
}
69-
oldRequirementsContents := string(bytes)
70-
if oldRequirementsContents == requirementsFile {
71-
return requirementsFile, nil
72-
}
73-
}
74-
75-
// Write out a new requirements file
76-
err = os.WriteFile(requirementsFile, []byte(requirementsContent), 0o644)
77-
if err != nil {
78-
return "", err
79-
}
80-
return requirementsFile, nil
81-
}
82-
838
// SplitPinnedPythonRequirement returns the name, version, findLinks, and extraIndexURLs from a requirements.txt line
849
// in the form name==version [--find-links=<findLink>] [-f <findLink>] [--extra-index-url=<extraIndexURL>]
8510
func SplitPinnedPythonRequirement(requirement string) (name string, version string, findLinks []string, extraIndexURLs []string, err error) {

pkg/config/requirements_test.go

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,42 @@
11
package config
22

33
import (
4-
"path/filepath"
54
"testing"
65

76
"github.com/stretchr/testify/require"
87
)
98

10-
func TestGenerateRequirements(t *testing.T) {
11-
tmpDir := t.TempDir()
12-
build := Build{
13-
PythonPackages: []string{"torch==2.5.1"},
9+
func TestSplitPinnedPythonRequirement(t *testing.T) {
10+
testCases := []struct {
11+
input string
12+
expectedName string
13+
expectedVersion string
14+
expectedFindLinks []string
15+
expectedExtraIndexURLs []string
16+
expectedError bool
17+
}{
18+
{"package1==1.0.0", "package1", "1.0.0", nil, nil, false},
19+
{"package1==1.0.0+alpha", "package1", "1.0.0+alpha", nil, nil, false},
20+
{"--find-links=link1 --find-links=link2 package3==3.0.0", "package3", "3.0.0", []string{"link1", "link2"}, nil, false},
21+
{"package4==4.0.0 --extra-index-url=url1 --extra-index-url=url2", "package4", "4.0.0", nil, []string{"url1", "url2"}, false},
22+
{"-f link1 --find-links=link2 package5==5.0.0 --extra-index-url=url1 --extra-index-url=url2", "package5", "5.0.0", []string{"link1", "link2"}, []string{"url1", "url2"}, false},
23+
{"package6 --find-links=link1 --find-links=link2 --extra-index-url=url1 --extra-index-url=url2", "", "", nil, nil, true},
24+
{"invalid package", "", "", nil, nil, true},
25+
{"package8==", "", "", nil, nil, true},
26+
{"==8.0.0", "", "", nil, nil, true},
1427
}
15-
config := Config{
16-
Build: &build,
28+
29+
for _, tc := range testCases {
30+
name, version, findLinks, extraIndexURLs, err := SplitPinnedPythonRequirement(tc.input)
31+
32+
if tc.expectedError {
33+
require.Error(t, err)
34+
} else {
35+
require.NoError(t, err)
36+
require.Equal(t, tc.expectedName, name, "input: "+tc.input)
37+
require.Equal(t, tc.expectedVersion, version, "input: "+tc.input)
38+
require.Equal(t, tc.expectedFindLinks, findLinks, "input: "+tc.input)
39+
require.Equal(t, tc.expectedExtraIndexURLs, extraIndexURLs, "input: "+tc.input)
40+
}
1741
}
18-
requirementsFile, err := GenerateRequirements(tmpDir, &config)
19-
require.NoError(t, err)
20-
require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile)
2142
}

pkg/docker/apt.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package docker
2+
3+
import (
4+
"crypto/sha256"
5+
"encoding/hex"
6+
"errors"
7+
"fmt"
8+
"os"
9+
"path/filepath"
10+
"sort"
11+
"strings"
12+
)
13+
14+
const aptTarballPrefix = "apt."
15+
const aptTarballSuffix = ".tar.zst"
16+
17+
func CreateAptTarball(tmpDir string, command Command, packages ...string) (string, error) {
18+
if len(packages) > 0 {
19+
sort.Strings(packages)
20+
hash := sha256.New()
21+
hash.Write([]byte(strings.Join(packages, " ")))
22+
hexHash := hex.EncodeToString(hash.Sum(nil))
23+
aptTarFile := aptTarballPrefix + hexHash + aptTarballSuffix
24+
aptTarPath := filepath.Join(tmpDir, aptTarFile)
25+
26+
if _, err := os.Stat(aptTarPath); errors.Is(err, os.ErrNotExist) {
27+
// Remove previous apt tar files.
28+
err = removeAptTarballs(tmpDir)
29+
if err != nil {
30+
return "", err
31+
}
32+
33+
// Create the apt tar file
34+
_, err = command.CreateAptTarFile(tmpDir, aptTarFile, packages...)
35+
if err != nil {
36+
return "", err
37+
}
38+
}
39+
40+
return aptTarFile, nil
41+
}
42+
return "", nil
43+
}
44+
45+
func CurrentAptTarball(tmpDir string) (string, error) {
46+
files, err := os.ReadDir(tmpDir)
47+
if err != nil {
48+
return "", fmt.Errorf("os read dir error: %w", err)
49+
}
50+
51+
for _, file := range files {
52+
fileName := file.Name()
53+
if strings.HasPrefix(fileName, aptTarballPrefix) && strings.HasSuffix(fileName, aptTarballSuffix) {
54+
return filepath.Join(tmpDir, fileName), nil
55+
}
56+
}
57+
58+
return "", nil
59+
}
60+
61+
func removeAptTarballs(tmpDir string) error {
62+
files, err := os.ReadDir(tmpDir)
63+
if err != nil {
64+
return err
65+
}
66+
67+
for _, file := range files {
68+
fileName := file.Name()
69+
if strings.HasPrefix(fileName, aptTarballPrefix) && strings.HasSuffix(fileName, aptTarballSuffix) {
70+
err = os.Remove(filepath.Join(tmpDir, fileName))
71+
if err != nil {
72+
return err
73+
}
74+
}
75+
}
76+
77+
return nil
78+
}

pkg/docker/apt_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package docker
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
8+
"github.com/replicate/cog/pkg/docker/dockertest"
9+
)
10+
11+
func TestCreateAptTarball(t *testing.T) {
12+
dir := t.TempDir()
13+
command := dockertest.NewMockCommand()
14+
tarball, err := CreateAptTarball(dir, command, []string{}...)
15+
require.NoError(t, err)
16+
require.Equal(t, "", tarball)
17+
}

pkg/docker/build.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"strings"
99

1010
"github.com/replicate/cog/pkg/config"
11-
"github.com/replicate/cog/pkg/dockerfile"
1211

1312
"github.com/replicate/cog/pkg/util"
1413
"github.com/replicate/cog/pkg/util/console"
@@ -17,7 +16,7 @@ import (
1716
func Build(dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64) error {
1817
var args []string
1918

20-
userCache, err := dockerfile.UserCache()
19+
userCache, err := UserCache()
2120
if err != nil {
2221
return err
2322
}

pkg/docker/command.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package docker
2+
3+
type Command interface {
4+
Push(string) error
5+
LoadLoginToken(string) (string, error)
6+
CreateTarFile(string, string, string, string) (string, error)
7+
CreateAptTarFile(string, string, ...string) (string, error)
8+
}
9+
10+
type CredentialHelperInput struct {
11+
Username string
12+
Secret string
13+
ServerURL string
14+
}

0 commit comments

Comments
 (0)