-
Notifications
You must be signed in to change notification settings - Fork 8
Integrating MongoDB MVP #383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
cfae831
4353f4f
0d5ae9d
fe004b8
633f4b2
c5c7a97
8c499f9
932d37b
454d471
3d65800
89b7183
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,257 @@ | ||
| # Generated by Honegumi (https://arxiv.org/abs/2502.06815) | ||
| # pip install ax-platform==0.4.3 numpy pymongo | ||
| import numpy as np | ||
| import json | ||
| import os | ||
| from datetime import datetime | ||
| from ax.service.ax_client import AxClient, ObjectiveProperties | ||
| from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep | ||
| from ax.modelbridge.registry import Models | ||
| from pymongo import MongoClient, errors | ||
|
|
||
|
|
||
| obj1_name = "branin" | ||
| MAX_TRIALS = 19 # Configuration constant | ||
|
|
||
|
|
||
| def branin(x1, x2): | ||
| """Branin function - a common benchmark for optimization.""" | ||
| y = float( | ||
| (x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2 | ||
| + 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1) | ||
| + 10 | ||
| ) | ||
| return y | ||
|
|
||
|
|
||
| # Connect to MongoDB with error handling | ||
| try: | ||
| mongo_client = MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=5000) | ||
| # Test the connection | ||
| mongo_client.admin.command('ping') | ||
| db = mongo_client["ax_db"] | ||
| snapshots_col = db["ax_snapshots"] # Collection for storing JSON snapshots | ||
| print("Connected to MongoDB successfully") | ||
| except errors.ServerSelectionTimeoutError: | ||
| print("Failed to connect to MongoDB. Is MongoDB running?") | ||
| exit(1) | ||
| except Exception as e: | ||
| print(f"MongoDB connection error: {e}") | ||
| exit(1) | ||
|
|
||
| # Experiment configuration | ||
| parameters = [ | ||
| {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, | ||
| {"name": "x2", "type": "range", "bounds": [0.0, 10.0]}, | ||
| ] | ||
| objectives = {obj1_name: ObjectiveProperties(minimize=True)} | ||
|
|
||
| # Use Ax's default Sobol trials for 2D problems | ||
| SOBOL_TRIALS = 5 | ||
|
|
||
| def save_ax_snapshot_to_mongodb(ax_client, experiment_name): | ||
Gawthaman marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Save Ax client snapshot to MongoDB with timestamp (append, don't overwrite).""" | ||
| try: | ||
| temp_file = f"temp_{experiment_name}_snapshot.json" | ||
| ax_client.save_to_json_file(temp_file) | ||
|
|
||
| with open(temp_file, 'r') as f: | ||
| snapshot_data = json.load(f) | ||
|
|
||
| snapshot_doc = { | ||
| "experiment_name": experiment_name, | ||
| "snapshot_data": snapshot_data, | ||
| "timestamp": datetime.now().isoformat(), | ||
| "trial_count": len(ax_client.get_trials_data_frame()) if ax_client.get_trials_data_frame() is not None else 0 | ||
| } | ||
|
|
||
| # Insert a new document for every snapshot (no overwrite) | ||
| snapshots_col.insert_one(snapshot_doc) | ||
|
|
||
| os.remove(temp_file) | ||
|
|
||
| print(f"Snapshot saved to MongoDB at {snapshot_doc['timestamp']}") | ||
| return True | ||
|
|
||
| except Exception as e: | ||
| print(f"Error saving snapshot: {e}") | ||
| return False | ||
| """Save Ax client snapshot to MongoDB with timestamp.""" | ||
| try: | ||
| # Save to temporary JSON file first | ||
| temp_file = f"temp_{experiment_name}_snapshot.json" | ||
| ax_client.save_to_json_file(temp_file) | ||
|
|
||
| # Read the JSON content | ||
| with open(temp_file, 'r') as f: | ||
| snapshot_data = json.load(f) | ||
|
|
||
| # Create MongoDB document | ||
| snapshot_doc = { | ||
| "experiment_name": experiment_name, | ||
| "snapshot_data": snapshot_data, | ||
| "timestamp": datetime.now().isoformat(), | ||
| "trial_count": len(ax_client.get_trials_data_frame()) if ax_client.get_trials_data_frame() is not None else 0 | ||
| } | ||
|
|
||
| # Upsert the snapshot (replace if exists, insert if not) | ||
| snapshots_col.replace_one( | ||
| {"experiment_name": experiment_name}, | ||
| snapshot_doc, | ||
| upsert=True | ||
| ) | ||
|
|
||
| # Clean up temp file | ||
| os.remove(temp_file) | ||
|
|
||
| print(f"Snapshot saved to MongoDB at {snapshot_doc['timestamp']}") | ||
| return True | ||
|
|
||
| except Exception as e: | ||
| print(f"Error saving snapshot: {e}") | ||
| return False | ||
|
|
||
|
|
||
| def load_ax_snapshot_from_mongodb(experiment_name): | ||
| """Load most recent Ax client snapshot from MongoDB.""" | ||
| try: | ||
| # Find the most recent snapshot | ||
| record = snapshots_col.find_one( | ||
| {"experiment_name": experiment_name}, | ||
| sort=[("timestamp", -1)] # Most recent first | ||
| ) | ||
|
|
||
| if record: | ||
| # Save snapshot data to temporary file | ||
| temp_file = f"temp_{experiment_name}_snapshot.json" | ||
| with open(temp_file, 'w') as f: | ||
| json.dump(record["snapshot_data"], f) | ||
|
|
||
| # Load AxClient from file | ||
| ax_client = AxClient.load_from_json_file(temp_file) | ||
|
|
||
| # Clean up temp file | ||
| os.remove(temp_file) | ||
|
|
||
| print( | ||
| f"Loaded snapshot from {record['timestamp']} with " | ||
| f"{record['trial_count']} trials" | ||
| ) | ||
| return ax_client | ||
| else: | ||
| print("No existing snapshot found") | ||
| return None | ||
|
|
||
| except Exception as e: | ||
| print(f"Error loading snapshot: {e}") | ||
| return None | ||
|
|
||
|
|
||
| # Load existing experiment or create new one | ||
| ax_client = load_ax_snapshot_from_mongodb(obj1_name) | ||
|
||
|
|
||
| if ax_client is None: | ||
| # Create new experiment | ||
| generation_strategy = GenerationStrategy([ | ||
| GenerationStep( | ||
| model=Models.SOBOL, | ||
| num_trials=SOBOL_TRIALS, | ||
| min_trials_observed=1, | ||
| max_parallelism=5, | ||
| model_kwargs={"seed": 999}, # For reproducibility | ||
| ), | ||
| GenerationStep( | ||
| model=Models.GPEI, | ||
| num_trials=-1, | ||
| max_parallelism=3, | ||
| model_kwargs={}, | ||
| ), | ||
| ]) | ||
| ax_client = AxClient(generation_strategy=generation_strategy) | ||
| ax_client.create_experiment( | ||
| name=obj1_name, | ||
| parameters=parameters, | ||
| objectives=objectives | ||
| ) | ||
| print(f"Created new experiment with {SOBOL_TRIALS} Sobol trials") | ||
|
|
||
| # Save initial snapshot | ||
| save_ax_snapshot_to_mongodb(ax_client, obj1_name) | ||
|
||
| else: | ||
| print("Resuming existing experiment") | ||
|
|
||
| # Get current trial count to determine how many more trials to run | ||
| current_trials = ax_client.get_trials_data_frame() | ||
|
||
| start_trial = len(current_trials) if current_trials is not None else 0 | ||
|
|
||
| print(f"Starting optimization: running trials {start_trial} to {MAX_TRIALS-1}") | ||
|
|
||
| for i in range(start_trial, MAX_TRIALS): | ||
| try: | ||
| # Get next trial | ||
| parameterization, trial_index = ax_client.get_next_trial() | ||
|
|
||
| # Extract parameters | ||
| x1 = parameterization["x1"] | ||
| x2 = parameterization["x2"] | ||
|
|
||
| print(f"Trial {trial_index}: x1={x1:.3f}, x2={x2:.3f}") | ||
|
|
||
| # Save snapshot before running experiment (preserves pending trial) | ||
| save_ax_snapshot_to_mongodb(ax_client, obj1_name) | ||
|
|
||
| # Evaluate objective function | ||
| results = branin(x1, x2) | ||
|
|
||
| # Format raw_data as expected by AxClient | ||
| raw_data = {obj1_name: results} | ||
|
|
||
| # Complete trial | ||
| ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data) | ||
|
|
||
| # Save snapshot after completing trial | ||
| save_ax_snapshot_to_mongodb(ax_client, obj1_name) | ||
|
|
||
| # Get current best for progress tracking | ||
| try: | ||
| best_parameters, best_metrics = ax_client.get_best_parameters() | ||
| best_value = best_metrics[0][obj1_name] | ||
| print( | ||
| f"Trial {trial_index}: result={results:.3f} | " | ||
| f"Best so far: {best_value:.3f}" | ||
| ) | ||
| except Exception: | ||
|
||
| print(f"Trial {trial_index}: result={results:.3f}") | ||
|
|
||
| except Exception as e: | ||
| print(f"Error in trial {trial_index}: {e}") | ||
| continue | ||
|
|
||
| print("\nOptimization completed!") | ||
| try: | ||
| best_parameters, best_metrics = ax_client.get_best_parameters() | ||
| print(f"Best parameters: {best_parameters}") | ||
| print(f"Best metrics: {best_metrics}") | ||
|
|
||
| # Save final snapshot | ||
| save_ax_snapshot_to_mongodb(ax_client, obj1_name) | ||
|
|
||
| # Print experiment summary | ||
| trials_df = ax_client.get_trials_data_frame() | ||
| if trials_df is not None: | ||
| print(f"Total trials completed: {len(trials_df)}") | ||
| print(f"Best objective value: {trials_df[obj1_name].min():.6f}") | ||
|
|
||
| except Exception as e: | ||
|
||
| print(f"Error getting best parameters: {e}") | ||
|
|
||
| # Clean up MongoDB connection | ||
| mongo_client.close() | ||
|
||
| print("MongoDB connection closed") | ||
|
|
||
| # Optional: Display trials data frame for debugging | ||
| try: | ||
| print("\nTrials Summary:") | ||
| print(ax_client.get_trials_data_frame()) | ||
| except Exception as e: | ||
| print(f"Error displaying trials: {e}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this explicit error handling? Seems like it would just bubble up naturally (unless you found that the error that bubbled up naturally was non-descript).
As a note for later, we'll set this up with a MongoDB Atlas cluster