Skip to content
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20240729-173203.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Include models that depend on changed vars in state:modified, add state:modified.vars
selection method
time: 2024-07-29T17:32:03.368508-04:00
custom:
Author: michelleark
Issue: "4304"
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240923-190758.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Allow singular tests to be documented in properties.yml
time: 2024-09-23T19:07:58.151069+01:00
custom:
Author: aranke
Issue: "9005"
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class ParsedResource(ParsedResourceMandatory):
unrendered_config_call_dict: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
raw_code: str = ""
vars: Dict[str, Any] = field(default_factory=dict)

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/exposure.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Exposure(GraphResource):
tags: List[str] = field(default_factory=list)
config: ExposureConfig = field(default_factory=ExposureConfig)
unrendered_config: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
url: Optional[str] = None
depends_on: DependsOn = field(default_factory=DependsOn)
refs: List[RefArgs] = field(default_factory=list)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/source_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ class SourceDefinition(ParsedSourceMandatory):
config: SourceConfig = field(default_factory=SourceConfig)
patch_path: Optional[str] = None
unrendered_config: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
created_at: float = field(default_factory=lambda: time.time())
13 changes: 4 additions & 9 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,8 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
return [VersionSpecifier.from_version_string(v) for v in versions]


def _all_source_paths(
model_paths: List[str],
seed_paths: List[str],
snapshot_paths: List[str],
analysis_paths: List[str],
macro_paths: List[str],
) -> List[str]:
paths = chain(model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths)
def _all_source_paths(*args: List[str]) -> List[str]:
paths = chain(*args)
# Strip trailing slashes since the path is the same even though the name is not
stripped_paths = map(lambda s: s.rstrip("/"), paths)
return list(set(stripped_paths))
Expand Down Expand Up @@ -409,7 +403,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])

all_source_paths: List[str] = _all_source_paths(
model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths
model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths, test_paths
)

docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
Expand Down Expand Up @@ -652,6 +646,7 @@ def all_source_paths(self) -> List[str]:
self.snapshot_paths,
self.analysis_paths,
self.macro_paths,
self.test_paths,
)

@property
Expand Down
38 changes: 26 additions & 12 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,35 @@ def __init__(self, package_name: str):
self.resource_type = NodeType.Model


class SchemaYamlVars:
def __init__(self):
self.env_vars = {}
self.vars = {}


class ConfiguredVar(Var):
def __init__(
self,
context: Dict[str, Any],
config: AdapterRequiredConfig,
project_name: str,
schema_yaml_vars: Optional[SchemaYamlVars] = None,
):
super().__init__(context, config.cli_vars)
self._config = config
self._project_name = project_name
self.schema_yaml_vars = schema_yaml_vars

def __call__(self, var_name, default=Var._VAR_NOTSET):
my_config = self._config.load_dependencies()[self._project_name]

var_found = False
var_value = None

# cli vars > active project > local project
if var_name in self._config.cli_vars:
return self._config.cli_vars[var_name]
var_found = True
var_value = self._config.cli_vars[var_name]

adapter_type = self._config.credentials.type
lookup = FQNLookup(self._project_name)
Expand All @@ -58,19 +70,21 @@ def __call__(self, var_name, default=Var._VAR_NOTSET):
all_vars.add(my_config.vars.vars_for(lookup, adapter_type))
all_vars.add(active_vars)

if var_name in all_vars:
return all_vars[var_name]
if not var_found and var_name in all_vars:
var_found = True
var_value = all_vars[var_name]

if default is not Var._VAR_NOTSET:
return default

return self.get_missing_var(var_name)
if not var_found and default is not Var._VAR_NOTSET:
var_found = True
var_value = default

if not var_found:
return self.get_missing_var(var_name)
else:
if self.schema_yaml_vars:
self.schema_yaml_vars.vars[var_name] = var_value

class SchemaYamlVars:
def __init__(self):
self.env_vars = {}
self.vars = {}
return var_value


class SchemaYamlContext(ConfiguredContext):
Expand All @@ -82,7 +96,7 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY

@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self._project_name)
return ConfiguredVar(self._ctx, self.config, self._project_name, self.schema_yaml_vars)

