Skip to content

Commit 9adf43d

Browse files
committed
fix(downloader): do not download model files if not necessary
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 74ee146 commit 9adf43d

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

core/config/model_config_loader.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -287,18 +287,20 @@ func (bcl *ModelConfigLoader) Preload(modelPath string) error {
287287
if config.IsModelURL() {
288288
modelFileName := config.ModelFileName()
289289
uri := downloader.URI(config.Model)
290-
// check if file exists
291-
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
292-
err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status)
293-
if err != nil {
294-
return err
290+
if uri.ResolveURL() != config.Model {
291+
// check if file exists
292+
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
293+
err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status)
294+
if err != nil {
295+
return err
296+
}
295297
}
296-
}
297298

298-
cc := bcl.configs[i]
299-
c := &cc
300-
c.PredictionOptions.Model = modelFileName
301-
bcl.configs[i] = *c
299+
cc := bcl.configs[i]
300+
c := &cc
301+
c.PredictionOptions.Model = modelFileName
302+
bcl.configs[i] = *c
303+
}
302304
}
303305

304306
if config.IsMMProjURL() {

pkg/downloader/uri.go

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,26 @@ func (s URI) ResolveURL() string {
214214
repository = strings.Replace(repository, HuggingFacePrefix2, "", 1)
215215
// convert repository to a full URL.
216216
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
217-
owner := strings.Split(repository, "/")[0]
218-
repo := strings.Split(repository, "/")[1]
219217

220-
branch := "main"
221-
if strings.Contains(repo, "@") {
222-
branch = strings.Split(repository, "@")[1]
218+
repoPieces := strings.Split(repository, "/")
219+
repoID := strings.Split(repository, "@")
220+
if len(repoPieces) < 3 {
221+
return string(s)
223222
}
224-
filepath := strings.Split(repository, "/")[2]
225-
if strings.Contains(filepath, "@") {
226-
filepath = strings.Split(filepath, "@")[0]
223+
224+
owner := repoPieces[0]
225+
repo := repoPieces[1]
226+
227+
branch := "main"
228+
filepath := repoPieces[2]
229+
230+
if len(repoID) > 1 {
231+
if strings.Contains(repo, "@") {
232+
branch = repoID[1]
233+
}
234+
if strings.Contains(filepath, "@") {
235+
filepath = repoID[2]
236+
}
227237
}
228238

229239
return fmt.Sprintf("%s/%s/%s/resolve/%s/%s", HF_ENDPOINT, owner, repo, branch, filepath)

0 commit comments

Comments
 (0)