Skip to content

Commit d1857b3

Browse files
authored
state:modified vars, behind "state_modified_compare_vars" behaviour flag (#10793)
1 parent 2ff3f20 commit d1857b3

File tree

22 files changed

+614
-21
lines changed

22 files changed

+614
-21
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
kind: Features
2+
body: Include models that depend on changed vars in state:modified, add state:modified.vars
3+
selection method
4+
time: 2024-07-29T17:32:03.368508-04:00
5+
custom:
6+
Author: michelleark
7+
Issue: "4304"

core/dbt/artifacts/resources/v1/components.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class ParsedResource(ParsedResourceMandatory):
197197
unrendered_config_call_dict: Dict[str, Any] = field(default_factory=dict)
198198
relation_name: Optional[str] = None
199199
raw_code: str = ""
200+
vars: Dict[str, Any] = field(default_factory=dict)
200201

201202
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
202203
dct = super().__post_serialize__(dct, context)

core/dbt/artifacts/resources/v1/exposure.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Exposure(GraphResource):
4141
tags: List[str] = field(default_factory=list)
4242
config: ExposureConfig = field(default_factory=ExposureConfig)
4343
unrendered_config: Dict[str, Any] = field(default_factory=dict)
44+
vars: Dict[str, Any] = field(default_factory=dict)
4445
url: Optional[str] = None
4546
depends_on: DependsOn = field(default_factory=DependsOn)
4647
refs: List[RefArgs] = field(default_factory=list)

core/dbt/artifacts/resources/v1/source_definition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,6 @@ class SourceDefinition(ParsedSourceMandatory):
6969
config: SourceConfig = field(default_factory=SourceConfig)
7070
patch_path: Optional[str] = None
7171
unrendered_config: Dict[str, Any] = field(default_factory=dict)
72+
vars: Dict[str, Any] = field(default_factory=dict)
7273
relation_name: Optional[str] = None
7374
created_at: float = field(default_factory=lambda: time.time())

core/dbt/context/configured.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,35 @@ def __init__(self, package_name: str):
3131
self.resource_type = NodeType.Model
3232

3333

34+
class SchemaYamlVars:
35+
def __init__(self):
36+
self.env_vars = {}
37+
self.vars = {}
38+
39+
3440
class ConfiguredVar(Var):
3541
def __init__(
3642
self,
3743
context: Dict[str, Any],
3844
config: AdapterRequiredConfig,
3945
project_name: str,
46+
schema_yaml_vars: Optional[SchemaYamlVars] = None,
4047
):
4148
super().__init__(context, config.cli_vars)
4249
self._config = config
4350
self._project_name = project_name
51+
self.schema_yaml_vars = schema_yaml_vars
4452

4553
def __call__(self, var_name, default=Var._VAR_NOTSET):
4654
my_config = self._config.load_dependencies()[self._project_name]
4755

56+
var_found = False
57+
var_value = None
58+
4859
# cli vars > active project > local project
4960
if var_name in self._config.cli_vars:
50-
return self._config.cli_vars[var_name]
61+
var_found = True
62+
var_value = self._config.cli_vars[var_name]
5163

5264
adapter_type = self._config.credentials.type
5365
lookup = FQNLookup(self._project_name)
@@ -58,19 +70,21 @@ def __call__(self, var_name, default=Var._VAR_NOTSET):
5870
all_vars.add(my_config.vars.vars_for(lookup, adapter_type))
5971
all_vars.add(active_vars)
6072

61-
if var_name in all_vars:
62-
return all_vars[var_name]
73+
if not var_found and var_name in all_vars:
74+
var_found = True
75+
var_value = all_vars[var_name]
6376

64-
if default is not Var._VAR_NOTSET:
65-
return default
66-
67-
return self.get_missing_var(var_name)
77+
if not var_found and default is not Var._VAR_NOTSET:
78+
var_found = True
79+
var_value = default
6880

81+
if not var_found:
82+
return self.get_missing_var(var_name)
83+
else:
84+
if self.schema_yaml_vars:
85+
self.schema_yaml_vars.vars[var_name] = var_value
6986

70-
class SchemaYamlVars:
71-
def __init__(self):
72-
self.env_vars = {}
73-
self.vars = {}
87+
return var_value
7488

7589

7690
class SchemaYamlContext(ConfiguredContext):
@@ -82,7 +96,7 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY
8296

8397
@contextproperty()
8498
def var(self) -> ConfiguredVar:
85-
return ConfiguredVar(self._ctx, self.config, self._project_name)
99+
return ConfiguredVar(self._ctx, self.config, self._project_name, self.schema_yaml_vars)
86100

87101
@contextmember()
88102
def env_var(self, var: str, default: Optional[str] = None) -> str:

core/dbt/context/providers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,14 @@ def get_missing_var(self, var_name):
790790
# in the parser, just always return None.
791791
return None
792792

793+
def __call__(self, var_name: str, default: Any = ModelConfiguredVar._VAR_NOTSET) -> Any:
794+
var_value = super().__call__(var_name, default)
795+
796+
if self._node and hasattr(self._node, "vars"):
797+
self._node.vars[var_name] = var_value
798+
799+
return var_value
800+
793801

794802
class RuntimeVar(ModelConfiguredVar):
795803
pass

core/dbt/contracts/files.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class SchemaSourceFile(BaseSourceFile):
213213
sop: List[SourceKey] = field(default_factory=list)
214214
env_vars: Dict[str, Any] = field(default_factory=dict)
215215
unrendered_configs: Dict[str, Any] = field(default_factory=dict)
216+
vars: Dict[str, Any] = field(default_factory=dict)
216217
pp_dict: Optional[Dict[str, Any]] = None
217218
pp_test_index: Optional[Dict[str, Any]] = None
218219

@@ -353,6 +354,22 @@ def delete_from_unrendered_configs(self, yaml_key, name):
353354
if not self.unrendered_configs[yaml_key]:
354355
del self.unrendered_configs[yaml_key]
355356

357+
def add_vars(self, vars: Dict[str, Any], yaml_key: str, name: str) -> None:
358+
if yaml_key not in self.vars:
359+
self.vars[yaml_key] = {}
360+
361+
if name not in self.vars[yaml_key]:
362+
self.vars[yaml_key][name] = vars
363+
364+
def get_vars(self, yaml_key: str, name: str) -> Dict[str, Any]:
365+
if yaml_key not in self.vars:
366+
return {}
367+
368+
if name not in self.vars[yaml_key]:
369+
return {}
370+
371+
return self.vars[yaml_key][name]
372+
356373
def add_env_var(self, var, yaml_key, name):
357374
if yaml_key not in self.env_vars:
358375
self.env_vars[yaml_key] = {}

core/dbt/contracts/graph/nodes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,19 +369,30 @@ def same_contract(self, old, adapter_type=None) -> bool:
369369
# This would only apply to seeds
370370
return True
371371

372+
def same_vars(self, old) -> bool:
373+
return self.vars == old.vars
374+
372375
def same_contents(self, old, adapter_type) -> bool:
373376
if old is None:
374377
return False
375378

376379
# Need to ensure that same_contract is called because it
377380
# could throw an error
378381
same_contract = self.same_contract(old, adapter_type)
382+
383+
# Legacy behaviour
384+
if not get_flags().state_modified_compare_vars:
385+
same_vars = True
386+
else:
387+
same_vars = self.same_vars(old)
388+
379389
return (
380390
self.same_body(old)
381391
and self.same_config(old)
382392
and self.same_persisted_description(old)
383393
and self.same_fqn(old)
384394
and self.same_database_representation(old)
395+
and same_vars
385396
and same_contract
386397
and True
387398
)
@@ -1251,6 +1262,9 @@ def same_config(self, old: "SourceDefinition") -> bool:
12511262
old.unrendered_config,
12521263
)
12531264

1265+
def same_vars(self, other: "SourceDefinition") -> bool:
1266+
return self.vars == other.vars
1267+
12541268
def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
12551269
# existing when it didn't before is a change!
12561270
if old is None:
@@ -1264,13 +1278,20 @@ def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
12641278
# freshness changes are changes, I guess
12651279
# metadata/tags changes are not "changes"
12661280
# patching/description changes are not "changes"
1281+
# Legacy behaviour
1282+
if not get_flags().state_modified_compare_vars:
1283+
same_vars = True
1284+
else:
1285+
same_vars = self.same_vars(old)
1286+
12671287
return (
12681288
self.same_database_representation(old)
12691289
and self.same_fqn(old)
12701290
and self.same_config(old)
12711291
and self.same_quoting(old)
12721292
and self.same_freshness(old)
12731293
and self.same_external(old)
1294+
and same_vars
12741295
and True
12751296
)
12761297

@@ -1367,12 +1388,21 @@ def same_config(self, old: "Exposure") -> bool:
13671388
old.unrendered_config,
13681389
)
13691390

1391+
def same_vars(self, old: "Exposure") -> bool:
1392+
return self.vars == old.vars
1393+
13701394
def same_contents(self, old: Optional["Exposure"]) -> bool:
13711395
# existing when it didn't before is a change!
13721396
# metadata/tags changes are not "changes"
13731397
if old is None:
13741398
return True
13751399

1400+
# Legacy behaviour
1401+
if not get_flags().state_modified_compare_vars:
1402+
same_vars = True
1403+
else:
1404+
same_vars = self.same_vars(old)
1405+
13761406
return (
13771407
self.same_fqn(old)
13781408
and self.same_exposure_type(old)
@@ -1383,6 +1413,7 @@ def same_contents(self, old: Optional["Exposure"]) -> bool:
13831413
and self.same_label(old)
13841414
and self.same_depends_on(old)
13851415
and self.same_config(old)
1416+
and same_vars
13861417
and True
13871418
)
13881419

@@ -1634,6 +1665,7 @@ class ParsedNodePatch(ParsedPatch):
16341665
latest_version: Optional[NodeVersion]
16351666
constraints: List[Dict[str, Any]]
16361667
deprecation_date: Optional[datetime]
1668+
vars: Dict[str, Any]
16371669
time_spine: Optional[TimeSpine] = None
16381670

16391671

core/dbt/contracts/project.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ class ProjectFlags(ExtensibleDbtClassMixin):
343343
source_freshness_run_project_hooks: bool = False
344344
skip_nodes_if_on_run_start_fails: bool = False
345345
state_modified_compare_more_unrendered_values: bool = False
346+
state_modified_compare_vars: bool = False
346347

347348
@property
348349
def project_only_flags(self) -> Dict[str, Any]:
@@ -352,6 +353,7 @@ def project_only_flags(self) -> Dict[str, Any]:
352353
"source_freshness_run_project_hooks": self.source_freshness_run_project_hooks,
353354
"skip_nodes_if_on_run_start_fails": self.skip_nodes_if_on_run_start_fails,
354355
"state_modified_compare_more_unrendered_values": self.state_modified_compare_more_unrendered_values,
356+
"state_modified_compare_vars": self.state_modified_compare_vars,
355357
}
356358

357359

core/dbt/graph/selector_methods.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,7 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu
752752
"modified.relation": self.check_modified_factory("same_database_representation"),
753753
"modified.macros": self.check_modified_macros,
754754
"modified.contract": self.check_modified_contract("same_contract", adapter_type),
755+
"modified.vars": self.check_modified_factory("same_vars"),
755756
}
756757
if selector in state_checks:
757758
checker = state_checks[selector]

0 commit comments

Comments
 (0)