Skip to content

Commit 4447c9e

Browse files
committed
exp run: fix failure when dependencies are from inside submodules
Fixes #7186. Fixes #10823.
1 parent a72402c commit 4447c9e

File tree

5 files changed

+63
-6
lines changed

5 files changed

+63
-6
lines changed

dvc/repo/experiments/executor/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,9 @@ def save(
298298
stages = dvc.commit([], recursive=recursive, force=True, relink=False)
299299
exp_hash = cls.hash_exp(stages)
300300
if include_untracked:
301-
dvc.scm.add(include_untracked, force=True) # type: ignore[call-arg]
301+
from dvc.scm import add_no_submodules
302+
303+
add_no_submodules(dvc.scm, include_untracked, force=True) # type: ignore[call-arg]
302304

303305
with cls.auto_push(dvc):
304306
cls.commit(
@@ -513,7 +515,9 @@ def reproduce(
513515
stages = dvc.reproduce(*args, **kwargs)
514516
if paths := cls._get_top_level_paths(dvc):
515517
logger.debug("Staging top-level files: %s", paths)
516-
dvc.scm_context.add(paths)
518+
from dvc.scm import add_no_submodules
519+
520+
add_no_submodules(dvc.scm, paths)
517521

518522
exp_hash = cls.hash_exp(stages)
519523
if not repro_dry:

dvc/repo/scm_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def _make_git_add_cmd(paths: Union[str, Iterable[str]]) -> str:
4141
def add(self, paths: Union[str, Iterable[str]]) -> None:
4242
from scmrepo.exceptions import UnsupportedIndexFormat
4343

44+
from dvc.scm import add_no_submodules
45+
4446
try:
45-
return self.scm.add(paths)
47+
add_no_submodules(self.scm, paths)
4648
except UnsupportedIndexFormat:
4749
link = "https://github.com/iterative/dvc/issues/610"
4850
add_cmd = self._make_git_add_cmd([relpath(path) for path in paths])

dvc/scm.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,27 @@
11
"""Manages source control systems (e.g. Git)."""
22

33
import os
4-
from collections.abc import Iterator, Mapping
4+
from collections.abc import Iterable, Iterator, Mapping
55
from contextlib import contextmanager
66
from functools import partial
77
from typing import TYPE_CHECKING, Literal, Optional, Union, overload
88

99
from funcy import group_by
10-
from scmrepo.base import Base # noqa: F401
10+
from scmrepo.base import Base # noqa: TC002
1111
from scmrepo.git import Git
1212
from scmrepo.noscm import NoSCM
1313

1414
from dvc.exceptions import DvcException
15+
from dvc.log import logger
1516
from dvc.progress import Tqdm
1617

1718
if TYPE_CHECKING:
1819
from scmrepo.progress import GitProgressEvent
1920

2021
from dvc.fs import FileSystem
2122

23+
logger = logger.getChild(__name__)
24+
2225

2326
class SCMError(DvcException):
2427
"""Base class for source control management errors."""
@@ -283,3 +286,36 @@ def lfs_prefetch(fs: "FileSystem", paths: list[str]):
283286
include=[(path if path.startswith("/") else f"/{path}") for path in paths],
284287
progress=pbar.update_git,
285288
)
289+
290+
291+
def add_no_submodules(
292+
scm: "Base",
293+
paths: Union[str, Iterable[str]],
294+
update: bool = False,
295+
force: bool = False,
296+
) -> None:
297+
"""Stage paths to Git, excluding those inside submodules."""
298+
299+
if not isinstance(scm, Git):
300+
return
301+
302+
if isinstance(paths, str):
303+
paths = [paths]
304+
305+
submodule_roots = {os.path.join(scm.root_dir, sub) for sub in scm.list_submodules()}
306+
307+
repo_paths: list[str] = []
308+
skipped_paths: list[str] = []
309+
310+
for p in paths:
311+
abs_path = os.path.abspath(p)
312+
if abs_path in submodule_roots or abs_path.startswith(tuple(submodule_roots)):
313+
skipped_paths.append(p)
314+
else:
315+
repo_paths.append(p)
316+
317+
if skipped_paths:
318+
msg = "Skipping staging for path(s) inside submodules: %s"
319+
logger.debug(msg, ", ".join(skipped_paths))
320+
321+
scm.add(repo_paths, update=update, force=force) # type: ignore[call-arg]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ dependencies = [
6767
"requests>=2.22",
6868
"rich>=12",
6969
"ruamel.yaml>=0.17.11",
70-
"scmrepo>=3.3.8,<4",
70+
"scmrepo>=3.5.2,<4",
7171
"shortuuid>=0.5",
7272
"shtab<2,>=1.3.4",
7373
"tabulate>=0.8.7",

tests/func/experiments/test_experiments.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,3 +793,18 @@ def test_custom_commit_message(tmp_dir, scm, dvc, tmp):
793793
)
794794
)
795795
assert scm.resolve_commit(exp).message == "custom commit message"
796+
797+
798+
@pytest.mark.parametrize("dep", ["submodule", "submodule/file"])
799+
def test_experiments_run_with_submodule_dependencies(dvc, scm, make_tmp_dir, dep):
800+
external_repo = make_tmp_dir("external_repo", scm=True)
801+
external_repo.scm_gen("file", "content", commit="add file")
802+
803+
submodules = scm.pygit2.repo.submodules
804+
submodules.add(os.fspath(external_repo), "submodule")
805+
submodules.update(init=True)
806+
scm.add_commit([".gitmodules"], message="add submodule")
807+
808+
dvc.stage.add(cmd="echo foo", deps=[dep], name="foo")
809+
810+
assert dvc.experiments.run()

0 commit comments

Comments
 (0)