Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions dp_wizard/utils/code_generators/analyses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Protocol

from dp_wizard.types import AnalysisName
from dp_wizard.utils.code_generators.abstract_generator import AbstractGenerator
from dp_wizard.utils.code_generators.base_generators._base_generator import (
BaseGenerator,
)


class Analysis(Protocol): # pragma: no cover
Expand All @@ -16,15 +18,15 @@ def input_names(self) -> list[str]: ...

@staticmethod
def make_query(
code_gen: AbstractGenerator,
code_gen: BaseGenerator,
identifier: str,
accuracy_name: str,
stats_name: str,
) -> str: ...

@staticmethod
def make_output(
code_gen: AbstractGenerator,
code_gen: BaseGenerator,
column_name: str,
accuracy_name: str,
stats_name: str,
Expand Down
3 changes: 1 addition & 2 deletions dp_wizard/utils/code_generators/analyses/count/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from dp_wizard_templates.code_template import Template

from dp_wizard import opendp_version
from dp_wizard import get_template_root, opendp_version
from dp_wizard.types import AnalysisName
from dp_wizard.utils.code_generators.abstract_generator import get_template_root

name = AnalysisName("Count")
blurb_md = """
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from dp_wizard_templates.code_template import Template

from dp_wizard import opendp_version
from dp_wizard import get_template_root, opendp_version
from dp_wizard.types import AnalysisName
from dp_wizard.utils.code_generators.abstract_generator import get_template_root

name = AnalysisName("Histogram")
blurb_md = """
Expand Down
24 changes: 24 additions & 0 deletions dp_wizard/utils/code_generators/base_generators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dp_wizard.utils.code_generators.base_generators._stats_generator import (
StatsGenerator,
)
from dp_wizard.utils.code_generators.base_generators._synth_generator import (
SynthGenerator,
)


class AbstractGenerator(StatsGenerator, SynthGenerator):
"""
Each class in this hierarchy has its own set of concerns:

```
NotebookGenerator ScriptGenerator
└────────┬────────┘
AbstractGenerator
┌────────┴────────┐
StatsGenerator SynthGenerator
└────────┬────────┘
BaseGenerator
```
"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,18 @@
from dp_wizard.utils.code_generators import (
AnalysisPlan,
make_column_config_block,
make_privacy_loss_block,
make_privacy_unit_block,
)
from dp_wizard.utils.code_generators.analyses import count, histogram
from dp_wizard.utils.code_generators.analyses import count
from dp_wizard.utils.dp_helper import confidence
from dp_wizard.utils.shared import make_cut_points

root = get_template_root(__file__)
root = get_template_root(Path(__file__).parent)
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_template_root function call may fail if the function is not properly imported. The original code used get_template_root(__file__) which suggests a different signature.

Suggested change
root = get_template_root(Path(__file__).parent)
root = get_template_root(__file__)

Copilot uses AI. Check for mistakes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this would be a change in semantics.



def _analysis_has_bounds(analysis) -> bool:
return analysis.analysis_name != count.name


class AbstractGenerator(ABC):
class BaseGenerator(ABC):
def __init__(self, analysis_plan: AnalysisPlan):
self.analysis_plan = analysis_plan

Expand Down Expand Up @@ -58,9 +55,6 @@ def _get_root_template(self) -> str:
noun = self._get_notebook_or_script()
return f"{adj}_{noun}"

@abstractmethod
def _make_stats_context(self) -> str: ... # pragma: no cover

@abstractmethod
def _make_extra_blocks(self) -> dict[str, str]: ... # pragma: no cover

Expand Down Expand Up @@ -95,7 +89,9 @@ def template():
)
.fill_code_blocks(
IMPORTS_BLOCK=Template(template).finish(),
UTILS_BLOCK=(Path(__file__).parent.parent / "shared.py").read_text(),
UTILS_BLOCK=(
Path(__file__).parent.parent.parent / "shared.py"
).read_text(),
**self._make_extra_blocks(),
)
.fill_comment_blocks(
Expand Down Expand Up @@ -176,17 +172,6 @@ def _make_column_config_dict(self):
def _make_confidence_note(self):
return f"{int(confidence * 100)}% confidence interval"

def _make_stats_queries(self):
to_return = [
self._make_python_cell(
f"confidence = {confidence} # {self._make_confidence_note()}"
)
]
for column_name in self.analysis_plan.columns.keys():
to_return.append(self._make_query(column_name))

return "\n".join(to_return)

def _make_query(self, column_name):
plan = self.analysis_plan.columns[column_name]
identifier = ColumnIdentifier(column_name)
Expand Down Expand Up @@ -235,177 +220,3 @@ def _make_weights_expression(self):
)
+ "]"
)

def _make_partial_stats_context(self):

from dp_wizard.utils.code_generators.analyses import (
get_analysis_by_name,
has_bins,
)

bin_column_names = [
ColumnIdentifier(name)
for name, plan in self.analysis_plan.columns.items()
if has_bins(get_analysis_by_name(plan[0].analysis_name))
]

privacy_unit_block = make_privacy_unit_block(
contributions=self.analysis_plan.contributions,
contributions_entity=self.analysis_plan.contributions_entity,
)
privacy_loss_block = make_privacy_loss_block(
pure=False,
epsilon=self.analysis_plan.epsilon,
max_rows=self.analysis_plan.max_rows,
)

is_just_histograms = all(
plan_column[0].analysis_name == histogram.name
for plan_column in self.analysis_plan.columns.values()
)
margins_list = (
# Histograms don't need margins.
"[]"
if is_just_histograms
else self._make_margins_list(
bin_names=[f"{name}_bin" for name in bin_column_names],
groups=self.analysis_plan.groups,
max_rows=self.analysis_plan.max_rows,
)
)
extra_columns = ", ".join(
[
f"{ColumnIdentifier(name)}_bin_expr"
for name, plan in self.analysis_plan.columns.items()
if has_bins(get_analysis_by_name(plan[0].analysis_name))
]
)
return (
Template("stats_context", root)
.fill_expressions(
MARGINS_LIST=margins_list,
EXTRA_COLUMNS=extra_columns,
OPENDP_V_VERSION=f"v{opendp_version}",
WEIGHTS=self._make_weights_expression(),
)
.fill_code_blocks(
PRIVACY_UNIT_BLOCK=privacy_unit_block,
PRIVACY_LOSS_BLOCK=privacy_loss_block,
)
)

def _make_partial_synth_context(self):
privacy_unit_block = make_privacy_unit_block(
contributions=self.analysis_plan.contributions,
contributions_entity=self.analysis_plan.contributions_entity,
)
# If there are no groups and all analyses have bounds (so we have cut points),
# then OpenDP requires that pure DP be used for contingency tables.

privacy_loss_block = make_privacy_loss_block(
pure=not self.analysis_plan.groups
and all(
_analysis_has_bounds(analyses[0])
for analyses in self.analysis_plan.columns.values()
),
epsilon=self.analysis_plan.epsilon,
max_rows=self.analysis_plan.max_rows,
)
return (
Template("synth_context", root)
.fill_expressions(
OPENDP_V_VERSION=f"v{opendp_version}",
)
.fill_code_blocks(
PRIVACY_UNIT_BLOCK=privacy_unit_block,
PRIVACY_LOSS_BLOCK=privacy_loss_block,
)
)

def _make_synth_query(self):
def template(synth_context, COLUMNS, CUTS):
synth_query = (
synth_context.query()
.select(COLUMNS)
.contingency_table(
# Numeric columns will generally require cut points,
# unless they contain only a few distinct values.
cuts=CUTS,
# If you know the possible values for particular columns,
# supply them here to use your privacy budget more efficiently:
# keys={"your_column": ["known_value"]},
)
)
contingency_table = synth_query.release()

# Calling
# [`project_melted()`](https://docs.opendp.org/en/OPENDP_V_VERSION/api/python/opendp.extras.mbi.html#opendp.extras.mbi.ContingencyTable.project_melted)
# returns a dataframe with one row per combination of values.
# We'll first check the number of possible rows,
# to make sure it's not too large:

# +
from math import prod

possible_rows = prod([len(v) for v in contingency_table.keys.values()])
(
contingency_table.project_melted([COLUMNS])
if possible_rows < 100_000
else "Too big!"
)
# -

# Finally, a contingency table can also be used
# to create synthetic data by calling
# [`synthesize()`](https://docs.opendp.org/en/OPENDP_V_VERSION/api/python/opendp.extras.mbi.html#opendp.extras.mbi.ContingencyTable.synthesize).
# (There may be warnings from upstream libraries
# which we can ignore for now.)

# +
import warnings

with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
synthetic_data = contingency_table.synthesize()
synthetic_data # type: ignore
# -

# The make_cut_points() call could be moved into generated code,
# but that would require more complex templating,
# and more reliance on helper functions.
cuts = {
k: sorted(
{
# TODO: Error if float cut points are used with integer data.
# Is an upstream fix possible?
# (Sort the set because we might get int collisions,
# and repeated cut points are also an error.)
int(x)
for x in make_cut_points(
lower_bound=int(v[0].lower_bound),
upper_bound=int(v[0].upper_bound),
# bin_count is not set for mean: default to 10.
bin_count=v[0].bin_count or 10,
)
}
)
for (k, v) in self.analysis_plan.columns.items()
if _analysis_has_bounds(v[0])
}
return (
Template(template)
.fill_expressions(
OPENDP_V_VERSION=f"v{opendp_version}",
COLUMNS=", ".join(
repr(k)
for k in (
list(self.analysis_plan.columns.keys())
+ self.analysis_plan.groups
)
),
)
.fill_values(
CUTS=cuts,
)
.finish()
)
Loading
Loading