Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=pygit2
extension-pkg-whitelist=pygit2, pydantic

# Add files or directories to the blacklist. They should be base names, not
# paths.
Expand Down
9 changes: 4 additions & 5 deletions dvc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ class NoRemoteError(ConfigError):
def get_compiled_schema():
from voluptuous import Schema

from .config_schema import SCHEMA
from .schema.config import SCHEMA

return Schema(SCHEMA)


def to_bool(value):
from .config_schema import Bool
from .schema.config import Bool

return Bool(value)

Expand Down Expand Up @@ -191,7 +191,7 @@ def _load_paths(conf, filename):
abs_conf_dir = os.path.abspath(os.path.dirname(filename))

def resolve(path):
from .config_schema import RelPath
from dvc.schema.config import RelPath

if os.path.isabs(path) or re.match(r"\w+://", path):
return path
Expand All @@ -207,10 +207,9 @@ def resolve(path):

@staticmethod
def _to_relpath(conf_dir, path):
from dvc.schema.config import RelPath
from dvc.utils import relpath

from .config_schema import RelPath

if re.match(r"\w+://", path):
return path

Expand Down
51 changes: 41 additions & 10 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import collections
import contextlib
import logging
import os
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Callable, Union

from voluptuous import MultipleInvalid

Expand All @@ -28,6 +27,7 @@
if TYPE_CHECKING:
from dvc.repo import Repo


logger = logging.getLogger(__name__)

DVC_FILE = "Dvcfile"
Expand Down Expand Up @@ -77,7 +77,7 @@ def check_dvc_filename(path):


class FileMixin:
SCHEMA = None
SCHEMA: Callable

def __init__(self, repo, path, verify=True, **kwargs):
self.repo = repo
Expand Down Expand Up @@ -122,7 +122,7 @@ def _check_gitignored(self):
if self._is_git_ignored():
raise FileIsGitIgnored(self.path)

def _load(self):
def _load(self, use_pydantic: bool = False):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To test new schema, you can toggle this to True, and it should start using the new schema for validation.

# it raises the proper exceptions by priority:
# 1. when the file doesn't exists
# 2. filename is not a DVC file
Expand All @@ -140,16 +140,27 @@ def _load(self):
with self.repo.fs.open(self.path, encoding="utf-8") as fd:
stage_text = fd.read()
d = parse_yaml(stage_text, self.path)
return self.validate(d, self.relpath), stage_text
return (
self.validate(d, self.relpath, use_pydantic=use_pydantic),
stage_text,
)

@classmethod
def validate(cls, d, fname=None):
assert isinstance(cls.SCHEMA, collections.abc.Callable)
def validate(cls, d, fname=None, use_pydantic: bool = False):
if use_pydantic:
logger.debug("using experimental pydantic schema for %s", fname)
with contextlib.suppress(NotImplementedError):
return cls.validate_pyd(d, fname)

try:
return cls.SCHEMA(d) # pylint: disable=not-callable
except MultipleInvalid as exc:
raise StageFileFormatError(f"'{fname}' format error: {exc}")

@classmethod
def validate_pyd(cls, d, fname=None):
raise NotImplementedError

def remove(self, force=False): # pylint: disable=unused-argument
with contextlib.suppress(FileNotFoundError):
os.unlink(self.path)
Expand Down Expand Up @@ -199,13 +210,29 @@ def merge(self, ancestor, other):
stage.merge(ancestor.stage, other.stage)
self.dump(stage)

@classmethod
def validate_pyd(cls, d, fname=None):
raise NotImplementedError


class PipelineFile(FileMixin):
"""Abstraction for pipelines file, .yaml + .lock combined."""

from dvc.schema import COMPILED_MULTI_STAGE_SCHEMA as SCHEMA
from dvc.stage.loader import StageLoader as LOADER

@classmethod
def validate_pyd(cls, d, fname=None):
from pydantic import ValidationError

from dvc.schema.dvc_yaml import get_schema

try:
get_schema().parse_obj(d)
return d
except ValidationError as exc:
raise StageFileFormatError(f"'{fname}' format error: {str(exc)}")

@property
def _lockfile(self):
return Lockfile(self.repo, os.path.splitext(self.path)[0] + ".lock")
Expand Down Expand Up @@ -329,7 +356,11 @@ def migrate_lock_v1_to_v2(d, version_info):

class Lockfile(FileMixin):
@classmethod
def validate(cls, d, fname=None):
def validate_pyd(cls, d, fname=None):
raise NotImplementedError

@classmethod
def validate(cls, d, fname=None, use_pydantic: bool = False):
schema = get_lockfile_schema(d)
try:
return schema(d)
Expand All @@ -339,9 +370,9 @@ def validate(cls, d, fname=None):
def _verify_filename(self):
pass # lockfile path is hardcoded, so no need to verify here

def _load(self):
def _load(self, use_pydantic: bool = False):
try:
return super()._load()
return super()._load(use_pydantic=use_pydantic)
except StageFileDoesNotExistError:
# we still need to account for git-ignored dvc.lock file
# even though it may not exist or have been .dvcignored
Expand Down
2 changes: 1 addition & 1 deletion dvc/fs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _resolve_remote_refs(config, remote_conf):

def get_cloud_fs(repo, **kwargs):
from dvc.config import ConfigError
from dvc.config_schema import SCHEMA, Invalid
from dvc.schema.config import SCHEMA, Invalid

