Skip to content

Commit

Permalink
Merge branch 'development' into arksap2002/bugs/fix-error-window-size
Browse files Browse the repository at this point in the history
  • Loading branch information
arksap2002 committed Aug 19, 2024
2 parents c91275f + edfdba2 commit 7ba71b5
Show file tree
Hide file tree
Showing 65 changed files with 2,936 additions and 1,168 deletions.
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ dependencies {

// https://mvnrepository.com/artifact/org.mockito/mockito-all
testImplementation("org.mockito:mockito-all:1.10.19")
testImplementation("org.mockito.kotlin:mockito-kotlin:5.1.0")

// https://mvnrepository.com/artifact/net.jqwik/jqwik
testImplementation("net.jqwik:jqwik:1.6.5")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ data class TestGenerationData(

// Code required of imports and package for generated tests
var importsCode: MutableSet<String> = mutableSetOf(),
var packageLine: String = "",
var packageName: String = "",
var runWith: String = "",
var otherInfo: String = "",

Expand All @@ -37,7 +37,7 @@ data class TestGenerationData(
resultName = ""
fileUrl = ""
importsCode = mutableSetOf()
packageLine = ""
packageName = ""
runWith = ""
otherInfo = ""
polyDepthReducing = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeRed
import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.TestCompiler
import org.jetbrains.research.testspark.core.test.TestsAssembler
import org.jetbrains.research.testspark.core.test.TestsPersistentStorage
Expand Down Expand Up @@ -44,6 +45,7 @@ data class FeedbackResponse(

class LLMWithFeedbackCycle(
private val report: Report,
private val language: SupportedLanguage,
private val initialPromptMessage: String,
private val promptSizeReductionStrategy: PromptSizeReductionStrategy,
// filename in which the test suite is saved in result path
Expand Down Expand Up @@ -99,6 +101,7 @@ class LLMWithFeedbackCycle(
// clearing test assembler's collected text on the previous attempts
testsAssembler.clear()
val response: LLMResponse = requestManager.request(
language = language,
prompt = nextPromptMessage,
indicator = indicator,
packageName = packageName,
Expand All @@ -119,6 +122,7 @@ class LLMWithFeedbackCycle(
continue
}
}

ResponseErrorCode.PROMPT_TOO_LONG -> {
if (promptSizeReductionStrategy.isReductionPossible()) {
nextPromptMessage = promptSizeReductionStrategy.reduceSizeAndGeneratePrompt()
Expand All @@ -132,11 +136,13 @@ class LLMWithFeedbackCycle(
break
}
}

ResponseErrorCode.EMPTY_LLM_RESPONSE -> {
nextPromptMessage =
"You have provided an empty answer! Please, answer my previous question with the same formats"
continue
}

ResponseErrorCode.TEST_SUITE_PARSING_FAILURE -> {
onWarningCallback?.invoke(WarningType.TEST_SUITE_PARSING_FAILED)
log.info { "Cannot parse a test suite from the LLM response. LLM response: '$response'" }
Expand All @@ -161,12 +167,15 @@ class LLMWithFeedbackCycle(
generatedTestSuite.updateTestCases(compilableTestCases.toMutableList())
} else {
for (testCaseIndex in generatedTestSuite.testCases.indices) {
val testCaseFilename = "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java"
val testCaseFilename = when (language) {
SupportedLanguage.Java -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java"
SupportedLanguage.Kotlin -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.kt"
}

val testCaseRepresentation = testsPresenter.representTestCase(generatedTestSuite, testCaseIndex)

val saveFilepath = testStorage.saveGeneratedTest(
generatedTestSuite.packageString,
generatedTestSuite.packageName,
testCaseRepresentation,
resultPath,
testCaseFilename,
Expand All @@ -177,7 +186,7 @@ class LLMWithFeedbackCycle(
}

val generatedTestSuitePath: String = testStorage.saveGeneratedTest(
generatedTestSuite.packageString,
generatedTestSuite.packageName,
testsPresenter.representTestSuite(generatedTestSuite),
resultPath,
testSuiteFilename,
Expand Down Expand Up @@ -205,8 +214,10 @@ class LLMWithFeedbackCycle(
// Compile the test file
indicator.setText("Compilation tests checking")

val testCasesCompilationResult = testCompiler.compileTestCases(generatedTestCasesPaths, buildPath, testCases)
val testSuiteCompilationResult = testCompiler.compileCode(File(generatedTestSuitePath).absolutePath, buildPath)
val testCasesCompilationResult =
testCompiler.compileTestCases(generatedTestCasesPaths, buildPath, testCases)
val testSuiteCompilationResult =
testCompiler.compileCode(File(generatedTestSuitePath).absolutePath, buildPath)

// saving the compilable test cases
compilableTestCases.addAll(testCasesCompilationResult.compilableTestCases)
Expand All @@ -216,7 +227,8 @@ class LLMWithFeedbackCycle(

onWarningCallback?.invoke(WarningType.COMPILATION_ERROR_OCCURRED)

nextPromptMessage = "I cannot compile the tests that you provided. The error is:\n${testSuiteCompilationResult.second}\n Fix this issue in the provided tests.\nGenerate public classes and public methods. Response only a code with tests between ```, do not provide any other text."
nextPromptMessage =
"I cannot compile the tests that you provided. The error is:\n${testSuiteCompilationResult.second}\n Fix this issue in the provided tests.\nGenerate public classes and public methods. Response only a code with tests between ```, do not provide any other text."
log.info { nextPromptMessage }
continue
}
Expand All @@ -226,7 +238,8 @@ class LLMWithFeedbackCycle(
generatedTestsArePassing = true

for (index in testCases.indices) {
report.testCaseList[index] = TestCase(index, testCases[index].name, testCases[index].toString(), setOf())
report.testCaseList[index] =
TestCase(index, testCases[index].name, testCases[index].toString(), setOf())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,45 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag
import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.TestsAssembler
import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
import org.jetbrains.research.testspark.core.utils.javaPackagePattern
import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern
import java.util.Locale

// TODO: find a better place for the below functions

/**
* Retrieves the package declaration from the given test suite code for any language.
*
* @param testSuiteCode The generated code of the test suite.
* @return The package name extracted from the test suite code, or an empty string if no package declaration was found.
*/
fun getPackageFromTestSuiteCode(testSuiteCode: String?, language: SupportedLanguage): String {
testSuiteCode ?: return ""
return when (language) {
SupportedLanguage.Kotlin -> kotlinPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty()
SupportedLanguage.Java -> javaPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty()
}
}

/**
* Retrieves the imports code from a given test suite code.
*
* @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned.
* @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result.
* @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned.
*/
fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String): MutableSet<String> {
testSuiteCode ?: return mutableSetOf()
return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence()
.filter { it.contains("^import".toRegex()) }
.filterNot { it.contains("evosuite".toRegex()) }
.filterNot { it.contains("RunWith".toRegex()) }
.filterNot { it.contains(classFQN.toRegex()) }.toMutableSet()
}

/**
* Returns the generated class name for a given test case.
*
Expand Down Expand Up @@ -38,6 +71,7 @@ fun getClassWithTestCaseName(testCaseName: String): String {
* @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null.
*/
fun executeTestCaseModificationRequest(
language: SupportedLanguage,
testCase: String,
task: String,
indicator: CustomProgressIndicator,
Expand All @@ -48,17 +82,10 @@ fun executeTestCaseModificationRequest(
// Update Token information
val prompt = "For this test:\n ```\n $testCase\n ```\nPerform the following task: $task"

var packageName = ""
testCase.split("\n")[0].let {
if (it.startsWith("package")) {
packageName = it
.removePrefix("package ")
.removeSuffix(";")
.trim()
}
}
val packageName = getPackageFromTestSuiteCode(testCase, language)

val response = requestManager.request(
language,
prompt,
indicator,
packageName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.jetbrains.research.testspark.core.data.ChatUserMessage
import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.TestsAssembler

abstract class RequestManager(var token: String) {
Expand All @@ -30,6 +31,7 @@ abstract class RequestManager(var token: String) {
* @return the generated TestSuite, or null and prompt message
*/
open fun request(
language: SupportedLanguage,
prompt: String,
indicator: CustomProgressIndicator,
packageName: String,
Expand All @@ -55,14 +57,15 @@ abstract class RequestManager(var token: String) {
}

return when (isUserFeedback) {
true -> processUserFeedbackResponse(testsAssembler, packageName)
false -> processResponse(testsAssembler, packageName)
true -> processUserFeedbackResponse(testsAssembler, packageName, language)
false -> processResponse(testsAssembler, packageName, language)
}
}

open fun processResponse(
testsAssembler: TestsAssembler,
packageName: String,
language: SupportedLanguage,
): LLMResponse {
// save the full response in the chat history
val response = testsAssembler.getContent()
Expand All @@ -75,7 +78,7 @@ abstract class RequestManager(var token: String) {
return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null)
}

val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName)
val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite()

return if (testSuiteGeneratedByLLM == null) {
LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null)
Expand All @@ -94,6 +97,7 @@ abstract class RequestManager(var token: String) {
open fun processUserFeedbackResponse(
testsAssembler: TestsAssembler,
packageName: String,
language: SupportedLanguage,
): LLMResponse {
val response = testsAssembler.getContent()

Expand All @@ -104,7 +108,7 @@ abstract class RequestManager(var token: String) {
return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null)
}

val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName)
val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite()

return if (testSuiteGeneratedByLLM == null) {
LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ internal class PromptBuilder(private var prompt: String) {
fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n"
}
for (interestingClass in interestingClasses) {
if (interestingClass.qualifiedName.startsWith("java")) {
if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) {
continue
}

Expand All @@ -88,7 +88,9 @@ internal class PromptBuilder(private var prompt: String) {
// Skip java methods
// TODO: checks for java methods should be done by a caller to make
// this class as abstract and language agnostic as possible.
if (method.containingClassQualifiedName.startsWith("java")) {
if (method.containingClassQualifiedName.startsWith("java") ||
method.containingClassQualifiedName.startsWith("kotlin")
) {
continue
}

Expand All @@ -106,8 +108,11 @@ internal class PromptBuilder(private var prompt: String) {
) = apply {
val keyword = "\$${PromptKeyword.POLYMORPHISM.text}"
if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) {
var fullText = ""

// If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable
var fullText = when {
polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable"
else -> ""
}
polymorphismRelations.forEach { entry ->
for (currentSubClass in entry.value) {
val subClassTypeName = when (currentSubClass.classType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.jetbrains.research.testspark.core.test

/**
* Language ID string should be the same as the language name in com.intellij.lang.Language
*/
enum class SupportedLanguage(val languageId: String) {
Java("JAVA"), Kotlin("kotlin")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package org.jetbrains.research.testspark.core.test

import org.jetbrains.research.testspark.core.test.data.TestLine

interface TestBodyPrinter {
/**
* Generates a test body as a string based on the provided parameters.
*
* @param testInitiatedText A string containing the upper part of the test case.
* @param lines A mutable list of `TestLine` objects representing the lines of the test body.
* @param throwsException The exception type that the test function throws, if any.
* @param name The name of the test function.
* @return A string representing the complete test body.
*/
fun printTestBody(
testInitiatedText: String,
lines: MutableList<TestLine>,
throwsException: String,
name: String,
): String
}
Loading

0 comments on commit 7ba71b5

Please sign in to comment.