@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
Expand Down
8 changes: 8 additions & 0 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,14 @@ def get_missing_var(self, var_name):
# in the parser, just always return None.
return None

def __call__(self, var_name: str, default: Any = ModelConfiguredVar._VAR_NOTSET) -> Any:
var_value = super().__call__(var_name, default)

if self._node and hasattr(self._node, "vars"):
self._node.vars[var_name] = var_value

return var_value


class RuntimeVar(ModelConfiguredVar):
pass
Expand Down
17 changes: 17 additions & 0 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class SchemaSourceFile(BaseSourceFile):
# created too, but those are in 'sources'
sop: List[SourceKey] = field(default_factory=list)
env_vars: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
unrendered_configs: Dict[str, Any] = field(default_factory=dict)
pp_dict: Optional[Dict[str, Any]] = None
pp_test_index: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -318,6 +319,22 @@ def get_all_test_ids(self):
test_ids.extend(self.data_tests[key][name])
return test_ids

def add_vars(self, vars: Dict[str, Any], yaml_key: str, name: str) -> None:
if yaml_key not in self.vars:
self.vars[yaml_key] = {}

if name not in self.vars[yaml_key]:
self.vars[yaml_key][name] = vars

def get_vars(self, yaml_key: str, name: str) -> Dict[str, Any]:
if yaml_key not in self.vars:
return {}

if name not in self.vars[yaml_key]:
return {}

return self.vars[yaml_key][name]

def add_unrendered_config(self, unrendered_config, yaml_key, name, version=None):
versioned_name = f"{name}_v{version}" if version is not None else name

Expand Down
50 changes: 49 additions & 1 deletion core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
SavedQuery,
SeedNode,
SemanticModel,
SingularTestNode,
SourceDefinition,
UnitTestDefinition,
UnitTestFileFixture,
Expand Down Expand Up @@ -89,7 +90,7 @@
RefName = str


def find_unique_id_for_package(storage, key, package: Optional[PackageName]):
def find_unique_id_for_package(storage, key, package: Optional[PackageName]) -> Optional[UniqueID]:
if key not in storage:
return None

Expand Down Expand Up @@ -470,6 +471,43 @@ class AnalysisLookup(RefableLookup):
_versioned_types: ClassVar[set] = set()


class SingularTestLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

def get_unique_id(self, search_name, package: Optional[PackageName]) -> Optional[UniqueID]:
return find_unique_id_for_package(self.storage, search_name, package)

def find(
self, search_name, package: Optional[PackageName], manifest: "Manifest"
) -> Optional[SingularTestNode]:
unique_id = self.get_unique_id(search_name, package)
if unique_id is not None:
return self.perform_lookup(unique_id, manifest)
return None

def add_singular_test(self, source: SingularTestNode) -> None:
if source.search_name not in self.storage:
self.storage[source.search_name] = {}

self.storage[source.search_name][source.package_name] = source.unique_id

def populate(self, manifest: "Manifest") -> None:
for node in manifest.nodes.values():
if isinstance(node, SingularTestNode):
self.add_singular_test(node)

def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SingularTestNode:
if unique_id not in manifest.nodes:
raise dbt_common.exceptions.DbtInternalError(
f"Singular test {unique_id} found in cache but not found in manifest"
)
node = manifest.nodes[unique_id]
assert isinstance(node, SingularTestNode)
return node


