Skip to content

Commit 407311b

Browse files
committed
cog train: add image argument, -o, -e
Signed-off-by: Yorick van Pelt <[email protected]>
1 parent f862da2 commit 407311b

File tree

1 file changed

+52
-20
lines changed

1 file changed

+52
-20
lines changed

pkg/cli/train.go

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cli
22

33
import (
4+
"fmt"
45
"os"
56
"os/signal"
67
"syscall"
@@ -15,16 +16,21 @@ import (
1516
)
1617

1718
var (
19+
trainEnvFlags []string
1820
trainInputFlags []string
21+
trainOutPath string
1922
)
2023

2124
func newTrainCommand() *cobra.Command {
2225
cmd := &cobra.Command{
23-
Use: "train",
26+
Use: "train [image]",
2427
Short: "Run a training",
2528
Long: `Run a training.
2629
27-
It will build the model in the current directory and train it.`,
30+
If 'image' is passed, it will run the training on that Docker image.
31+
It must be an image that has been built by Cog.
32+
33+
Otherwise, it will build the model in the current directory and train it.`,
2834
RunE: cmdTrain,
2935
Args: cobra.MaximumNArgs(1),
3036
Hidden: true,
@@ -33,37 +39,62 @@ It will build the model in the current directory and train it.`,
3339
addBuildProgressOutputFlag(cmd)
3440
addDockerfileFlag(cmd)
3541
addUseCudaBaseImageFlag(cmd)
42+
addGpusFlag(cmd)
3643

3744
cmd.Flags().StringArrayVarP(&trainInputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i [email protected]")
38-
cmd.Flags().StringVarP(&outPath, "output", "o", "weights", "Output path")
45+
cmd.Flags().StringVarP(&trainOutPath, "output", "o", "weights", "Output path")
46+
cmd.Flags().StringArrayVarP(&trainEnvFlags, "env", "e", []string{}, "Environment variables, in the form name=value")
3947

4048
return cmd
4149
}
4250

4351
func cmdTrain(cmd *cobra.Command, args []string) error {
4452
imageName := ""
4553
volumes := []docker.Volume{}
46-
gpus := ""
54+
gpus := gpusFlag
4755

48-
// Build image
56+
if len(args) == 0 {
57+
// Build image
4958

50-
cfg, projectDir, err := config.GetConfig(projectDirFlag)
51-
if err != nil {
52-
return err
53-
}
59+
cfg, projectDir, err := config.GetConfig(projectDirFlag)
60+
if err != nil {
61+
return err
62+
}
5463

55-
if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, buildProgressOutput); err != nil {
56-
return err
57-
}
64+
if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, buildProgressOutput); err != nil {
65+
return err
66+
}
5867

59-
// Base image doesn't have /src in it, so mount as volume
60-
volumes = append(volumes, docker.Volume{
61-
Source: projectDir,
62-
Destination: "/src",
63-
})
68+
// Base image doesn't have /src in it, so mount as volume
69+
volumes = append(volumes, docker.Volume{
70+
Source: projectDir,
71+
Destination: "/src",
72+
})
73+
74+
if gpus == "" && cfg.Build.GPU {
75+
gpus = "all"
76+
}
77+
} else {
78+
// Use existing image
79+
imageName = args[0]
6480

65-
if cfg.Build.GPU {
66-
gpus = "all"
81+
exists, err := docker.ImageExists(imageName)
82+
if err != nil {
83+
return fmt.Errorf("Failed to determine if %s exists: %w", imageName, err)
84+
}
85+
if !exists {
86+
console.Infof("Pulling image: %s", imageName)
87+
if err := docker.Pull(imageName); err != nil {
88+
return fmt.Errorf("Failed to pull %s: %w", imageName, err)
89+
}
90+
}
91+
conf, err := image.GetConfig(imageName)
92+
if err != nil {
93+
return err
94+
}
95+
if gpus == "" && conf.Build.GPU {
96+
gpus = "all"
97+
}
6798
}
6899

69100
console.Info("")
@@ -74,6 +105,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
74105
Image: imageName,
75106
Volumes: volumes,
76107
Args: []string{"python", "-m", "cog.server.http", "--x-mode", "train"},
108+
Env: trainEnvFlags,
77109
})
78110

79111
go func() {
@@ -100,5 +132,5 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
100132
}
101133
}()
102134

103-
return predictIndividualInputs(predictor, trainInputFlags, outPath)
135+
return predictIndividualInputs(predictor, trainInputFlags, trainOutPath)
104136
}

0 commit comments

Comments
 (0)