Skip to content

Commit

Permalink
add hrm8k benchmark for both Korean and English (#2627)
Browse files Browse the repository at this point in the history
* add hrm8k benchmark for both Korean and English

* apply precommit

* revise tasks to make models not to directly answer; use zeroshot_cot if possible

* add README

* Add hrm8k on the task-list

---------

Co-authored-by: Baber <[email protected]>
  • Loading branch information
bzantium and baberabb authored Jan 20, 2025
1 parent f724be6 commit a5c344c
Show file tree
Hide file tree
Showing 18 changed files with 848 additions and 129 deletions.
259 changes: 130 additions & 129 deletions lm_eval/tasks/README.md

Large diffs are not rendered by default.

46 changes: 46 additions & 0 deletions lm_eval/tasks/hrm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# HRM8K

### Paper

Title: [Understand, Solve and Translate: Bridging the Multilingual Mathematical Reasoning Gap](https://www.arxiv.org/abs/2501.02448)

Large language models (LLMs) demonstrate exceptional performance on complex reasoning tasks. However, despite their strong reasoning capabilities in high-resource languages (e.g., English and Chinese), a significant performance gap persists in other languages. To investigate this gap in Korean, we introduce HRM8K, a benchmark comprising 8,011 English-Korean parallel bilingual math problems. Through systematic analysis of model behaviors, we identify a key finding: these performance disparities stem primarily from difficulties in comprehending non-English inputs, rather than limitations in reasoning capabilities. Based on these findings, we propose UST (Understand, Solve, and Translate), a method that strategically uses English as an anchor for reasoning and solution generation. By fine-tuning the model on 130k synthetically generated data points, UST achieves a 10.91% improvement on the HRM8K benchmark and reduces the multilingual performance gap from 11.6% to 0.7%. Additionally, we show that improvements from UST generalize effectively to different Korean domains, demonstrating that capabilities acquired from machine-verifiable content can be generalized to other areas. We publicly release the benchmark, training dataset, and models.

Homepage: https://huggingface.co/datasets/HAERAE-HUB/HRM8K


### Citation

```
@article{ko2025understand,
title={Understand, Solve and Translate: Bridging the Multilingual Mathematical Reasoning Gap},
author={Ko, Hyunwoo and Son, Guijin and Choi, Dasol},
journal={arXiv preprint arXiv:2501.02448},
year={2025}
}
```

### Groups and and Tasks

#### Groups

* `hrm8k`: HRM8K comprises 8,011 instances for evaluation, sourced through a combination of translations from established English benchmarks (e.g., GSM8K, MATH, OmniMath, MMMLU) and original problems curated from existing Korean math exams. This benchmark consists of Korean instruction and question.
* `hrm8k_en`: English version of `hrm8k`. This benchmark consists of English instruction and question.

#### Tasks

* `hrm8k_{gsm8k|ksm|math|mmmlu|omni_math}`
* `hrm8k_en_{gsm8k|ksm|math|mmmlu|omni_math}`

### Checklist

For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?


If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
22 changes: 22 additions & 0 deletions lm_eval/tasks/hrm8k/default/_hrm8k_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
dataset_path: HAERAE-HUB/HRM8K
output_type: generate_until
test_split: test
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
num_fewshot: 0
generation_kwargs:
until:
- "</s>"
- "<|end_of_text|>"
- "<|endoftext|>"
- "<|im_end|>"
max_gen_toks: 512
do_sample: false
temperature: 0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
13 changes: 13 additions & 0 deletions lm_eval/tasks/hrm8k/default/hrm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
group: hrm8k
task:
- hrm8k_gsm8k
- hrm8k_ksm
- hrm8k_math
- hrm8k_mmmlu
- hrm8k_omni_math
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
3 changes: 3 additions & 0 deletions lm_eval/tasks/hrm8k/default/hrm8k_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
include: _hrm8k_yaml
dataset_name: GSM8K
task: hrm8k_gsm8k
3 changes: 3 additions & 0 deletions lm_eval/tasks/hrm8k/default/hrm8k_ksm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
include: _hrm8k_yaml
dataset_name: KSM
task: hrm8k_ksm
3 changes: 3 additions & 0 deletions lm_eval/tasks/hrm8k/default/hrm8k_math.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
include: _hrm8k_yaml
dataset_name: MATH
task: hrm8k_math
4 changes: 4 additions & 0 deletions lm_eval/tasks/hrm8k/default/hrm8k_mmmlu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: _hrm8k_yaml
dataset_name: MMMLU
task: hrm8k_mmmlu
doc_to_text: !function utils.doc_to_text_mmmlu
3 changes: 3 additions & 0 deletions lm_eval/tasks/hrm8k/default/hrm8k_omni_math.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
include: _hrm8k_yaml
dataset_name: OMNI_MATH
task: hrm8k_omni_math
285 changes: 285 additions & 0 deletions lm_eval/tasks/hrm8k/default/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
import re
from typing import Dict, List


def doc_to_text(doc):
text = (
"주어진 문제를 풀어보세요.\n"
"문제를 푼 후, 최종 답변을 다음과 같은 형식으로 작성하세요: $\\boxed{N}$.\n\n"
f"문제: {doc['question'].strip()}\n답변:"
)
return text


def doc_to_text_mmmlu(doc):
text = (
"주어진 문제를 풀어보세요.\n"
"문제를 푼 후, 주어진 선택지 (1, 2, 3, 4) 중 최종 선택지를 다음 형식으로 작성하세요: $\\boxed{N}$.\n\n"
f"문제: {doc['question'].strip()}\n답변:"
)
return text


def doc_to_target(doc):
return postprocess(doc["answer"])


def postprocess(s):
s = str(s).strip()
try:
float_value = float(s)
return str(int(float_value)) if float_value.is_integer() else str(float_value)
except Exception:
return s


def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
candidate = results[0]

gold = postprocess(doc["answer"])

if not gold:
print(doc, candidate, gold)
if is_equiv(candidate, gold):
retval = 1
else:
retval = 0

results = {
"exact_match": retval,
}
return results


def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False

str1, str2 = parse_math_answer(str1), parse_math_answer(str2)

try:
ss1 = _strip_string(str1)
ss1 = postprocess(ss1)
ss2 = _strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2


def parse_math_answer(raw_string):
def remove_boxed(s):
left = "\\boxed{"
try:
assert s[: len(left)] == left
assert s[-1] == "}"
answer = s[len(left) : -1]
if "=" in answer:
answer = answer.split("=")[-1].lstrip(" ")
return answer
except Exception:
return None

def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1

if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]

return retval

def get_answer_with_dollar_sign(s):
first_pattern = "\$(.*)\$"
last_match = None
matches = re.findall(first_pattern, s)
if matches:
last_match = matches[-1]
if "=" in last_match:
last_match = last_match.split("=")[-1].lstrip(" ")
return last_match

def get_answer_without_dollar_sign(s):
last_match = None
if "=" in s:
last_match = s.split("=")[-1].lstrip(" ").rstrip(".")
if "\\n" in last_match:
last_match = last_match.split("\\n")[0]
else:
pattern = "(?:\\$)?\d+(?:\.\d+)?(?![\w\d])"
matches = re.findall(pattern, s)
if matches:
last_match = matches[-1]
return last_match

if "\\boxed" in raw_string:
answer = remove_boxed(last_boxed_only_string(raw_string))
else:
answer = get_answer_with_dollar_sign(raw_string)
if not answer:
answer = get_answer_without_dollar_sign(raw_string)
return answer


# code from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except Exception:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string


def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except Exception:
return string


def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string


def _fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string


def _strip_string(string):
# linebreaks
string = string.replace("\n", "")
# print(string)

# remove inverse spaces
string = string.replace("\\!", "")
# print(string)

# replace \\ with \
string = string.replace("\\\\", "\\")
# print(string)

# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# print(string)

# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# print(string)

# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")

# remove dollar signs
string = string.replace("\\$", "")

# remove units (on the right)
string = _remove_right_units(string)

# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")

# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string

# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]

# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)

# remove spaces
string = string.replace(" ", "")

# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)

# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"

# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)

return string
Loading

0 comments on commit a5c344c

Please sign in to comment.