Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions scripts/mongodbintegrationmvp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Generated by Honegumi (https://arxiv.org/abs/2502.06815)
# pip install ax-platform==0.4.3 numpy pymongo
import numpy as np
import json
import os
from datetime import datetime
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep
from ax.modelbridge.registry import Models
from pymongo import MongoClient, errors


obj1_name = "branin"
MAX_TRIALS = 19 # Configuration constant


def branin(x1, x2):
"""Branin function - a common benchmark for optimization."""
y = float(
(x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2
+ 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)
+ 10
)
return y


# Connect to MongoDB with error handling
try:
mongo_client = MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=5000)
# Test the connection
mongo_client.admin.command('ping')
db = mongo_client["ax_db"]
snapshots_col = db["ax_snapshots"] # Collection for storing JSON snapshots
print("Connected to MongoDB successfully")
except errors.ServerSelectionTimeoutError:
print("Failed to connect to MongoDB. Is MongoDB running?")
exit(1)
except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this explicit error handling? Seems like it would just bubble up naturally (unless you found that the error that bubbled up naturally was non-descript).

As a note for later, we'll set this up with a MongoDB Atlas cluster

print(f"MongoDB connection error: {e}")
exit(1)

# Experiment configuration
parameters = [
{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "x2", "type": "range", "bounds": [0.0, 10.0]},
]
objectives = {obj1_name: ObjectiveProperties(minimize=True)}

# Use Ax's default Sobol trials for 2D problems
SOBOL_TRIALS = 5

