Skip to content

Commit 1247b48

Browse files
authored
Fix monobeam client not authenticating properly (#2148)
* Monobeam client was recreating a standard HTTP client rather than using the one injected.
1 parent 59545cc commit 1247b48

File tree

2 files changed

+36
-27
lines changed

2 files changed

+36
-27
lines changed

pkg/monobeam/client.go

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,11 @@ func (c *Client) UploadFile(ctx context.Context, objectType string, digest strin
8888

8989
// Start upload
9090
uploadUrl := startUploadURL(objectType, digest)
91-
client := &http.Client{}
9291
req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl.String(), nil)
9392
if err != nil {
9493
return err
9594
}
96-
resp, err := client.Do(req)
95+
resp, err := c.client.Do(req)
9796
if err != nil {
9897
return err
9998
}
@@ -103,7 +102,7 @@ func (c *Client) UploadFile(ctx context.Context, objectType string, digest strin
103102
if resp.StatusCode == http.StatusConflict {
104103
return nil
105104
} else if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
106-
return errors.New("Bad response: " + strconv.Itoa(resp.StatusCode))
105+
return errors.New("Bad response from monobeam: " + strconv.Itoa(resp.StatusCode))
107106
}
108107

109108
// Decode the JSON payload
@@ -150,7 +149,7 @@ func (c *Client) UploadFile(ctx context.Context, objectType string, digest strin
150149
if err != nil {
151150
return err
152151
}
153-
beginResp, err := client.Do(req)
152+
beginResp, err := c.client.Do(req)
154153
if err != nil {
155154
return err
156155
}
@@ -165,7 +164,7 @@ func (c *Client) UploadFile(ctx context.Context, objectType string, digest strin
165164
return err
166165
}
167166
for i := 0; i < 100; i++ {
168-
final, err := checkVerificationStatus(req, client)
167+
final, err := c.checkVerificationStatus(req)
169168
if final {
170169
return err
171170
}
@@ -175,27 +174,8 @@ func (c *Client) UploadFile(ctx context.Context, objectType string, digest strin
175174
return nil
176175
}
177176

178-
func baseURL() url.URL {
179-
return url.URL{
180-
Scheme: env.SchemeFromEnvironment(),
181-
Host: HostFromEnvironment(),
182-
}
183-
}
184-
185-
func startUploadURL(objectType string, digest string) url.URL {
186-
uploadUrl := baseURL()
187-
uploadUrl.Path = strings.Join([]string{"", "uploads", objectType, "sha256", digest}, "/")
188-
return uploadUrl
189-
}
190-
191-
func verificationURL(objectType string, digest string, uuid string) url.URL {
192-
verificationUrl := baseURL()
193-
verificationUrl.Path = strings.Join([]string{"", "uploads", objectType, "sha256", digest, uuid, "verification"}, "/")
194-
return verificationUrl
195-
}
196-
197-
func checkVerificationStatus(req *http.Request, client *http.Client) (bool, error) {
198-
checkResp, err := client.Do(req)
177+
func (c *Client) checkVerificationStatus(req *http.Request) (bool, error) {
178+
checkResp, err := c.client.Do(req)
199179
if err != nil {
200180
return true, err
201181
}
@@ -219,3 +199,22 @@ func checkVerificationStatus(req *http.Request, client *http.Client) (bool, erro
219199

220200
return false, nil
221201
}
202+
203+
func baseURL() url.URL {
204+
return url.URL{
205+
Scheme: env.SchemeFromEnvironment(),
206+
Host: HostFromEnvironment(),
207+
}
208+
}
209+
210+
func startUploadURL(objectType string, digest string) url.URL {
211+
uploadUrl := baseURL()
212+
uploadUrl.Path = strings.Join([]string{"", "uploads", objectType, "sha256", digest}, "/")
213+
return uploadUrl
214+
}
215+
216+
func verificationURL(objectType string, digest string, uuid string) url.URL {
217+
verificationUrl := baseURL()
218+
verificationUrl.Path = strings.Join([]string{"", "uploads", objectType, "sha256", digest, uuid, "verification"}, "/")
219+
return verificationUrl
220+
}

pkg/monobeam/client_test.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@ import (
1313
"github.com/stretchr/testify/require"
1414
"github.com/vbauerster/mpb/v8"
1515

16+
"github.com/replicate/cog/pkg/docker/dockertest"
1617
"github.com/replicate/cog/pkg/env"
18+
r8HTTP "github.com/replicate/cog/pkg/http"
1719
"github.com/replicate/cog/pkg/weights"
1820
)
1921

2022
func TestUploadFile(t *testing.T) {
2123
// Setup mock http server
2224
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25+
require.Equal(t, r8HTTP.UserAgent(), r.Header.Get(r8HTTP.UserAgentHeader))
2326
w.WriteHeader(http.StatusConflict)
2427
}))
2528
defer server.Close()
@@ -44,7 +47,14 @@ func TestUploadFile(t *testing.T) {
4447
require.NoError(t, err)
4548
}
4649

47-
client := NewClient(http.DefaultClient)
50+
// Setup mock command
51+
command := dockertest.NewMockCommand()
52+
53+
// Setup http client
54+
httpClient, err := r8HTTP.ProvideHTTPClient(command)
55+
require.NoError(t, err)
56+
57+
client := NewClient(httpClient)
4858
ctx := context.Background()
4959
p := mpb.New(
5060
mpb.WithRefreshRate(180 * time.Millisecond),

0 commit comments

Comments
 (0)