Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the reasult not well #49

Open
primary-studyer opened this issue Jul 1, 2020 · 13 comments
Open

the reasult not well #49

primary-studyer opened this issue Jul 1, 2020 · 13 comments

Comments

@primary-studyer
Copy link

dcs->epcohStep=260000
top 10 ACC=0.767, MRR=0.32433587301587297, MAP=0.32433587301587297, nDCG=0.42961201846689157
top 5 ACC=0.6995, MRR=0.44343166666666667, MAP=0.44343166666666667, nDCG=0.5078651066901124
top 1 ACC=0.4761, MRR=0.4761, MAP=0.4761, nDCG=0.4761
数据集是codesearchnet中提供的Java数据,这是我训练过程达到的最优结果,poolsize设置的1000,达不到您之前说的结果要在0.9以上。
我将数据集按您所说划分为train和valid部分。感觉valid起的作用和test部分一样。
执行search的结果非常糟糕。我应该如何解决这个问题,使得search结果明显一些?
我之前用了您提供的epoch500来在大的codebase运行的时候,结果也是相关的比较少。我当时没找到原因,现在到我自己处理的时候,结果也这样,非常期待回复。

@guxd
Copy link
Owner

guxd commented Jul 1, 2020

请问你是用的pytorch版吗?
你的数据集可能偏小,需要重新调参。
我提供epoch500的时候可能还没有用automl调参,后来测的pytorch能达到0.9以上。
另外poolsize设为10,000或100,000更合理。

@primary-studyer
Copy link
Author

请问你是用的pytorch版吗?
你的数据集可能偏小,需要重新调参。
我提供epoch500的时候可能还没有用automl调参,后来测的pytorch能达到0.9以上。
另外poolsize设为10,000或100,000更合理。

这个数据集用于训练部分数据是23w左右,验证部分数据量1.5w左右。我想先在poolsize=1000达到0.9以后,再试试10000的。
关于调参 您有什么建议么。

        #parameters
        'name_len': 6,
        'api_len':30,
        'tokens_len':50,
        'desc_len': 30,
        'n_words': 10000, # len(vocabulary) + 1
        #vocabulary info
        'vocab_name':'vocab.name.json',
        'vocab_api':'vocab.apiseq.json',
        'vocab_tokens':'vocab.tokens.json',
        'vocab_desc':'vocab.desc.json',
                
    #training_params            
        'batch_size': 64,
        'chunk_size':200000,
        'nb_epoch': 15,
        #'optimizer': 'adam',
        'learning_rate':2.08e-4,
        'adam_epsilon':1e-8,
        'warmup_steps':5000,
        'fp16': False,
        'fp16_opt_level': 'O1', #For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3'].
                        #"See details at https://nvidia.github.io/apex/amp.html"

    # model_params
        'emb_size': 512,
        'n_hidden': 512,#number of hidden dimension of code/desc representation
        # recurrent
        'lstm_dims': 256, # * 2          
        'margin': 0.3986,
        'sim_measure':'cos',#similarity measure: cos, poly, sigmoid, euc, gesd, aesd. see https://arxiv.org/pdf/1508.01585.pdf
                     #cos, poly and sigmoid are fast with simple dot, while euc, gesd and aesd are slow with vector normalization.
}
return conf

@guxd
Copy link
Owner

guxd commented Jul 1, 2020

调参需要专用平台。手工调参的话可以参考automl_config.yaml末尾关于参数的区间。
我再用我们的训练数据跑一下,给你一个训练好的模型,你用训练好的模型来测试你的数据看看效果。

@primary-studyer
Copy link
Author

调参需要专用平台。手工调参的话可以参考automl_config.yaml末尾关于参数的区间。
我再用我们的训练数据跑一下,给你一个训练好的模型,你用训练好的模型来测试你的数据看看效果。

感谢。我的数据集小确实不太容易操作。

@guxd
Copy link
Owner

guxd commented Jul 3, 2020

https://drive.google.com/file/d/15HoKv0efrVXNTsqCxoq2Swgh6ohuq5jI/view?usp=sharing
这里是训练好的一个模型,pool size选的10000, top-1精度结果如下
image

@primary-studyer
Copy link
Author

https://drive.google.com/file/d/15HoKv0efrVXNTsqCxoq2Swgh6ohuq5jI/view?usp=sharing
这里是训练好的一个模型,pool size选的10000,
image

因为我自己模型query不是很好,从头捋一遍的时候发现一个问题 例如convert inputstream to string
方法名为inputstreamToString,我分词为inputstream to string 存储到.name.h5中。
关于desc.h5我将存为inputstream还是input stream?
那么token.h5关于InputStream语句我存inputstream还是input stream?
因为有的是整体不用分割,有的词是需要分割的?这种情况我该怎么处理。因为我无法确定哪些词是部分拆分的。

@guxd
Copy link
Owner

guxd commented Jul 12, 2020

我们简单的对代码里的token作了camel split, query没有拆分,你可以都试试。

@primary-studyer
Copy link
Author