def _packages_to_search(
current_project: str,
node_package: str,
Expand Down Expand Up @@ -869,6 +907,9 @@ class Manifest(MacroMethods, dbtClassMixin):
_analysis_lookup: Optional[AnalysisLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_singular_test_lookup: Optional[SingularTestLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_parsing_info: ParsingInfo = field(
default_factory=ParsingInfo,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
Expand Down Expand Up @@ -1264,6 +1305,12 @@ def analysis_lookup(self) -> AnalysisLookup:
self._analysis_lookup = AnalysisLookup(self)
return self._analysis_lookup

@property
def singular_test_lookup(self) -> SingularTestLookup:
if self._singular_test_lookup is None:
self._singular_test_lookup = SingularTestLookup(self)
return self._singular_test_lookup

@property
def external_node_unique_ids(self):
return [node.unique_id for node in self.nodes.values() if node.is_external_node]
Expand Down Expand Up @@ -1708,6 +1755,7 @@ def __reduce_ex__(self, protocol):
self._semantic_model_by_measure_lookup,
self._disabled_lookup,
self._analysis_lookup,
self._singular_test_lookup,
)
return self.__class__, args

Expand Down
27 changes: 27 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def same_contract(self, old, adapter_type=None) -> bool:
# This would only apply to seeds
return True

def same_vars(self, old) -> bool:
if get_flags().state_modified_compare_vars:
return self.vars == old.vars
else:
return True

def same_contents(self, old, adapter_type) -> bool:
if old is None:
return False
Expand All @@ -382,6 +388,7 @@ def same_contents(self, old, adapter_type) -> bool:
and self.same_persisted_description(old)
and self.same_fqn(old)
and self.same_database_representation(old)
and self.same_vars(old)
and same_contract
and True
)
Expand Down Expand Up @@ -1251,6 +1258,12 @@ def same_config(self, old: "SourceDefinition") -> bool:
old.unrendered_config,
)

def same_vars(self, other: "SourceDefinition") -> bool:
if get_flags().state_modified_compare_vars:
return self.vars == other.vars
else:
return True

def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
# existing when it didn't before is a change!
if old is None:
Expand All @@ -1271,6 +1284,7 @@ def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
and self.same_quoting(old)
and self.same_freshness(old)
and self.same_external(old)
and self.same_vars(old)
and True
)

Expand Down Expand Up @@ -1367,6 +1381,12 @@ def same_config(self, old: "Exposure") -> bool:
old.unrendered_config,
)

def same_vars(self, old: "Exposure") -> bool:
if get_flags().state_modified_compare_vars:
return self.vars == old.vars
else:
return True

def same_contents(self, old: Optional["Exposure"]) -> bool:
# existing when it didn't before is a change!
# metadata/tags changes are not "changes"
Expand All @@ -1383,6 +1403,7 @@ def same_contents(self, old: Optional["Exposure"]) -> bool:
and self.same_label(old)
and self.same_depends_on(old)
and self.same_config(old)
and self.same_vars(old)
and True
)

Expand Down Expand Up @@ -1634,6 +1655,7 @@ class ParsedNodePatch(ParsedPatch):
latest_version: Optional[NodeVersion]
constraints: List[Dict[str, Any]]
deprecation_date: Optional[datetime]
vars: Dict[str, Any]
time_spine: Optional[TimeSpine] = None


Expand All @@ -1642,6 +1664,11 @@ class ParsedMacroPatch(ParsedPatch):
arguments: List[MacroArgument] = field(default_factory=list)


@dataclass
class ParsedSingularTestPatch(ParsedPatch):
pass


# ====================================
# Node unions/categories
# ====================================
Expand Down
5 changes: 5 additions & 0 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ class UnparsedAnalysisUpdate(HasConfig, HasColumnDocs, HasColumnProps, HasYamlMe
access: Optional[str] = None


@dataclass
class UnparsedSingularTestUpdate(HasConfig, HasColumnProps, HasYamlMetadata):
pass


@dataclass
class UnparsedNodeUpdate(HasConfig, HasColumnTests, HasColumnAndTestProps, HasYamlMetadata):
quote_columns: Optional[bool] = None
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ class ProjectFlags(ExtensibleDbtClassMixin):
require_resource_names_without_spaces: bool = False
source_freshness_run_project_hooks: bool = False
state_modified_compare_more_unrendered_values: bool = False
state_modified_compare_vars: bool = False

@property
def project_only_flags(self) -> Dict[str, Any]:
Expand All @@ -350,6 +351,7 @@ def project_only_flags(self) -> Dict[str, Any]:
"require_resource_names_without_spaces": self.require_resource_names_without_spaces,
"source_freshness_run_project_hooks": self.source_freshness_run_project_hooks,
"state_modified_compare_more_unrendered_values": self.state_modified_compare_more_unrendered_values,
"state_modified_compare_vars": self.state_modified_compare_vars,
}


Expand Down
Loading