diff --git a/easybuild/easyblocks/p/pytorch.py b/easybuild/easyblocks/p/pytorch.py index d7a18e59982..86da42a51a4 100644 --- a/easybuild/easyblocks/p/pytorch.py +++ b/easybuild/easyblocks/p/pytorch.py @@ -280,6 +280,28 @@ def test_step(self): tests_out, tests_ec = test_result + # Show failed subtests to aid in debugging failures + # I.e. patterns like + # === FAIL: test_add_scalar_relu (quantization.core.test_quantized_op.TestQuantizedOps) === + # --- ERROR: test_all_to_all_group_cuda (__main__.TestDistBackendWithSpawn) --- + regex = r"^[=-]+\n(FAIL|ERROR): (test_.*?)\s\(.*\n[=-]+\n" + failed_test_cases = re.findall(regex, tests_out, re.M) + # And patterns like: + # FAILED test_ops_gradients.py::TestGradientsCPU::test_fn_grad_linalg_det_singular_cpu_complex128 - [snip] + regex = r"^(FAILED) \w+\.py.*::(test_.*?) - " + failed_test_cases.extend(re.findall(regex, tests_out, re.M)) + if failed_test_cases: + errored_test_cases = sorted(m[1] for m in failed_test_cases if m[0] == 'ERROR') + failed_test_cases = sorted(m[1] for m in failed_test_cases if m[0] != 'ERROR') + msg = [] + if errored_test_cases: + msg.append("Found %d individual tests that exited with an error: %s" + % (len(errored_test_cases), ', '.join(errored_test_cases))) + if failed_test_cases: + msg.append("Found %d individual tests with failed assertions: %s" + % (len(failed_test_cases), ', '.join(failed_test_cases))) + self.log.warning("\n".join(msg)) + def get_count_for_pattern(regex, text): """Match the regexp containing a single group and return the integer value of the matched group. Return zero if no or more than 1 match was found and warn for the latter case @@ -308,7 +330,7 @@ def get_count_for_pattern(regex, text): # test_fx failed! regex = (r"^Ran (?P[0-9]+) tests.*$\n\n" r"FAILED \((?P.*)\)$\n" - r"(?:^(?:(?!failed!).)*$\n)*" + r"(?:^(?:(?!failed!).)*$\n){0,5}" r"(?P.*) failed!(?: Received signal: \w+)?\s*$") for m in re.finditer(regex, tests_out, re.M): @@ -324,7 +346,17 @@ def get_count_for_pattern(regex, text): # Grep for patterns like: # ===================== 2 failed, 128 passed, 2 skipped, 2 warnings in 3.43s ===================== - regex = r"^=+ (?P.*) in [0-9]+\.*[0-9]*[a-zA-Z]* =+$\n(?P.*) failed!$" + # test_quantization failed! + # OR: + # ===================== 2 failed, 128 passed, 2 skipped, 2 warnings in 63.43s (01:03:43) ========= + # FINISHED PRINTING LOG FILE + # test_quantization failed! + + regex = ( + r"^=+ (?P.*) in [0-9]+\.*[0-9]*[a-zA-Z]* (\([0-9]+:[0-9]+:[0-9]+\) )?=+$\n" + r"(?:^(?:(?!failed!).)*$\n){0,5}" + r"(?P.*) failed!$" + ) for m in re.finditer(regex, tests_out, re.M): # E.g. '2 failed, 128 passed, 2 skipped, 2 warnings'