| | """
|
| | This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/parser.py
|
| | """
|
| | import re
|
| | from typing import Any, Dict
|
| |
|
| |
|
| | def _fix_fracs(string):
|
| | substrs = string.split("\\frac")
|
| | new_str = substrs[0]
|
| | if len(substrs) > 1:
|
| | substrs = substrs[1:]
|
| | for substr in substrs:
|
| | new_str += "\\frac"
|
| | if len(substr) > 0 and substr[0] == "{":
|
| | new_str += substr
|
| | else:
|
| | try:
|
| | assert len(substr) >= 2
|
| | except:
|
| | return string
|
| | a = substr[0]
|
| | b = substr[1]
|
| | if b != "{":
|
| | if len(substr) > 2:
|
| | post_substr = substr[2:]
|
| | new_str += "{" + a + "}{" + b + "}" + post_substr
|
| | else:
|
| | new_str += "{" + a + "}{" + b + "}"
|
| | else:
|
| | if len(substr) > 2:
|
| | post_substr = substr[2:]
|
| | new_str += "{" + a + "}" + b + post_substr
|
| | else:
|
| | new_str += "{" + a + "}" + b
|
| | string = new_str
|
| | return string
|
| |
|
| |
|
| | def _fix_a_slash_b(string):
|
| | if len(string.split("/")) != 2:
|
| | return string
|
| | a = string.split("/")[0]
|
| | b = string.split("/")[1]
|
| | try:
|
| | if "sqrt" not in a:
|
| | a = int(a)
|
| | if "sqrt" not in b:
|
| | b = int(b)
|
| | assert string == "{}/{}".format(a, b)
|
| | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
| | return new_string
|
| | except:
|
| | return string
|
| |
|
| |
|
| | def _fix_sqrt(string):
|
| | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
| | return _string
|
| |
|
| |
|
| | def strip_string(string):
|
| | string = str(string).strip()
|
| |
|
| | string = string.replace("\n", "")
|
| |
|
| |
|
| | string = string.rstrip(".")
|
| |
|
| |
|
| | string = string.replace("\\!", "")
|
| | string = string.replace("\\ ", "")
|
| |
|
| |
|
| | string = string.replace("\\\\", "\\")
|
| | string = string.replace("\\\\", "\\")
|
| |
|
| |
|
| | string = string.replace("tfrac", "frac")
|
| | string = string.replace("dfrac", "frac")
|
| |
|
| |
|
| | string = string.replace("\\left", "")
|
| | string = string.replace("\\right", "")
|
| |
|
| |
|
| | _string = re.sub(r"\\text{.*?}$", "", string).strip()
|
| | if _string != "" and _string != string:
|
| |
|
| | string = _string
|
| |
|
| |
|
| | string = string.replace("^{\\circ}", "")
|
| | string = string.replace("^\\circ", "")
|
| |
|
| |
|
| | string = string.replace("\\$", "")
|
| | string = string.replace("$", "")
|
| |
|
| | string = string.replace("\\text", "")
|
| | string = string.replace("x\\in", "")
|
| |
|
| |
|
| | string = string.replace("\\%", "")
|
| | string = string.replace("\%", "")
|
| | string = string.replace("%", "")
|
| |
|
| |
|
| | string = string.replace(" .", " 0.")
|
| | string = string.replace("{.", "{0.")
|
| |
|
| |
|
| | string = string.replace("\\cdot", "")
|
| |
|
| |
|
| | string = string.replace("infinity", "\\infty")
|
| | if "\\infty" not in string:
|
| | string = string.replace("inf", "\\infty")
|
| | string = string.replace("+\\inity", "\\infty")
|
| |
|
| |
|
| | string = string.replace("and", "")
|
| | string = string.replace("\\mathbf", "")
|
| |
|
| |
|
| | string = re.sub(r"\\mbox{.*?}", "", string)
|
| |
|
| |
|
| | string.replace("'", "")
|
| | string.replace("\"", "")
|
| |
|
| |
|
| | if "j" in string and "i" not in string:
|
| | string = string.replace("j", "i")
|
| |
|
| |
|
| | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
|
| | string = re.sub(r"(\d+)\.0+$", r"\1", string)
|
| |
|
| |
|
| | if len(string) == 0:
|
| | return string
|
| | if string[0] == ".":
|
| | string = "0" + string
|
| |
|
| |
|
| | if len(string.split("=")) == 2:
|
| | if len(string.split("=")[0]) <= 2:
|
| | string = string.split("=")[1]
|
| |
|
| | string = _fix_sqrt(string)
|
| | string = string.replace(" ", "")
|
| |
|
| |
|
| | string = _fix_fracs(string)
|
| |
|
| |
|
| | string = _fix_a_slash_b(string)
|
| |
|
| | return string
|
| |
|
| | def extract_answer(pred_str):
|
| | if 'boxed' in pred_str:
|
| | ans = pred_str.split('boxed')[-1]
|
| | if len(ans) == 0:
|
| | return ""
|
| | elif (ans[0] == '{'):
|
| | stack = 1
|
| | a = ''
|
| | for c in ans[1:]:
|
| | if (c == '{'):
|
| | stack += 1
|
| | a += c
|
| | elif (c == '}'):
|
| | stack -= 1
|
| | if (stack == 0): break
|
| | a += c
|
| | else:
|
| | a += c
|
| | else:
|
| | a = ans.split('$')[0].strip()
|
| | pred=a
|
| | elif ('he answer is' in pred_str):
|
| | pred = pred_str.split('he answer is')[-1].strip()
|
| | elif extract_program_output(pred_str) != "":
|
| |
|
| | pred = extract_program_output(pred_str)
|
| | else:
|
| | pattern = '-?\d*\.?\d+'
|
| | pred = re.findall(pattern, pred_str.replace(",", ""))
|
| | if(len(pred) >= 1):
|
| | pred = pred[-1]
|
| | else: pred = ''
|
| |
|
| |
|
| | pred = pred.split("\n")[0]
|
| | if pred != "" and pred[0] == ":":
|
| | pred = pred[1:]
|
| | if pred != "" and pred[-1] == ".":
|
| | pred = pred[:-1]
|
| | if pred != "" and pred[-1] == "/":
|
| | pred = pred[:-1]
|
| | pred = strip_string(pred)
|
| | return pred
|
| |
|
| |
|
| | def extract_program(result: str, last_only=True):
|
| | """
|
| | extract the program after "```python", and before "```"
|
| | """
|
| | program = ""
|
| | start = False
|
| | for line in result.split("\n"):
|
| | if line.startswith("```python"):
|
| | if last_only:
|
| | program = ""
|
| | else:
|
| | program += "\n# ========\n"
|
| | start = True
|
| | elif line.startswith("```"):
|
| | start = False
|
| | elif start:
|
| | program += line + "\n"
|
| | return program
|
| |
|
| |
|
| | def extract_program_output(pred_str):
|
| | """
|
| | extract output between the last ```output\n...\n```
|
| | """
|
| | if "```output" not in pred_str:
|
| | return ""
|
| | if '```output' in pred_str:
|
| | pred_str = pred_str.split('```output')[-1]
|
| | if '```' in pred_str:
|
| | pred_str = pred_str.split('```')[0]
|
| | output = pred_str.strip()
|
| | return output
|
| |
|
| |
|
| | def parse_ground_truth(example: Dict[str, Any], data_name):
|
| | if 'gt_cot' in example:
|
| | return example['gt_cot'], strip_string(example['gt'])
|
| |
|
| |
|
| | if data_name in ["math", 'ocw']:
|
| | gt_cot = example['solution']
|
| | gt_ans = extract_answer(gt_cot)
|
| | elif data_name == "gsm8k":
|
| | gt_cot, gt_ans = example['answer'].split("####")
|
| | elif data_name == "gsm-hard":
|
| | gt_cot, gt_ans = example['code'], example['target']
|
| | elif data_name == "svamp":
|
| | gt_cot, gt_ans = example['Equation'], example['Answer']
|
| | elif data_name == "asdiv":
|
| | gt_cot = example['formula']
|
| | gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
|
| | elif data_name == "mawps":
|
| | gt_cot, gt_ans = None, example['target']
|
| | elif data_name == "tabmwp":
|
| | gt_cot = example['solution']
|
| | gt_ans = example['answer']
|
| | if example['ans_type'] in ['integer_number', 'decimal_number']:
|
| | if '/' in gt_ans:
|
| | gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
|
| | elif ',' in gt_ans:
|
| | gt_ans = float(gt_ans.replace(',', ''))
|
| | elif '%' in gt_ans:
|
| | gt_ans = float(gt_ans.split('%')[0]) / 100
|
| | else:
|
| | gt_ans = float(gt_ans)
|
| | elif data_name == "bbh":
|
| | gt_cot, gt_ans = None, example['target']
|
| | else:
|
| | raise NotImplementedError(data_name)
|
| |
|
| | gt_cot = str(gt_cot).strip()
|
| | gt_ans = strip_string(gt_ans)
|
| | return gt_cot, gt_ans
|
| |
|
| |
|
| | def parse_question(example, data_name):
|
| | question = ""
|
| | if data_name == "asdiv":
|
| | question = f"{example['body'].strip()} {example['question'].strip()}"
|
| | elif data_name == "svamp":
|
| | body = example["Body"].strip()
|
| | if not body.endswith("."):
|
| | body = body + "."
|
| | question = f'{body} {example["Question"].strip()}'
|
| | elif data_name == "tabmwp":
|
| | title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
|
| | question = f'Read the following table {title_str}and answer a question:\n'
|
| | question += f'{example["table"]}\n{example["question"]}'
|
| | if example['choices']:
|
| | question += f' Please select from the following options: {example["choices"]}'
|
| | else:
|
| | for key in ['question', 'problem', 'Question', 'input']:
|
| | if key in example:
|
| | question = example[key]
|
| | break
|
| | assert question != ""
|
| | return question.strip()
|
| |
|
| |
|
| | def run_execute(executor, result, prompt_type, execute=False):
|
| | if not result or result == 'error':
|
| | return None, None
|
| | report = None
|
| |
|
| | if "program_only" in prompt_type:
|
| | prediction = extract_program_output(result)
|
| | elif prompt_type in ["pot", "pal"] and execute:
|
| | code = extract_program(result)
|
| | prediction, report = executor.apply(code)
|
| | else:
|
| | prediction = extract_answer(result)
|
| |
|
| | prediction = strip_string(prediction)
|
| | return prediction, report
|
| |
|