我们简单的对代码里的token作了camel split, query没有拆分,你可以都试试。
我在另一个数据集训练dcs的时候,acc top 10 poolsize-1000,打到了0.90+
并且我query的时候将inputstream to string,转换成input stream to string来查询。其中top10最高相似度为0.90左右,但是结果并不相干。
当我把codebase等文件其他项删除 只留一个InputStreamToString代码片段时候,即len(codebase)==1,cos是0.93,但是在上面这个结果并没有出现,最高才0.9+。 0.93的那段并没有显示出来。

@guxd
Copy link
Owner

guxd commented Jul 12, 2020

谢谢你提供的线索,可能代码还存在bug。如果找到原因麻烦您告知。

@primary-studyer
Copy link
Author

primary-studyer commented Jul 18, 2020

谢谢你提供的线索,可能代码还存在bug。如果找到原因麻烦您告知。

原因貌似在repr_code.py中data_loader = torch.utils.data.DataLoader(dataset=use_set, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=1)的batch_size,您的代码默认是10000.我执行query语句input stream to `string,top5结果如下:
(这个是我把codebase length的大小做成为5的代码库,就5个数据,当codebase是全部代码片段时,第一的这个0.9439719这条相似度变成0.72左右)

`('public static String inputStreamToString(InputStream is) throws MPException { String value = ""; if (is != null) {
try { ByteArrayOutputStream result = new ByteArrayOutputStream(); byte[] buffer = new byte[1024]; int length;
while ((length = is.read(buffer)) != -1) { result.write(buffer, 0, length); } value = result.toString("UTF-8");
} catch (Exception ex) {throw new com.mercadopago.exceptions.MPException(ex); } } return value; }\r\n', 0.9439719)

('private static String inputStreamToString(InputStream in) { try { BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(in));
StringBuilder stringBuilder = new StringBuilder(); String line = null; while ((line = bufferedReader.readLine()) != null) {
stringBuilder.append(line + "\n"); } bufferedReader.close(); return stringBuilder.toString(); }
catch (IOException e) { throw new RuntimeException("Failed to parse input stream", e); } }\r\n', 0.66112924)

('public static String inputStreamToString(InputStream in) { StringBuffer buffer = new StringBuffer(); try { BufferedReader br = new BufferedReader(new InputS
treamReader(in, "UTF-8"), 1024); String line; while ((line = br.readLine()) != null) {
buffer.append(line); } } catch (IOException iox) { LOGR.warning(iox.getMessage()); }
return buffer.toString(); }\r\n', 0.6445415)

('public static String inputStreamToString(InputStream in) { Scanner scanner = new Scanner(in, "UTF-8"); String content = scanner.useDelimiter("\\A").next(); scanner.close(); return conte
nt; }', 0.49575543)
('public static String readInputStreamToString( InputStream inputStream ) {try {List bytesList = new ArrayList(); byte b = 0; while( (b = (byte)
inputStream.read()) != -1 ) {bytesList.add(b); } inputStream.close(); byte[] bArray = new byte[bytesList.size()]; for( int i = 0; i < bArray.length; i++ )
{bArray[i] = bytesList.get(i);} String file = new String(bArray); return file; } catch (IOException e) { e.printStackTrace();
return null; } }\r\n', 0.48657355)`

当我把batchzise设置成1. top10结果如下
('public static int cuStreamWriteValue32 ( CUstream stream, CUdeviceptr addr, int value, int flags ) { return checkResult ( cuStreamW
riteValue32Native ( stream, addr, value, flags ) ) ; } \n', 0.9701143)
('@OverRide protected Response encode ( final SamlRegisteredService service, final Response samlResponse, final HttpServletResponse h
ttpResponse, final HttpServletRequest httpRequest, final SamlRegisteredServiceServiceProviderMetadataFacade adaptor, final String relayState, final String binding, final RequestAbstractType authnRequest, final Object assertion ) throws SamlException { LOGGER . trace ( " " , binding, adaptor . getEntityId ( ) ) ; if ( binding . equalsIgnoreCase ( SAMLConstants . SAML2_ARTIFACT_BINDING_URI ) ) { val encoder = new SamlResponseArtifactEncoder ( getSamlResponseBuilderConfigurationContext ( ) . getVelocityEngineFactory ( ) , adaptor, httpRequest, httpResponse, authnRequest, getSamlResponseBuilderConfigurationContext ( ) . getTicketRegistry ( ) , getSamlResponseBuilderConfigurationContext ( ) . getSamlArtifactTicketFactory ( ) , getSamlResponseBuilderConfigurationContext ( ) . getTicketGrantingTicketCookieGenerator ( ) , getSamlResponseBuilderConfigurationContext ( ) . getSamlArtifactMap ( ) ) ; return encoder . encode ( authnRequest, samlResponse, relayState ) ; } if ( binding . equalsIgnoreCase ( SAMLConstants . SAML2_POST_SIMPLE_SIGN_BINDING_URI ) ) { val encoder = new SamlResponsePostSimpleSignEncoder ( getSamlResponseBuilderConfigurationContext ( ) . getVelocityEngineFactory ( ) , adaptor, httpResponse, httpRequest ) ; return encoder . encode ( authnRequest, samlResponse, relayState ) ; } val encoder = new SamlResponsePostEncoder ( getSamlResponseBuilderConfigurationContext ( ) . getVelocityEngineFactory ( ) , adaptor, httpResponse, httpRequest ) ; return encoder . encode ( authnRequest, samlResponse, relayState ) ; } \n', 0.9691605)
('public static int cuStreamWriteValue64 ( CUstream stream, CUdeviceptr addr, long value, int flags ) { return checkResult ( cuStream
WriteValue64Native ( stream, addr, value, flags ) ) ; } \n', 0.96685445)
('public static ByteBuf encode ( ByteBufAllocator allocator, int streamId, boolean fragmentFollows, boolean complete, boolean next, P
ayload payload ) { return FLYWEIGHT . encode ( allocator, streamId, fragmentFollows, complete, next, 0, payload . hasMetadata ( ) ? payload . metadata ( ) . retain ( ) : null, payload . data ( ) . retain ( ) ) ; } \n', 0.9657972)
('public static < S > Stream < S > stream ( Iterable < S > input ) { return stream ( input, false ) ; } \n', 0.96461725)

