Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions scripts/hitl-bo/check_persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Manual test for MongoDB persistence
# Run this to check the state before and after manually killing mongodbintegrationmvp.py

from pymongo import MongoClient
from datetime import datetime
import json

def check_mongodb_status():
"""Check MongoDB connection and experiment status."""
try:
client = MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=2000)
client.admin.command("ping")
print("✅ MongoDB is running and accessible")

db = client["ax_db"]
snapshots_col = db["ax_snapshots"]

experiment_name = "branin_experiment_k7m9"

# Get all snapshots for this experiment
snapshots = list(snapshots_col.find(
{"experiment_name": experiment_name}
).sort("timestamp", -1).limit(5))

if snapshots:
print(f"\n📊 Found {len(snapshots)} recent snapshots for '{experiment_name}':")
for i, snapshot in enumerate(snapshots):
timestamp = snapshot['timestamp']
trial_count = snapshot['trial_count']
snapshot_id = str(snapshot['_id'])[:8] + "..."
print(f" {i+1}. {timestamp} | {trial_count} trials | ID: {snapshot_id}")

# Show details of most recent
latest = snapshots[0]
print(f"\n🔍 Most recent snapshot details:")
print(f" Timestamp: {latest['timestamp']}")
print(f" Trial count: {latest['trial_count']}")
print(f" Document ID: {latest['_id']}")

return latest['trial_count']
else:
print(f"📊 No snapshots found for experiment '{experiment_name}'")
return 0

except Exception as e:
print(f"❌ Error: {e}")
return None

def cleanup_old_experiments():
"""Optional: Clean up old experiment data."""
try:
client = MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=2000)
db = client["ax_db"]
snapshots_col = db["ax_snapshots"]

experiment_name = "branin_experiment_k7m9"

count = snapshots_col.count_documents({"experiment_name": experiment_name})
print(f"\n🗑️ Found {count} total snapshots for '{experiment_name}'")

if count > 0:
choice = input("Do you want to delete all snapshots? (y/N): ").strip().lower()
if choice == 'y':
result = snapshots_col.delete_many({"experiment_name": experiment_name})
print(f" Deleted {result.deleted_count} snapshots")
return True

return False

except Exception as e:
print(f"❌ Error during cleanup: {e}")
return False

def main():
print("=" * 60)
print("🔍 MONGODB PERSISTENCE CHECKER")
print("=" * 60)
print("This tool helps you manually test the persistence of mongodbintegrationmvp.py")
print("")
print("How to use:")
print("1. Run this script to check current state")
print("2. Start mongodbintegrationmvp.py and let it run a few trials")
print("3. Kill mongodbintegrationmvp.py (Ctrl+C or close terminal)")
print("4. Run this script again to verify data was saved")
print("5. Start mongodbintegrationmvp.py again to see it resume")
print("")

while True:
print("\nChoose an option:")
print("1. Check current experiment status")
print("2. Clean up old experiments (delete all data)")
print("3. Exit")

choice = input("\nEnter choice (1-3): ").strip()

if choice == "1":
print("\n" + "-" * 40)
trial_count = check_mongodb_status()
print("-" * 40)

if trial_count is not None:
if trial_count == 0:
print("\n💡 No trials found. You can now:")
print(" - Start mongodbintegrationmvp.py to begin a new experiment")
else:
print(f"\n💡 Found {trial_count} trials. You can now:")
print(" - Start mongodbintegrationmvp.py to resume the experiment")
print(" - Or kill it partway through to test persistence")

elif choice == "2":
print("\n" + "-" * 40)
cleanup_old_experiments()
print("-" * 40)

elif choice == "3":
print("\n👋 Goodbye!")
break

else:
print("❌ Invalid choice. Please enter 1, 2, or 3.")

if __name__ == "__main__":
main()
254 changes: 254 additions & 0 deletions scripts/hitl-bo/mongodbintegration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Generated by Honegumi (https://arxiv.org/abs/2502.06815)
# pip install ax-platform==0.4.3 numpy pymongo
import json
import os
from datetime import datetime
import random
import string

import numpy as np
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.ax_client import AxClient, ObjectiveProperties
from pymongo import MongoClient, errors

obj1_name = "branin"
MAX_TRIALS = 19 # Configuration constant

# These will be set based on user choice
experiment_id = None
db_name = None


def branin(x1, x2):
"""Branin function - a common benchmark for optimization."""
y = float(
(x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2
+ 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)
+ 10
)
return y


def generate_random_id(length=4):
"""Generate a random alphanumeric ID."""
return ''.join(random.choices(string.ascii_lowercase + string.digits, k=length))


def get_user_choice():
"""Ask user whether to continue previous experiment or start new one."""
print("\n" + "="*50)
print("EXPERIMENT SETUP")
print("="*50)
print("Choose an option:")
print("1. Continue previous experiment (use existing database)")
print("2. Start new experiment (create new database)")

while True:
choice = input("\nEnter your choice (1 or 2): ").strip()
if choice == "1":
return "continue"
elif choice == "2":
return "new"
else:
print("Invalid choice. Please enter 1 or 2.")


def setup_experiment_config(choice):
"""Set up experiment configuration based on user choice."""
global experiment_id, db_name

if choice == "continue":
# Use default previous experiment settings
experiment_id = f"{obj1_name}_experiment_k7m9"
db_name = "ax_db"
print(f"\nContinuing previous experiment...")
print(f"Database: {db_name}")
print(f"Experiment ID: {experiment_id}")
else:
# Create new experiment with random ID
random_id = generate_random_id()
experiment_id = f"{obj1_name}_experiment_{random_id}"
db_name = f"ax_db_{random_id}"
print(f"\nStarting new experiment...")
print(f"Database: {db_name}")
print(f"Experiment ID: {experiment_id}")

return experiment_id, db_name


# Get user choice and setup configuration
user_choice = get_user_choice()
experiment_id, db_name = setup_experiment_config(user_choice)


# Connect to MongoDB
mongo_client = MongoClient(
"mongodb://localhost:27017/", serverSelectionTimeoutMS=5000
)
# Test the connection
mongo_client.admin.command("ping")
db = mongo_client[db_name] # Use dynamic database name
snapshots_col = db["ax_snapshots"] # Collection for storing JSON snapshots
print(f"Connected to MongoDB successfully (Database: {db_name})")

# Experiment configuration
parameters = [
{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "x2", "type": "range", "bounds": [0.0, 10.0]},
]
objectives = {obj1_name: ObjectiveProperties(minimize=True)}


def save_ax_snapshot_to_mongodb(ax_client, experiment_name):
"""Save Ax client snapshot to MongoDB with timestamp (append, don't overwrite)."""
try:
# Insert document first to get unique ID
snapshot_doc = {
"experiment_name": experiment_name,
"snapshot_data": {}, # Placeholder, will be updated
"timestamp": datetime.now().isoformat(),
"trial_count": (
len(ax_client.get_trials_data_frame())
if ax_client.get_trials_data_frame() is not None
else 0
),
}

# Insert a new document for every snapshot (no overwrite)
result = snapshots_col.insert_one(snapshot_doc)

# Use database ID in temp filename to avoid conflicts
temp_file = f"temp_{experiment_name}_{result.inserted_id}_snapshot.json"
ax_client.save_to_json_file(temp_file)

with open(temp_file, "r") as f:
snapshot_data = json.load(f)

# Update the document with actual snapshot data
snapshots_col.update_one(
{"_id": result.inserted_id},
{"$set": {"snapshot_data": snapshot_data}}
)

os.remove(temp_file)

print(f"Snapshot saved to MongoDB at {snapshot_doc['timestamp']} (ID: {result.inserted_id})")
return result.inserted_id

except Exception as e:
print(f"Error saving snapshot: {e}")
return None


def load_ax_snapshot_from_mongodb(experiment_name):
"""Load most recent Ax client snapshot from MongoDB."""
try:
# Find the most recent snapshot
record = snapshots_col.find_one(
{"experiment_name": experiment_name},
sort=[("timestamp", -1)], # Most recent first
)

if record:
# Use database ID in temp filename to avoid conflicts
temp_file = f"temp_{experiment_name}_{record['_id']}_snapshot.json"
with open(temp_file, "w") as f:
json.dump(record["snapshot_data"], f)

# Load AxClient from file
ax_client = AxClient.load_from_json_file(temp_file)

# Clean up temp file
os.remove(temp_file)

print(
f"Loaded snapshot from {record['timestamp']} with "
f"{record['trial_count']} trials"
)
return ax_client
else:
print("No existing snapshot found")
return None

except Exception as e:
print(f"Error loading snapshot: {e}")
return None


# Load existing experiment or create new one
ax_client = load_ax_snapshot_from_mongodb(experiment_id)

if ax_client is None:
# Create new experiment (Ax will use default generation strategy)
ax_client = AxClient()
ax_client.create_experiment(
name=experiment_id, parameters=parameters, objectives=objectives
)
print(f"Created new experiment '{experiment_id}' with default generation strategy")

# Save initial snapshot
save_ax_snapshot_to_mongodb(ax_client, experiment_id)
else:
print(f"Resuming existing experiment '{experiment_id}'")

# Get current trial count to determine how many more trials to run
current_trials = ax_client.get_trials_data_frame()
start_trial = len(current_trials) if current_trials is not None else 0

print(f"Starting optimization: running trials {start_trial} to {MAX_TRIALS-1}")

for i in range(start_trial, MAX_TRIALS):
# Get next trial
parameterization, trial_index = ax_client.get_next_trial()

# Extract parameters
x1 = parameterization["x1"]
x2 = parameterization["x2"]

print(f"Trial {trial_index}: x1={x1:.3f}, x2={x2:.3f}")

# Save snapshot before running experiment (preserves pending trial)
save_ax_snapshot_to_mongodb(ax_client, experiment_id)

# Evaluate objective function
results = branin(x1, x2)

# Format raw_data as expected by AxClient
raw_data = {obj1_name: results}

# Complete trial
ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

# Save snapshot after completing trial
save_ax_snapshot_to_mongodb(ax_client, experiment_id)

# Get current best for progress tracking
best_parameters, best_metrics = ax_client.get_best_parameters()
best_value = best_metrics[0][obj1_name]
print(
f"Trial {trial_index}: result={results:.3f} | "
f"Best so far: {best_value:.3f}"
)

print("\nOptimization completed!")
best_parameters, best_metrics = ax_client.get_best_parameters()
print(f"Best parameters: {best_parameters}")
print(f"Best metrics: {best_metrics}")

# Save final snapshot
save_ax_snapshot_to_mongodb(ax_client, experiment_id)

# Print experiment summary
trials_df = ax_client.get_trials_data_frame()
if trials_df is not None:
print(f"Total trials completed: {len(trials_df)}")
print(f"Best objective value: {trials_df[obj1_name].min():.6f}")

# Clean up MongoDB connection
mongo_client.close()
print("MongoDB connection closed")

# Optional: Display trials data frame for debugging
print("\nTrials Summary:")
print(ax_client.get_trials_data_frame())
Loading
Loading