def save_ax_snapshot_to_mongodb(ax_client, experiment_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this return the database ID of the snapshot?

"""Save Ax client snapshot to MongoDB with timestamp (append, don't overwrite)."""
try:
temp_file = f"temp_{experiment_name}_snapshot.json"
ax_client.save_to_json_file(temp_file)

with open(temp_file, 'r') as f:
snapshot_data = json.load(f)

snapshot_doc = {
"experiment_name": experiment_name,
"snapshot_data": snapshot_data,
"timestamp": datetime.now().isoformat(),
"trial_count": len(ax_client.get_trials_data_frame()) if ax_client.get_trials_data_frame() is not None else 0
}

# Insert a new document for every snapshot (no overwrite)
snapshots_col.insert_one(snapshot_doc)

os.remove(temp_file)

print(f"Snapshot saved to MongoDB at {snapshot_doc['timestamp']}")
return True

except Exception as e:
print(f"Error saving snapshot: {e}")
return False
"""Save Ax client snapshot to MongoDB with timestamp."""
try:
# Save to temporary JSON file first
temp_file = f"temp_{experiment_name}_snapshot.json"
ax_client.save_to_json_file(temp_file)

# Read the JSON content
with open(temp_file, 'r') as f:
snapshot_data = json.load(f)

# Create MongoDB document
snapshot_doc = {
"experiment_name": experiment_name,
"snapshot_data": snapshot_data,
"timestamp": datetime.now().isoformat(),
"trial_count": len(ax_client.get_trials_data_frame()) if ax_client.get_trials_data_frame() is not None else 0
}

# Upsert the snapshot (replace if exists, insert if not)
snapshots_col.replace_one(
{"experiment_name": experiment_name},
snapshot_doc,
upsert=True
)

# Clean up temp file
os.remove(temp_file)

print(f"Snapshot saved to MongoDB at {snapshot_doc['timestamp']}")
return True

except Exception as e:
print(f"Error saving snapshot: {e}")
return False


def load_ax_snapshot_from_mongodb(experiment_name):
"""Load most recent Ax client snapshot from MongoDB."""
try:
# Find the most recent snapshot
record = snapshots_col.find_one(
{"experiment_name": experiment_name},
sort=[("timestamp", -1)] # Most recent first
)

if record:
# Save snapshot data to temporary file
temp_file = f"temp_{experiment_name}_snapshot.json"
with open(temp_file, 'w') as f:
json.dump(record["snapshot_data"], f)

# Load AxClient from file
ax_client = AxClient.load_from_json_file(temp_file)

# Clean up temp file
os.remove(temp_file)

print(
f"Loaded snapshot from {record['timestamp']} with "
f"{record['trial_count']} trials"
)
return ax_client
else:
print("No existing snapshot found")
return None

except Exception as e:
print(f"Error loading snapshot: {e}")
return None


# Load existing experiment or create new one
ax_client = load_ax_snapshot_from_mongodb(obj1_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably don't reuse obj1_name like this. Instead define a separate variable. Could incorporate obj1_name via f-string. Probably good to also add a hard-coded set of 4 characters (randomly generated externally) to it.


if ax_client is None:
# Create new experiment
generation_strategy = GenerationStrategy([
GenerationStep(
model=Models.SOBOL,
num_trials=SOBOL_TRIALS,
min_trials_observed=1,
max_parallelism=5,
model_kwargs={"seed": 999}, # For reproducibility
),
GenerationStep(
model=Models.GPEI,
num_trials=-1,
max_parallelism=3,
model_kwargs={},
),
])
ax_client = AxClient(generation_strategy=generation_strategy)
ax_client.create_experiment(
name=obj1_name,
parameters=parameters,
objectives=objectives
)
print(f"Created new experiment with {SOBOL_TRIALS} Sobol trials")

# Save initial snapshot
save_ax_snapshot_to_mongodb(ax_client, obj1_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about experiment name

else:
print("Resuming existing experiment")

# Get current trial count to determine how many more trials to run
current_trials = ax_client.get_trials_data_frame()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point about needing to handle max trials in the for loop. I suppose an alternative would be to have a budget variable stored in MongoDB that gets updated, but ignore that for now. More of a note to self/musing.

start_trial = len(current_trials) if current_trials is not None else 0

print(f"Starting optimization: running trials {start_trial} to {MAX_TRIALS-1}")

for i in range(start_trial, MAX_TRIALS):
try:
# Get next trial
parameterization, trial_index = ax_client.get_next_trial()

# Extract parameters
x1 = parameterization["x1"]
x2 = parameterization["x2"]

print(f"Trial {trial_index}: x1={x1:.3f}, x2={x2:.3f}")

# Save snapshot before running experiment (preserves pending trial)
save_ax_snapshot_to_mongodb(ax_client, obj1_name)

# Evaluate objective function
results = branin(x1, x2)

# Format raw_data as expected by AxClient
raw_data = {obj1_name: results}

# Complete trial
ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

# Save snapshot after completing trial
save_ax_snapshot_to_mongodb(ax_client, obj1_name)

# Get current best for progress tracking
try:
best_parameters, best_metrics = ax_client.get_best_parameters()
best_value = best_metrics[0][obj1_name]
print(
f"Trial {trial_index}: result={results:.3f} | "
f"Best so far: {best_value:.3f}"
)
except Exception:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably remove these try-excepts

print(f"Trial {trial_index}: result={results:.3f}")

except Exception as e:
print(f"Error in trial {trial_index}: {e}")
continue

print("\nOptimization completed!")
try:
best_parameters, best_metrics = ax_client.get_best_parameters()
print(f"Best parameters: {best_parameters}")
print(f"Best metrics: {best_metrics}")

# Save final snapshot
save_ax_snapshot_to_mongodb(ax_client, obj1_name)

# Print experiment summary
trials_df = ax_client.get_trials_data_frame()
if trials_df is not None:
print(f"Total trials completed: {len(trials_df)}")
print(f"Best objective value: {trials_df[obj1_name].min():.6f}")

except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, no need for try except here

print(f"Error getting best parameters: {e}")

# Clean up MongoDB connection
mongo_client.close()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to close the client. I suppose a top-level try-except-finally could be implemented at some point if we find that lots of mongo connections are being leftover during restarts, but I think probably not an issue. Again, just a musing

print("MongoDB connection closed")

# Optional: Display trials data frame for debugging
try:
print("\nTrials Summary:")
print(ax_client.get_trials_data_frame())
except Exception as e:
print(f"Error displaying trials: {e}")
Loading