remote_conf = get_fs_config(repo.config, **kwargs)
try:
Expand Down
2 changes: 1 addition & 1 deletion dvc/objects/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, repo):
elif "dir" not in config:
settings = None
else:
from dvc.config_schema import LOCAL_COMMON
from dvc.schema.config import LOCAL_COMMON

settings = {"url": config["dir"]}
for opt in LOCAL_COMMON.keys():
Expand Down
File renamed without changes.
8 changes: 8 additions & 0 deletions dvc/schema/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Extra


class BaseModel(PydanticBaseModel):
class Config:
# TODO: figure out a way to make it configurable
extra = Extra.forbid
File renamed without changes.
190 changes: 190 additions & 0 deletions dvc/schema/dvc_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from typing import Any, Dict, List, Optional, Type, Union

from pydantic import Field, validator

from dvc.types import OptStr

from .base import BaseModel


class OutProps(BaseModel):
cache: bool = Field(True, description="Cache output by DVC")
persist: bool = Field(False, description="Persist output between runs")
checkpoint: bool = Field(
False,
description="Indicate that the output is associated with "
"in-code checkpoints",
)
desc: Optional[str] = Field(
None,
description="User description for the output",
title="Description",
)


class MetricProps(OutProps):
pass


FilePath = str


class PlotProps(OutProps):
template: Optional[FilePath] = Field(
None, description="Default plot template"
)
x: OptStr = Field(
None, description="Default field name to use as x-axis data"
)
y: OptStr = Field(
None, description="Default field name to use as y-axis data"
)
x_label: OptStr = Field(None, description="Default label for the x-axis")
y_label: OptStr = Field(None, description="Default label for the y-axis")
title: OptStr = Field(None, description="Default plot title")
header: bool = Field(
False, description="Whether the target CSV or TSV has a header or not"
)


class LiveProps(PlotProps):
summary: bool = Field(
True, description="Signals dvclive to dump latest metrics file"
)
html: bool = Field(
True, description="Signals dvclive to produce training report"
)


# eg: "file.txt", "file.txt:foo,bar", "file.txt:foo"
VarImportSpec = str # validate here?
# {"foo" (str) : "foobar" (Any) }
LocalVarKey = str
LocalVarValue = Any
VarsSpec = List[Union[VarImportSpec, Dict[LocalVarKey, LocalVarValue]]]

# key name of the param, usually from `params.yaml`
ParamKey = str
ParamsSpec = List[Union[ParamKey, Dict[FilePath, List[ParamKey]]]]


class WithDescription(BaseModel):
desc: OptStr = Field(
None, description="Description of the stage", title="Description"
)


class StageDefinition(WithDescription, BaseModel):
"""This is the raw one, which could be parametrized."""

cmd: Union[str, List[str]] = Field(
..., description="Command to run", title="Command(s)"
) # required
wdir: OptStr = Field(
None, description="Working directory", title="Working Directory"
)
deps: List[FilePath] = Field(
default_factory=list,
description="Dependencies for the stage",
title="Dependencies",
)
params: ParamsSpec = Field(
default_factory=list,
description="Params for the stage",
title="Parameter Dependencies",
)
vars: VarsSpec = Field(
default_factory=list,
description="Variables for the stage",
title="Variables",
)
frozen: bool = Field(False, description="Assume stage as unchanged")
meta: Any = Field(
None, description="Additional information/metadata", title="Metadata"
)
always_changed: bool = Field(
False, description="Assume stage as always changed"
)
outs: List[Union[Dict[FilePath, OutProps], FilePath]] = Field(
default_factory=list,
description="Additional information/metadata",
title="Outputs",
)
plots: List[
Union[Dict[FilePath, Union[PlotProps, List[PlotProps]]], FilePath]
] = Field(
default_factory=list, description="Plots of the stage", title="Plots"
)
metrics: List[Union[Dict[FilePath, MetricProps], FilePath]] = Field(
default_factory=list,
description="Metrics of the stage",
title="Metrics",
)
live: Union[Dict[FilePath, LiveProps], FilePath] = Field(
default_factory=list,
description="Declare output as dvclive",
title="Dvclive",
)

# Note: we don't support parametrization in props and in
# frozen/always_changed/meta yet.


# trying to differentiate here between normal str expectation
# and parametrized ones
ParametrizedString = str # validate with constr()?

ListAny = List[Any]
DictStrAny = Dict[str, Any]

FOREACH_DESC = """\
Iterable to loop through in foreach. Can be a parametrized string, list \
or a dictionary.

The stages will be generated by iterating through this data, by substituting
data in the `do` block."""

DO_DESC = """\
Parametrized stage definition that'll be substituted over for each of the
value from the foreach data."""


class ForeachDo(BaseModel):
foreach: Union[ParametrizedString, ListAny, DictStrAny] = Field(
..., description=FOREACH_DESC
)
do: StageDefinition = Field(..., description=DO_DESC)


Definition = Union[ForeachDo, StageDefinition]
StageName = str


class Schema(BaseModel):
vars: VarsSpec = Field(
default_factory=list,
description="Variables for the parametrization",
title="Variables",
)
stages: Dict[StageName, Definition] = Field(
default_factory=dict, description="List of stages"
)

@validator("stages", each_item=True, pre=True)
@classmethod
def validate_stages(cls, v: Any):
if not isinstance(v, dict):
raise TypeError("must be a dict")

if v.keys() & {"foreach", "do"}:
return ForeachDo.parse_obj(v)

return StageDefinition.parse_obj(v)

class Config:
title = "dvc.yaml schema"


def get_schema(extra: str = "forbid") -> Type[Schema]:
assert extra
return Schema
Loading