Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions easybuild/easyblocks/p/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,28 @@ def test_step(self):

tests_out, tests_ec = test_result

# Show failed subtests to aid in debugging failures
# I.e. patterns like
# === FAIL: test_add_scalar_relu (quantization.core.test_quantized_op.TestQuantizedOps) ===
# --- ERROR: test_all_to_all_group_cuda (__main__.TestDistBackendWithSpawn) ---
regex = r"^[=-]+\n(FAIL|ERROR): (test_.*?)\s\(.*\n[=-]+\n"
failed_test_cases = re.findall(regex, tests_out, re.M)
# And patterns like:
# FAILED test_ops_gradients.py::TestGradientsCPU::test_fn_grad_linalg_det_singular_cpu_complex128 - [snip]
regex = r"^(FAILED) \w+\.py.*::(test_.*?) - "
failed_test_cases.extend(re.findall(regex, tests_out, re.M))
if failed_test_cases:
errored_test_cases = sorted(m[1] for m in failed_test_cases if m[0] == 'ERROR')
failed_test_cases = sorted(m[1] for m in failed_test_cases if m[0] != 'ERROR')
msg = []
if errored_test_cases:
msg.append("Found %d individual tests that exited with an error: %s"
% (len(errored_test_cases), ', '.join(errored_test_cases)))
if failed_test_cases:
msg.append("Found %d individual tests with failed assertions: %s"
% (len(failed_test_cases), ', '.join(failed_test_cases)))
self.log.warning("\n".join(msg))

def get_count_for_pattern(regex, text):
"""Match the regexp containing a single group and return the integer value of the matched group.
Return zero if no or more than 1 match was found and warn for the latter case
Expand Down Expand Up @@ -308,7 +330,7 @@ def get_count_for_pattern(regex, text):
# test_fx failed!
regex = (r"^Ran (?P<test_cnt>[0-9]+) tests.*$\n\n"
r"FAILED \((?P<failure_summary>.*)\)$\n"
r"(?:^(?:(?!failed!).)*$\n)*"
r"(?:^(?:(?!failed!).)*$\n){0,5}"
r"(?P<failed_test_suite_name>.*) failed!(?: Received signal: \w+)?\s*$")

for m in re.finditer(regex, tests_out, re.M):
Expand All @@ -324,7 +346,17 @@ def get_count_for_pattern(regex, text):

# Grep for patterns like:
# ===================== 2 failed, 128 passed, 2 skipped, 2 warnings in 3.43s =====================
regex = r"^=+ (?P<failure_summary>.*) in [0-9]+\.*[0-9]*[a-zA-Z]* =+$\n(?P<failed_test_suite_name>.*) failed!$"
# test_quantization failed!
# OR:
# ===================== 2 failed, 128 passed, 2 skipped, 2 warnings in 63.43s (01:03:43) =========
# FINISHED PRINTING LOG FILE
# test_quantization failed!

regex = (
r"^=+ (?P<failure_summary>.*) in [0-9]+\.*[0-9]*[a-zA-Z]* (\([0-9]+:[0-9]+:[0-9]+\) )?=+$\n"
r"(?:^(?:(?!failed!).)*$\n){0,5}"
r"(?P<failed_test_suite_name>.*) failed!$"
)

for m in re.finditer(regex, tests_out, re.M):
# E.g. '2 failed, 128 passed, 2 skipped, 2 warnings'
Expand Down