-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This version doesn't create a dict each time function is called. This reduces memory consumption and increases speed. Each enum name is directly mapped to its corresponding string without looking up in dictionary, this is more efficient and quicker in execution.
- Loading branch information
1 parent
953a6a7
commit 40e9b44
Showing
1 changed file
with
174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from enum import Enum | ||
from typing import Optional, Iterator, List | ||
|
||
from pydantic.dataclasses import dataclass | ||
|
||
from codeflash.verification.comparator import comparator | ||
|
||
|
||
class TestType(Enum): | ||
EXISTING_UNIT_TEST = 1 | ||
INSPIRED_REGRESSION = 2 | ||
GENERATED_REGRESSION = 3 | ||
|
||
def to_name(self) -> str: | ||
if self == TestType.EXISTING_UNIT_TEST: | ||
return "⚙️ Existing Unit Tests" | ||
elif self == TestType.INSPIRED_REGRESSION: | ||
return "🎨 Inspired Regression Tests" | ||
elif self == TestType.GENERATED_REGRESSION: | ||
return "🌀 Generated Regression Tests" | ||
|
||
|
||
@dataclass(frozen=True) | ||
class InvocationId: | ||
test_module_path: str # The fully qualified name of the test module | ||
test_class_name: Optional[str] # The name of the class where the test is defined | ||
test_function_name: ( | ||
str # The name of the test_function. Does not include the components of the file_name | ||
) | ||
function_getting_tested: str | ||
iteration_id: Optional[str] | ||
|
||
# test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id | ||
def id(self): | ||
return f"{self.test_module_path}:{self.test_class_name or ''}.{self.test_function_name}:{self.function_getting_tested}:{self.iteration_id}" | ||
|
||
@staticmethod | ||
def from_str_id(string_id: str): | ||
components = string_id.split(":") | ||
assert len(components) == 4 | ||
second_components = components[1].split(".") | ||
if len(second_components) == 1: | ||
test_class_name = None | ||
test_function_name = second_components[0] | ||
else: | ||
test_class_name = second_components[0] | ||
test_function_name = second_components[1] | ||
return InvocationId( | ||
test_module_path=components[0], | ||
test_class_name=test_class_name, | ||
test_function_name=test_function_name, | ||
function_getting_tested=components[2], | ||
iteration_id=components[3], | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class FunctionTestInvocation: | ||
id: InvocationId # The fully qualified name of the function invocation (id) | ||
file_name: str # The file where the test is defined | ||
did_pass: bool # Whether the test this function invocation was part of, passed or failed | ||
runtime: Optional[int] # Time in nanoseconds | ||
test_framework: str # unittest or pytest | ||
test_type: TestType | ||
return_value: Optional[object] # The return value of the function invocation | ||
|
||
|
||
class TestResults: | ||
test_results: list[FunctionTestInvocation] | ||
|
||
def __init__(self, test_results=None): | ||
if test_results is None: | ||
test_results = [] | ||
self.test_results = test_results | ||
|
||
def add(self, function_test_invocation: FunctionTestInvocation) -> None: | ||
self.test_results.append(function_test_invocation) | ||
|
||
def merge(self, other: "TestResults") -> None: | ||
self.test_results.extend(other.test_results) | ||
|
||
def get_by_id(self, invocation_id: InvocationId) -> Optional[FunctionTestInvocation]: | ||
return next((r for r in self.test_results if r.id == invocation_id), None) | ||
|
||
def get_all_ids(self) -> List[InvocationId]: | ||
return [test_result.id for test_result in self.test_results] | ||
|
||
def get_test_pass_fail_report(self) -> str: | ||
passed = 0 | ||
failed = 0 | ||
for test_result in self.test_results: | ||
if test_result.did_pass: | ||
passed += 1 | ||
else: | ||
failed += 1 | ||
return f"Passed: {passed}, Failed: {failed}" | ||
|
||
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: | ||
report = {} | ||
for test_type in TestType: | ||
report[test_type] = {"passed": 0, "failed": 0} | ||
for test_result in self.test_results: | ||
if test_result.did_pass: | ||
report[test_result.test_type]["passed"] += 1 | ||
else: | ||
report[test_result.test_type]["failed"] += 1 | ||
return report | ||
|
||
@staticmethod | ||
def report_to_string(report: dict[TestType, dict[str, int]]) -> str: | ||
return " ".join( | ||
[ | ||
f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})" | ||
for test_type in TestType | ||
] | ||
) | ||
|
||
def total_passed_runtime(self) -> int: | ||
for result in self.test_results: | ||
if result.did_pass and result.runtime is None: | ||
logging.debug(f"Ignoring test case that passed but had no runtime -> {result.id}") | ||
timing = sum( | ||
[ | ||
result.runtime | ||
for result in self.test_results | ||
if (result.did_pass and result.runtime is not None) | ||
] | ||
) | ||
return timing | ||
|
||
def __iter__(self) -> Iterator[FunctionTestInvocation]: | ||
return iter(self.test_results) | ||
|
||
def __len__(self) -> int: | ||
return len(self.test_results) | ||
|
||
def __getitem__(self, index: int) -> FunctionTestInvocation: | ||
return self.test_results[index] | ||
|
||
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None: | ||
self.test_results[index] = value | ||
|
||
def __delitem__(self, index: int) -> None: | ||
del self.test_results[index] | ||
|
||
def __contains__(self, value: FunctionTestInvocation) -> bool: | ||
return value in self.test_results | ||
|
||
def __bool__(self) -> bool: | ||
return bool(self.test_results) | ||
|
||
def __eq__(self, other: TestResults): | ||
# Unordered comparison | ||
if type(self) != type(other): | ||
return False | ||
if len(self) != len(other): | ||
return False | ||
for test_result in self: | ||
other_test_result = other.get_by_id(test_result.id) | ||
if other_test_result is None: | ||
return False | ||
if ( | ||
test_result.file_name != other_test_result.file_name | ||
or test_result.did_pass != other_test_result.did_pass | ||
or test_result.runtime != other_test_result.runtime | ||
or test_result.test_framework != other_test_result.test_framework | ||
or test_result.test_type != other_test_result.test_type | ||
or not comparator(test_result.return_value, other_test_result.return_value) | ||
): | ||
return False | ||
return True |