('private byte [ ] decryptStream ( byte [ ] key, byte [ ] keepassFile ) throws IOException { CryptoInformation cryptoInformation = ne
w CryptoInformation ( KeePassHeader . VERSION_SIGNATURE_LENGTH, keepassHeader . getMasterSeed ( ) , keepassHeader . getTransformSeed ( ) , keepassHeader . getEncryptionIV ( ) , keepassHeader . getTransformRounds ( ) , keepassHeader . getHeaderSize ( ) ) ; return decrypter . decryptDatabase ( key, cryptoInformation, keepassFile ) ; } \n', 0.96334887)
('private ByteBuf encodeReadHoldingRegisters ( ReadHoldingRegistersResponse response, ByteBuf buffer ) { buffer . writeByte ( respons
e . getFunctionCode ( ) . getCode ( ) ) ; buffer . writeByte ( response . getRegisters ( ) . readableBytes ( ) ) ; buffer . writeBytes ( response . getRegisters ( ) ) ; return buffer ; } \n', 0.96252537)
('@OverRide public byte [ ] encode ( Endpoint endpoint ) { return ( endpoint . host ( ) + fieldDelimiter + endpoint . port ( ) + fiel
dDelimiter + endpoint . weight ( ) ) . getBytes ( StandardCharsets . UTF_8 ) ; } \n', 0.96244895)
('private static InputStreamWithMetadata compressStreamWithGZIP ( InputStream inputStream ) throws SnowflakeSQLException { FileBacked
OutputStream tempStream = new FileBackedOutputStream ( MAX_BUFFER_SIZE, true ) ; try { DigestOutputStream digestStream = new DigestOutputStream ( tempStream, MessageDigest . getInstance ( " " ) ) ; CountingOutputStream countingStream = new CountingOutputStream ( digestStream ) ; GZIPOutputStream gzipStream ; gzipStream = new GZIPOutputStream ( countingStream, true ) ; IOUtils . copy ( inputStream, gzipStream ) ; inputStream . close ( ) ; gzipStream . finish ( ) ; gzipStream . flush ( ) ; countingStream . flush ( ) ; return new InputStreamWithMetadata ( countingStream . getCount ( ) , Base64 . encodeAsString ( digestStream . getMessageDigest ( ) . digest ( ) ) , tempStream ) ; } catch ( IOException | NoSuchAlgorithmException ex ) { logger . error ( " " , ex ) ; throw new SnowflakeSQLException ( ex, SqlState . INTERNAL_ERROR, ErrorCode . INTERNAL_ERROR . getMessageCode ( ) , " " ) ; } } \n', 0.96195817)
('public static < U > Stream < U > stream ( final Spliterator < U > it ) { return StreamSupport . stream ( it, false ) ; } \n', 0.961
52276)

batchzise=1的结果更相关,我不知道这是为什么。我不太熟悉这个批操作的影响来自哪里。如果你找到原因请希望您的回复。

@guxd
Copy link
Owner

guxd commented Jul 18, 2020

见repr_code.py第50行。根据现在的设置,chunk_size=2000000, 你的codebase需要至少达到这么多代码系统才存储你的code vector. 我怀疑你搜索的时候用的老的code vector(请确认).
解决办法:在codebase大小小于chunk_size时只存储一个code vector文件,相应的在search.py里load vector时也要做相应更改。

我这边暂时先不改,以防引起其他问题,后面时间充裕会完整修改调试。

@primary-studyer
Copy link
Author

primary-studyer commented Jul 18, 2020 via email

@guxd
Copy link
Owner

guxd commented Jun 16, 2021

已解决PyTorch版的Bug, 问题在modules.py文件里的h_n = h_n.transpose(1, 0).contiguous().
去掉这行validation 效果提升,所以当时不下心注释掉了,把这行添上就行了。
现在测试效果没问题了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants