Skip to content
Merged

SRbench #2105

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions opencompass/configs/datasets/srbench/srbench_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import (
SRbenchDataset,SRbenchDatasetEvaluator,mydataset_postprocess
)


INFER_TEMPLATE = f'''
You will be provided with a set of input-output pairs. Based on these data, infer the mathematical relationship between y and multiple input variables. Please note that the possible mathematical operations include: +, -, *, /, exp, sqrt, sin, arcsin, and constant terms.
The input sample data are as follows:
{{prompt1}}
Based on the above data, please infer the possible formula. Ensure that your inference applies to all the provided data points, and consider both linear and nonlinear combinations.
Verify whether your formula applies to the following new data point and adjust it to ensure accuracy:
{{prompt2}}
Finally, please output only the formula string you inferred (e.g. y=x_0 * x_1), without any additional information.
'''

srbench_reader_cfg = dict(input_columns=['prompt1','prompt2'], output_column='Formula')

srbench_datasets = []

srbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt=INFER_TEMPLATE)
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)


srbench_eval_cfg = dict(
evaluator=dict(type=SRbenchDatasetEvaluator, path='opencompass/srbench'),
pred_postprocessor=dict(type=mydataset_postprocess),
pred_role='BOT',
)

srbench_datasets.append(
dict(
abbr='srbench',
type=SRbenchDataset,
path='opencompass/srbench',
reader_cfg=srbench_reader_cfg,
infer_cfg=srbench_infer_cfg,
eval_cfg=srbench_eval_cfg,
)
)

1 change: 1 addition & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
from .siqa import * # noqa: F401, F403
from .smolinstruct import * # noqa: F401, F403
from .squad20 import SQuAD20Dataset, SQuAD20Evaluator # noqa: F401, F403
from .srbench import * # noqa: F401, F403
from .storycloze import * # noqa: F401, F403
from .strategyqa import * # noqa: F401, F403
from .subjective import * # noqa: F401, F403
Expand Down
179 changes: 179 additions & 0 deletions opencompass/datasets/srbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# flake8: noqa
import os
import re

import numpy as np
import pandas as pd
import sympy as sp
from datasets import load_dataset
from sklearn.metrics import r2_score, root_mean_squared_error

from opencompass.datasets.base import BaseDataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path


@LOAD_DATASET.register_module()
class SRbenchDataset(BaseDataset):

@staticmethod
def load(path: str):
path = get_data_path(path)
base_path = os.path.join(path, 'Feynman')
formula_csv_path = os.path.join(base_path, 'FeynmanEquation_23.csv')
data_files_base_dir = os.path.join(base_path, 'Feynman_with_units')
dataset = load_dataset('csv', data_files=formula_csv_path)['train']
sample_data = []
prompt_1_out = []
prompt_2_out = []
for row in dataset:
n_var = int(row['n_variables'])
data_filename = str(row['Filename'])

data_file_path = os.path.join(data_files_base_dir, data_filename)
full_dataset = np.loadtxt(data_file_path)
rand_idx = np.random.choice(full_dataset.shape[0],
100,
replace=False)
sampled_data_i = full_dataset[rand_idx]
if isinstance(sampled_data_i, np.ndarray):
sample_data.append(sampled_data_i.tolist())
else:
sample_data.append(sampled_data_i)
# x = dataset[:, :n_var]
# y_true = dataset[:, -1]
if n_var == 2:
prompt_1 = '\n'.join([
f'x0={x1:.4f}, x1={x2:.4f}, y={y:.4f}'
for x1, x2, y in sampled_data_i[:-1]
])
prompt_2 = f'x0={sampled_data_i[-1, 0]:.4f}, x1={sampled_data_i[-1, 1]:.4f}, y={sampled_data_i[-1, 2]:.4f}'
else:
prompt_1 = '\n'.join([
f'x0={x1:.4f}, x1={x2:.4f}, x2={x3:.4f},y={y:.4f}'
for x1, x2, x3, y in sampled_data_i[:-1]
])
prompt_2 = f'x0={sampled_data_i[-1, 0]:.4f}, x1={sampled_data_i[-1, 1]:.4f},x3={sampled_data_i[-1, 2]:.4f}, y={sampled_data_i[-1, 3]:.4f}'

prompt_1_out.append(prompt_1)
prompt_2_out.append(prompt_2)
dataset = dataset.add_column(name='prompt1', column=prompt_1_out)
dataset = dataset.add_column(name='prompt2', column=prompt_2_out)
dataset = dataset.add_column(name='data_samples_list',
column=sample_data)
dataset = dataset.rename_column('n_variables', 'n_var')
return dataset


def mydataset_postprocess(formula_str):

formula_str = formula_str.replace('×', '*').replace('·',
'*').replace('÷', '/')
formula_str = formula_str.replace('−', '-').replace('^', '**')
formula_str = formula_str.replace('“', '"').replace('”',
'"').replace('’', "'")

formula_str = formula_str.replace('`', '').replace('$', '').strip()

formula_str = formula_str.split('\n')[0].strip()

formula_str = re.sub(r'[^\w\s\+\-\*/\^\=\.\(\)]', '', formula_str)

return formula_str.strip()


class SRbenchDatasetEvaluator(BaseEvaluator):

def __init__(self, path=''):
self.dataset = SRbenchDataset.load(path)

def parse_formula(self, formula_str, n_var=2):
try:
if '=' in formula_str:
_, expr_str = formula_str.split('=', 1)
else:
expr_str = formula_str
variables = [sp.Symbol(f'x{i}') for i in range(n_var)]
expr = sp.sympify(expr_str)
func = sp.lambdify(variables, expr, modules='numpy')
return func
except Exception as e:
print(f'[Parse Error] {formula_str}\n{e}')
return None

def is_symbolically_equivalent(self, formula1, formula2, n_var=2):
try:
expr1 = sp.sympify(
formula1.split('=')[1] if '=' in formula1 else formula1)
expr2 = sp.sympify(
formula2.split('=')[1] if '=' in formula2 else formula2)

return sp.simplify(expr1 - expr2) == 0
except Exception:
return False

def score(self, predictions, references) -> dict:
metrics = {
'LLM_Score': None,
'RMSE': None,
'SymbolicMatch': False,
'R2': 0
}
metrics_out = {
'LLM_Score': None,
'RMSE': None,
'Accuray': False,
'R2': 0
}
result = pd.DataFrame({
'GT': pd.Series(dtype=str),
'Pred': pd.Series(dtype=str),
'Score': pd.Series(dtype=float),
'RMSE': pd.Series(dtype=float),
'R2': pd.Series(dtype=float),
'SymbolicMatch': pd.Series(dtype=bool)
})

for row in range(len(references)):
# metrics['LLM_Score'] = float(self.llm_evaluate(predictions[row], references[row], mllm='gpt-4o'))
n_var = self.dataset[row]['n_var']
data_sample = self.dataset[row]['data_samples_list']
data_sample = np.array(data_sample)
x = data_sample[:, :n_var]
y_true = data_sample[:, -1]
func = self.parse_formula(predictions[row], n_var=n_var)
if func is not None:
x_vars = [x[:, i] for i in range(n_var)]
y_pred = func(*x_vars)
if np.isscalar(y_pred):
y_pred = np.full_like(y_true, y_pred)
metrics['RMSE'] = root_mean_squared_error(y_true, y_pred)
metrics['R2'] = r2_score(y_true, y_pred)
else:
metrics['R2'] = 0
metrics['RMSE'] = np.inf
metrics['SymbolicMatch'] = self.is_symbolically_equivalent(
predictions[row], references[row], n_var)
result = result._append(
{
'GT': references[row],
'Pred': predictions[row],
'RMSE': metrics['RMSE'],
'R2': metrics['R2'],
'SymbolicMatch': bool(metrics['SymbolicMatch'])
},
ignore_index=True)

if not result.empty:
symbolic_accuracy = result['SymbolicMatch'].sum() / len(result)
R2_out = result['R2'].sum() / len(result)
RMSE_out = result['RMSE'].sum() / len(result)

metrics_out = {
'RMSE': RMSE_out,
'R2': R2_out,
'Accuracy': symbolic_accuracy
}

return metrics_out
10 changes: 10 additions & 0 deletions opencompass/utils/datasets_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,11 @@
"hf_id": "",
"local": "./data/ChemBench4K",
},
"opencompass/srbench": {
"ms_id": "",
"hf_id": "",
"local": "./data/srbench",
},
"opencompass/nejmaibench": {
"ms_id": "",
"hf_id": "",
Expand Down Expand Up @@ -819,6 +824,11 @@
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ChemBench4K.zip",
"md5": "fc23fd21b2566a5dbbebfa4601d7779c"
},
"/srbench": {
"url":
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/srbench.zip",
"md5": "ab6c5308f7930ac9fbc516ab757feef1"
},
"nejmaibench": {
"url":
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/nejmaibench.zip",
Expand Down