Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom llm #503

Merged
merged 10 commits into from
Aug 8, 2023
Merged

Conversation

Sardhendu
Copy link
Contributor

@Sardhendu Sardhendu commented Aug 2, 2023

Custom Model and Global Dict ModelRegistry support

  • ModelRegistry: A Dictionary object with helper method and model info
        from autolabel.models import MODEL_REGISTRY
        print(MODEL_REGISTRY)  # Lists all the provider and models in a table format for better viewability
    
  • Langchain imports done under constructor
  • Methods implemented to register all predefined LLM class.
  • Registering a custom Model.
    class MyCustommodel(BaseModel):
        def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
            super().__init__(config, cache)
    
        def _label(self,):
            pass
    
        def get_cost(self,):
            pass
    
        def returns_token_probs(self,):
            pass
    
    
    MODEL_REGISTRY.register(model_name="my_custom_llm", model=MyCustommodel)
    

@Sardhendu
Copy link
Contributor Author

Sardhendu commented Aug 2, 2023

@nihit . instantiating ModelRegistry output a global dictionary MODEL_REGISTRY. This class is just.a helper to view the list of providers and models. Let me know your thoughts about it. Rest of the code should be self explanatory.

NOTE: [Its still pending]

Please skim through it and suggest changes, if needed.

@nihit nihit requested review from nihit and rajasbansal August 2, 2023 15:50
src/autolabel/models/__init__.py Outdated Show resolved Hide resolved
@@ -30,6 +28,12 @@ class AnthropicLLM(BaseModel):

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)

from langchain.chat_models import ChatAnthropic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we move this in order to prevent import errors? Should we catch import errors in that case and return the command to install the extras related to this model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done to mimic the imports as part of previously deployed code. If langchain is a base dependency it makes more sense to have top level imports.

Agreed that we should be catching import errors during test cases.


from langchain.chat_models import ChatVertexAI
from langchain.llms import VertexAI
from langchain.schema import LLMResult, HumanMessage, Generation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can keep this out? langchain is a dependency of the library


from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.schema import HumanMessage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can keep this out?

@Sardhendu
Copy link
Contributor Author

@rajasbansal With the updated code below is how we register the model.

register_model(name="custom", model_cls=MyCustommodel)

Note: The name has to be custom

)


def _register_refuel() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to do this in functions, we can add to the dictionary directly,

MODEL_REGISTRY = {
ModelProvider.HUGGINGFACE_PIPELINE: HFPipelineLLM,
ModelProvider.GOOGLE: PaLMLM
...
}

import tiktoken

from autolabel.models import BaseModel
from autolabel.configs import AutolabelConfig
from autolabel.cache import BaseCache
from autolabel.schema import RefuelLLMResult

from langchain.chat_models import ChatOpenAI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move this to inside the model so that we don't fail if openai is not installed

from autolabel.models import BaseModel
from autolabel.configs import AutolabelConfig
from autolabel.cache import BaseCache
from autolabel.schema import RefuelLLMResult

from langchain.chat_models import ChatVertexAI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move this

from autolabel.models import BaseModel
from autolabel.configs import AutolabelConfig
from autolabel.cache import BaseCache
from autolabel.schema import RefuelLLMResult
from langchain.llms import Cohere
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same: move

Copy link
Contributor

@rajasbansal rajasbansal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Thanks for the changes!

@rajasbansal rajasbansal merged commit 8b929b7 into refuel-ai:main Aug 8, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants