Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions scripts/mongodbintegrationmvp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Generated by Honegumi (https://arxiv.org/abs/2502.06815)
# pip install ax-platform==0.4.3 numpy pymongo
import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Models
from pymongo import MongoClient # Added import for MongoDB


obj1_name = "branin"
MAX_TRIALS = 19 # Configuration constant


def branin(x1, x2):
y = float(
(x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2
+ 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)
+ 10
)

return y


# Connect to MongoDB
tmongo_client = MongoClient("mongodb://localhost:27017/")
db = tmongo_client["ax_db"]
experiments_col = db["experiments"]

# Experiment configuration
parameters = [
{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "x2", "type": "range", "bounds": [0.0, 10.0]},
]
objectives = {obj1_name: ObjectiveProperties(minimize=True)}

# Use Ax's default Sobol trials for 2D problems (aligns with GitHub comment)
SOBOL_TRIALS = 5

# Load existing experiment state or initialize new
record = experiments_col.find_one({"experiment_name": obj1_name})
if record:
saved_trials = record.get("trials", [])
n_existing = len(saved_trials)

# Calculate remaining Sobol trials: max(target_sobol - existing, 0)
remaining_sobol = max(SOBOL_TRIALS - n_existing, 0)

if remaining_sobol > 0:
generation_strategy = GenerationStrategy([
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give this a try using Ax's built-in serialization method (save_to_json_file) and using the JSON snapshots that get sent to and reloaded from. In the current implementation, it doesn't preserve the Sobol sequence.

It should follow the same logic in https://colab.research.google.com/drive/1A2p1oUSsD8Edlu2haaB-FBSSjim3HTLS?usp=sharing

Please also add to each MongoDB document the timestamp of when the snapshot is saved and use the timestamp to load the most recent snapshot each time.

{"model": Models.SOBOL, "num_trials": remaining_sobol},
{"model": Models.GPEI, "num_trials": -1}
])
print(f"Will run {remaining_sobol} more Sobol trials (have {n_existing} existing)")
else:
# Remove Sobol step entirely when remaining_sobol = 0
generation_strategy = GenerationStrategy([
{"model": Models.GPEI, "num_trials": -1}
])
print(f"Skipping Sobol (have {n_existing} trials), going to GP")

ax_client = AxClient(generation_strategy=generation_strategy)
ax_client.create_experiment(name=obj1_name, parameters=parameters, objectives=objectives)

# Replay saved trials
for t in saved_trials:
ax_client.complete_trial(trial_index=t["trial_index"], raw_data=t["raw_data"])
start_i = len(saved_trials)
else:
# Use the SAME custom generation strategy for new experiments
generation_strategy = GenerationStrategy([
{"model": Models.SOBOL, "num_trials": SOBOL_TRIALS},
{"model": Models.GPEI, "num_trials": -1}
])
ax_client = AxClient(generation_strategy=generation_strategy)
ax_client.create_experiment(name=obj1_name, parameters=parameters, objectives=objectives)
start_i = 0
experiments_col.insert_one({"experiment_name": obj1_name, "trials": []})
print(f"Starting new experiment with {SOBOL_TRIALS} Sobol trials")

for i in range(start_i, MAX_TRIALS):

parameterization, trial_index = ax_client.get_next_trial()

# extract parameters
x1 = parameterization["x1"]
x2 = parameterization["x2"]

results = branin(x1, x2)
# Format raw_data as expected by AxClient (dict mapping objective name to value)
raw_data = {obj1_name: results}

ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

# Save trial results to MongoDB with parameters for debugging
experiments_col.update_one(
{"experiment_name": obj1_name},
{"$push": {"trials": {
"trial_index": trial_index,
"raw_data": raw_data,
"parameters": parameterization
}}},
)

print(f"Trial {trial_index}: x1={x1:.3f}, x2={x2:.3f}, result={results:.3f}")

best_parameters, metrics = ax_client.get_best_parameters()
print(f"Best parameters: {best_parameters}")
print(f"Best metrics: {metrics}")
Loading