Skip to content

Commit

Permalink
refactor(embedding): refactor model loading logic
Browse files Browse the repository at this point in the history
Refactor the model loading logic in the LocalEmbedding class to separate loading the tokenizer and the neural network model into distinct functions for better organization and clarity.
  • Loading branch information
phodal committed Jul 7, 2024
1 parent 2c0f0c0 commit 3d1f4a8
Showing 1 changed file with 27 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,41 @@ open class LocalEmbedding(
fun create(): LocalEmbedding {
val classLoader = Thread.currentThread().getContextClassLoader()

val tokenizerStream = classLoader.getResourceAsStream("model/tokenizer.json")!!
val onnxStream = classLoader.getResourceAsStream("model/model.onnx")!!
val tokenizer = loadTokenizer(classLoader)!!

val tokenizer = HuggingFaceTokenizer.newInstance(tokenizerStream, null)
val ortEnv = OrtEnvironment.getEnvironment()
val session = loadNetwork(classLoader, ortEnv)!!

return LocalEmbedding(tokenizer, session, ortEnv)
}

/**
* Loads a neural network model from the specified class loader and creates an OrtSession using the provided OrtEnvironment.
*
* @param classLoader the ClassLoader used to load the model file
* @param ortEnv the OrtEnvironment used to create the OrtSession
* @return the OrtSession created from the loaded model file, or null if an error occurs
*/
fun loadNetwork(classLoader: ClassLoader, ortEnv: OrtEnvironment): OrtSession? {
val sessionOptions = OrtSession.SessionOptions()

val onnxStream = classLoader.getResourceAsStream("model/model.onnx")!!
// load onnxPath as byte[]
val onnxPathAsByteArray = onnxStream.readAllBytes()
val session = ortEnv.createSession(onnxPathAsByteArray, sessionOptions)
return session
}

return LocalEmbedding(tokenizer, session, ortEnv)
/**
* Loads a HuggingFaceTokenizer using the provided ClassLoader.
*
* @param classLoader the ClassLoader used to load the tokenizer resource
* @return a HuggingFaceTokenizer instance loaded from the "model/tokenizer.json" resource, or null if the tokenizer could not be loaded
*/
fun loadTokenizer(classLoader: ClassLoader): HuggingFaceTokenizer? {
val tokenizerStream = classLoader.getResourceAsStream("model/tokenizer.json")
val tokenizer = HuggingFaceTokenizer.newInstance(tokenizerStream, null)
return tokenizer
}
}
}

0 comments on commit 3d1f4a8

Please sign in to comment.