Skip to content

Commit ffbf3e2

Browse files
committed
Make cog_runtime warning more passive aggressive
1 parent b9e7214 commit ffbf3e2

File tree

10 files changed

+110
-36
lines changed

10 files changed

+110
-36
lines changed

pkg/cli/build.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ func buildCommand(cmd *cobra.Command, args []string) error {
8686
buildFast = cfg.Build.Fast
8787
}
8888
logCtx.Fast = buildFast
89-
logCtx.CogRuntime = cfg.Build.CogRuntime
89+
logCtx.CogRuntime = false
90+
if cfg.Build.CogRuntime != nil {
91+
logCtx.CogRuntime = *cfg.Build.CogRuntime
92+
}
9093

9194
imageName := cfg.Image
9295
if buildTag != "" {

pkg/cli/push.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ func push(cmd *cobra.Command, args []string) error {
7373
buildFast = cfg.Build.Fast
7474
}
7575
logCtx.Fast = buildFast
76-
logCtx.CogRuntime = cfg.Build.CogRuntime
76+
logCtx.CogRuntime = false
77+
if cfg.Build.CogRuntime != nil {
78+
logCtx.CogRuntime = *cfg.Build.CogRuntime
79+
}
7780

7881
imageName := cfg.Image
7982
if len(args) > 0 {

pkg/config/config.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ type Build struct {
5656
CUDA string `json:"cuda,omitempty" yaml:"cuda,omitempty"`
5757
CuDNN string `json:"cudnn,omitempty" yaml:"cudnn,omitempty"`
5858
Fast bool `json:"fast,omitempty" yaml:"fast,omitempty"`
59-
CogRuntime bool `json:"cog_runtime,omitempty" yaml:"cog_runtime,omitempty"`
59+
CogRuntime *bool `json:"cog_runtime,omitempty" yaml:"cog_runtime,omitempty"`
6060
PythonOverrides string `json:"python_overrides,omitempty" yaml:"python_overrides,omitempty"`
6161

6262
pythonRequirementsContent []string
@@ -72,6 +72,7 @@ type Example struct {
7272
}
7373

7474
type Config struct {
75+
filename string
7576
Build *Build `json:"build" yaml:"build"`
7677
Image string `json:"image,omitempty" yaml:"image,omitempty"`
7778
Predict string `json:"predict,omitempty" yaml:"predict"`
@@ -82,6 +83,10 @@ type Config struct {
8283
parsedEnvironment map[string]string
8384
}
8485

86+
func (c *Config) Filename() string {
87+
return c.filename
88+
}
89+
8590
func DefaultConfig() *Config {
8691
return &Config{
8792
Build: &Build{

pkg/config/load.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ func GetConfig(configFilename string) (*Config, string, error) {
2929
return nil, "", err
3030
}
3131
err = config.ValidateAndComplete(rootDir)
32+
config.filename = configFilename
3233
return config, rootDir, err
3334
}
3435

pkg/dockerfile/standard_generator.go

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ import (
1414
"github.com/replicate/cog/pkg/docker/command"
1515
"github.com/replicate/cog/pkg/dockercontext"
1616
"github.com/replicate/cog/pkg/registry"
17+
"github.com/replicate/cog/pkg/util"
1718
"github.com/replicate/cog/pkg/util/console"
1819
"github.com/replicate/cog/pkg/util/slices"
1920
"github.com/replicate/cog/pkg/util/version"
2021
"github.com/replicate/cog/pkg/weights"
22+
"golang.org/x/term"
23+
"gopkg.in/yaml.v2"
2124
)
2225

2326
const DockerignoreHeader = `# generated by replicate/cog
@@ -443,28 +446,17 @@ func (g *StandardGenerator) installCog() (string, error) {
443446
return "", nil
444447
}
445448

446-
if g.Config.Build.CogRuntime {
447-
// We need fast-* compliant Python version to reconstruct coglet venv PYTHONPATH
448-
if !CheckMajorMinorOnly(g.Config.Build.PythonVersion) {
449-
return "", fmt.Errorf("Python version must be <major>.<minor>")
450-
}
451-
m, err := NewMonobaseMatrix(http.DefaultClient)
452-
if err != nil {
453-
return "", err
454-
}
455-
cmds := []string{
456-
"ENV R8_COG_VERSION=coglet",
457-
"ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion,
458-
"RUN pip install " + m.LatestCoglet.URL,
459-
}
460-
return strings.Join(cmds, "\n"), nil
449+
if g.Config.Build.CogRuntime != nil && *g.Config.Build.CogRuntime {
450+
return g.installCogRuntime()
461451
}
462-
463-
// cog-runtime does not support training yet
464-
if g.Config.Train == "" {
465-
console.Warnf("Major Cog runtime upgrade available. Opt in now by setting build.cog_runtime: true in cog.yaml.")
466-
console.Warnf("More: https://replicate.com/changelog/2025-07-21-cog-runtime")
452+
accepted, err := g.askAboutCogRuntime()
453+
if err != nil {
454+
return "", err
467455
}
456+
if accepted {
457+
return g.installCogRuntime()
458+
}
459+
468460
data, filename, err := ReadWheelFile()
469461
if err != nil {
470462
return "", err
@@ -483,6 +475,75 @@ func (g *StandardGenerator) installCog() (string, error) {
483475
return strings.Join(lines, "\n"), nil
484476
}
485477

478+
func (g *StandardGenerator) installCogRuntime() (string, error) {
479+
// We need fast-* compliant Python version to reconstruct coglet venv PYTHONPATH
480+
if !CheckMajorMinorOnly(g.Config.Build.PythonVersion) {
481+
return "", fmt.Errorf("Python version must be <major>.<minor>")
482+
}
483+
m, err := NewMonobaseMatrix(http.DefaultClient)
484+
if err != nil {
485+
return "", err
486+
}
487+
cmds := []string{
488+
"ENV R8_COG_VERSION=coglet",
489+
"ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion,
490+
"RUN pip install " + m.LatestCoglet.URL,
491+
}
492+
return strings.Join(cmds, "\n"), nil
493+
}
494+
495+
func (g *StandardGenerator) askAboutCogRuntime() (bool, error) {
496+
// Training is not supported
497+
if g.Config.Train != "" {
498+
return false, nil
499+
}
500+
// Only warn if cog_runtime is not explicitly set
501+
if g.Config.Build.CogRuntime != nil {
502+
return false, nil
503+
}
504+
505+
console.Warnf("Major Cog runtime upgrade available. Opt in now by setting build.cog_runtime: true in cog.yaml.")
506+
console.Warnf("More: https://replicate.com/changelog/2025-07-21-cog-runtime")
507+
508+
// Skip question if not in an interactive shell
509+
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) || !term.IsTerminal(int(os.Stderr.Fd())) {
510+
return false, nil
511+
}
512+
513+
interactive := &console.InteractiveBool{
514+
Prompt: "Do you want to switch to the new Cog runtime?",
515+
Default: true,
516+
// NonDefaultFlag is not applicable here
517+
}
518+
cogRuntime, err := interactive.Read()
519+
if err != nil {
520+
return false, fmt.Errorf("failed to read from stdin: %v", err)
521+
}
522+
// Only add cog_runtime: true to cog.yaml if answer is yes
523+
// Otherwise leave it absent so we keep nagging
524+
if !cogRuntime {
525+
console.Warnf("Not switching. Add build.cog_runtime: false to disable this reminder.")
526+
return false, nil
527+
}
528+
g.Config.Build.CogRuntime = &cogRuntime
529+
530+
console.Infof("Adding build.cog_runtime: true to %s", g.Config.Filename())
531+
newYaml, err := yaml.Marshal(g.Config)
532+
if err != nil {
533+
return false, err
534+
}
535+
path := filepath.Join(g.Dir, g.Config.Filename())
536+
oldYaml, err := os.ReadFile(path)
537+
if err != nil {
538+
return false, err
539+
}
540+
merged, err := util.OverwriteYAML(newYaml, oldYaml)
541+
if err != nil {
542+
return false, err
543+
}
544+
return true, os.WriteFile(path, merged, 0o644)
545+
}
546+
486547
func (g *StandardGenerator) pipInstalls() (string, error) {
487548
var err error
488549
includePackages := []string{}

pkg/migrate/migrator_v1_v1fast.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string, configF
241241
defer file.Close()
242242
console.Infof("Writing config changes to %s.", configFilepath)
243243

244-
mergedCfgData, err := OverwrightYAML(data, content)
244+
mergedCfgData, err := util.OverwriteYAML(data, content)
245245
if err != nil {
246246
return err
247247
}

pkg/util/console/interactive.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ func (i InteractiveBool) Read() (bool, error) {
9191
reader := bufio.NewReader(os.Stdin)
9292
text, err := reader.ReadString('\n')
9393
if err != nil {
94-
if err == io.EOF {
94+
// Only translate error if a flag is set
95+
if err == io.EOF && i.NonDefaultFlag != "" {
9596
return false, fmt.Errorf("stdin is closed. If you're running in a script, you need to pass the '%s' option", i.NonDefaultFlag)
9697
}
9798
return false, err

pkg/migrate/overwrite_yaml.go renamed to pkg/util/overwrite_yaml.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
package migrate
1+
package util
22

33
import (
44
"fmt"
55

66
"gopkg.in/yaml.v3"
77
)
88

9-
func OverwrightYAML(sourceYaml []byte, destinationYaml []byte) ([]byte, error) {
9+
func OverwriteYAML(sourceYaml []byte, destinationYaml []byte) ([]byte, error) {
1010
var sourceNode yaml.Node
1111
err := yaml.Unmarshal(sourceYaml, &sourceNode)
1212
if err != nil {

pkg/migrate/overwrite_yaml_test.go renamed to pkg/util/overwrite_yaml_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package migrate
1+
package util
22

33
import (
44
"testing"
@@ -7,7 +7,7 @@ import (
77
)
88

99
/*
10-
func TestOverwrightYAML(t *testing.T) {
10+
func TestOverwriteYAML(t *testing.T) {
1111
var yamlData1 = `build:
1212
command: "build.sh"
1313
image: "my-image"
@@ -30,13 +30,13 @@ environment:
3030
- "VAR1=new_value1"
3131
- "VAR3=value3"
3232
`
33-
content, err := OverwrightYAML([]byte(yamlData1), []byte(yamlData2))
33+
content, err := OverwriteYAML([]byte(yamlData1), []byte(yamlData2))
3434
require.NoError(t, err)
3535
require.Equal(t, yamlData1, string(content))
3636
}
3737
*/
3838

39-
func TestOverwrightYAMLWithComments(t *testing.T) {
39+
func TestOverwriteYAMLWithComments(t *testing.T) {
4040
var sourceYaml = `build:
4141
command: "build_new.sh"
4242
image: "new-image"
@@ -73,12 +73,12 @@ environment:
7373
- "VAR3=value3"
7474
`
7575

76-
content, err := OverwrightYAML([]byte(sourceYaml), []byte(destinationYaml))
76+
content, err := OverwriteYAML([]byte(sourceYaml), []byte(destinationYaml))
7777
require.NoError(t, err)
7878
require.Equal(t, expected, string(content))
7979
}
8080

81-
func TestOverwrightYAMLWithLineComments(t *testing.T) {
81+
func TestOverwriteYAMLWithLineComments(t *testing.T) {
8282
var sourceYaml = `build:
8383
command: "build_new.sh"
8484
image: "new-image"
@@ -116,7 +116,7 @@ environment:
116116
- "VAR1=new_value1"
117117
- "VAR3=value3"
118118
`
119-
content, err := OverwrightYAML([]byte(sourceYaml), []byte(destinationYaml))
119+
content, err := OverwriteYAML([]byte(sourceYaml), []byte(destinationYaml))
120120
require.NoError(t, err)
121121
require.Equal(t, expected, string(content))
122122
}
@@ -177,7 +177,7 @@ build:
177177
# predict.py defines how predictions are run on your model
178178
predict: "predict.py:Predictor"
179179
`
180-
content, err := OverwrightYAML([]byte(sourceYaml), []byte(destinationYaml))
180+
content, err := OverwriteYAML([]byte(sourceYaml), []byte(destinationYaml))
181181
require.NoError(t, err)
182182
require.Equal(t, expected, string(content))
183183
}

test-integration/test_integration/test_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_run_with_piped_stdin_returned_to_stdout(tmpdir_factory, cog_binary):
131131
def test_run_shell_with_with_interactive_tty(tmpdir_factory, cog_binary):
132132
tmpdir = tmpdir_factory.mktemp("project")
133133
(tmpdir / "cog.yaml").write_text(
134-
"build:\n python_version: '3.13'\n",
134+
"build:\n python_version: '3.13'\n cog_runtime: true\n",
135135
encoding="utf-8",
136136
)
137137

0 commit comments

Comments
 (0)