From 98c41c65b130b93e7b98cef43087816f86f67699 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sat, 20 Apr 2024 13:46:09 +0000 Subject: [PATCH 1/4] Fix _fix_fracs in MATH normalization --- src/lighteval/metrics/normalizations.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/lighteval/metrics/normalizations.py b/src/lighteval/metrics/normalizations.py index 2bf180007..f92b2c845 100644 --- a/src/lighteval/metrics/normalizations.py +++ b/src/lighteval/metrics/normalizations.py @@ -135,8 +135,8 @@ def _last_boxed_only_string(text: str) -> str | None: return retval - def _fix_fracs(text: str) -> str: - substrs = text.split("\\frac") + def _fix_fracs(string): + substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] @@ -147,9 +147,10 @@ def _fix_fracs(text: str) -> str: else: try: assert len(substr) >= 2 - except AssertionError: - return text - a, b = substr + except Exception: + return string + a = substr[0] + b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] @@ -162,8 +163,8 @@ def _fix_fracs(text: str) -> str: new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b - text = new_str - return text + string = new_str + return string def _fix_a_slash_b(text: str) -> str: """Source: https://github.com/hendrycks/math From a8fa3f562833c844f8523319d105613b67302c75 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sat, 20 Apr 2024 13:47:41 +0000 Subject: [PATCH 2/4] Refactor --- src/lighteval/metrics/normalizations.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lighteval/metrics/normalizations.py b/src/lighteval/metrics/normalizations.py index f92b2c845..5e39cabd2 100644 --- a/src/lighteval/metrics/normalizations.py +++ b/src/lighteval/metrics/normalizations.py @@ -135,8 +135,8 @@ def _last_boxed_only_string(text: str) -> str | None: return retval - def _fix_fracs(string): - substrs = string.split("\\frac") + def _fix_fracs(text: str) -> str: + substrs = text.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] @@ -148,7 +148,7 @@ def _fix_fracs(string): try: assert len(substr) >= 2 except Exception: - return string + return text a = substr[0] b = substr[1] if b != "{": @@ -163,8 +163,8 @@ def _fix_fracs(string): new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b - string = new_str - return string + text = new_str + return text def _fix_a_slash_b(text: str) -> str: """Source: https://github.com/hendrycks/math From 12c3fb662080896b78c05e8b6819f8311962d18a Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sat, 20 Apr 2024 13:48:17 +0000 Subject: [PATCH 3/4] Clean --- src/lighteval/metrics/normalizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lighteval/metrics/normalizations.py b/src/lighteval/metrics/normalizations.py index 5e39cabd2..fa7ff0f04 100644 --- a/src/lighteval/metrics/normalizations.py +++ b/src/lighteval/metrics/normalizations.py @@ -147,7 +147,7 @@ def _fix_fracs(text: str) -> str: else: try: assert len(substr) >= 2 - except Exception: + except AssertionError: return text a = substr[0] b = substr[1] From 8a3cbe0632fb66c4057873815d833176144c8f7e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sat, 20 Apr 2024 14:10:52 +0000 Subject: [PATCH 4/4] Add docstring --- src/lighteval/metrics/normalizations.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/lighteval/metrics/normalizations.py b/src/lighteval/metrics/normalizations.py index fa7ff0f04..b5b32d3e6 100644 --- a/src/lighteval/metrics/normalizations.py +++ b/src/lighteval/metrics/normalizations.py @@ -136,6 +136,24 @@ def _last_boxed_only_string(text: str) -> str | None: return retval def _fix_fracs(text: str) -> str: + """ + Fix the formatting of fractions in the given text. + Copied from: https://github.com/hendrycks/math/blob/357963a7f5501a6c1708cf3f3fb0cdf525642761/modeling/math_equivalence.py#L1 + + Args: + text (str): The input text. + + Returns: + str: The text with properly formatted fractions. + + Examples: + >>> _fix_fracs("\\frac12") + "\\frac{1}{2}" + >>> _fix_fracs("\\frac{3}{4}") + "\\frac{3}{4}" + >>> _fix_fracs("\\frac1{2}") + "\\frac{1}{2}" + """ substrs = text.split("\\frac") new_str = substrs[0] if len(substrs) > 1: