|
1 | 1 | """Manages source control systems (e.g. Git)."""
|
2 | 2 |
|
3 | 3 | import os
|
4 |
| -from collections.abc import Iterator, Mapping |
| 4 | +from collections.abc import Iterable, Iterator, Mapping |
5 | 5 | from contextlib import contextmanager
|
6 | 6 | from functools import partial
|
7 | 7 | from typing import TYPE_CHECKING, Literal, Optional, Union, overload
|
8 | 8 |
|
9 | 9 | from funcy import group_by
|
10 |
| -from scmrepo.base import Base # noqa: F401 |
| 10 | +from scmrepo.base import Base # noqa: TC002 |
11 | 11 | from scmrepo.git import Git
|
12 | 12 | from scmrepo.noscm import NoSCM
|
13 | 13 |
|
14 | 14 | from dvc.exceptions import DvcException
|
| 15 | +from dvc.log import logger |
15 | 16 | from dvc.progress import Tqdm
|
16 | 17 |
|
17 | 18 | if TYPE_CHECKING:
|
18 | 19 | from scmrepo.progress import GitProgressEvent
|
19 | 20 |
|
20 | 21 | from dvc.fs import FileSystem
|
21 | 22 |
|
| 23 | +logger = logger.getChild(__name__) |
| 24 | + |
22 | 25 |
|
23 | 26 | class SCMError(DvcException):
|
24 | 27 | """Base class for source control management errors."""
|
@@ -283,3 +286,36 @@ def lfs_prefetch(fs: "FileSystem", paths: list[str]):
|
283 | 286 | include=[(path if path.startswith("/") else f"/{path}") for path in paths],
|
284 | 287 | progress=pbar.update_git,
|
285 | 288 | )
|
| 289 | + |
| 290 | + |
| 291 | +def add_no_submodules( |
| 292 | + scm: "Base", |
| 293 | + paths: Union[str, Iterable[str]], |
| 294 | + update: bool = False, |
| 295 | + force: bool = False, |
| 296 | +) -> None: |
| 297 | + """Stage paths to Git, excluding those inside submodules.""" |
| 298 | + |
| 299 | + if not isinstance(scm, Git): |
| 300 | + return |
| 301 | + |
| 302 | + if isinstance(paths, str): |
| 303 | + paths = [paths] |
| 304 | + |
| 305 | + submodule_roots = {os.path.join(scm.root_dir, sub) for sub in scm.list_submodules()} |
| 306 | + |
| 307 | + repo_paths: list[str] = [] |
| 308 | + skipped_paths: list[str] = [] |
| 309 | + |
| 310 | + for p in paths: |
| 311 | + abs_path = os.path.abspath(p) |
| 312 | + if abs_path in submodule_roots or abs_path.startswith(tuple(submodule_roots)): |
| 313 | + skipped_paths.append(p) |
| 314 | + else: |
| 315 | + repo_paths.append(p) |
| 316 | + |
| 317 | + if skipped_paths: |
| 318 | + msg = "Skipping staging for path(s) inside submodules: %s" |
| 319 | + logger.debug(msg, ", ".join(skipped_paths)) |
| 320 | + |
| 321 | + scm.add(repo_paths, update=update, force=force) # type: ignore[call-arg] |
0 commit comments