Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions .github/workflows/python-api-setfit-cd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: setfit-docker-cd
on:
push:
branches:
- main
paths:
- "docker_images/setfit/**"
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: "3.8"
- name: Checkout
uses: actions/checkout@v2
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Install dependencies
run: |
pip install --upgrade pip
pip install awscli
- uses: tailscale/github-action@v1
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- name: Update upstream
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
DEFAULT_HOSTNAME: ${{ secrets.DEFAULT_HOSTNAME }}
REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }}
REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }}
run: |
python build_docker.py setfit --out out.txt
- name: Deploy on API
run: |
# Load the tags into the env
cat out.txt >> $GITHUB_ENV
export $(xargs < out.txt)
echo ${SETFIT_CPU_TAG}
# Weird single quote escape mechanism because string interpolation does
# not work on single quote in bash
curl -H "Authorization: Bearer ${{ secrets.API_GITHUB_TOKEN }}" https://api.github.com/repos/huggingface/api-inference/actions/workflows/update_community.yaml/dispatches -d '{"ref":"main","inputs":{"framework":"SETFIT","tag": "'"${SETFIT_CPU_TAG}"'"}}'
26 changes: 26 additions & 0 deletions .github/workflows/python-api-setfit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: setfit-docker

on:
pull_request:
paths:
- "docker_images/setfit/**"
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: "3.8"
- name: Checkout
uses: actions/checkout@v2
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Install dependencies
run: |
pip install --upgrade pip
pip install pytest pillow httpx
pip install -e .
- run: RUN_DOCKER_TESTS=1 pytest -sv tests/test_dockers.py::DockerImageTests::test_setfit
29 changes: 29 additions & 0 deletions docker_images/setfit/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
FROM tiangolo/uvicorn-gunicorn:python3.8
LABEL maintainer="Tom Aarsen <[email protected]>"

# Add any system dependency here
# RUN apt-get update -y && apt-get install libXXX -y

COPY ./requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
COPY ./prestart.sh /app/


# Most DL models are quite large in terms of memory, using workers is a HUGE
# slowdown because of the fork and GIL with python.
# Using multiple pods seems like a better default strategy.
# Feel free to override if it does not make sense for your library.
ARG max_workers=1
ENV MAX_WORKERS=$max_workers
ENV HUGGINGFACE_HUB_CACHE=/data

# Necessary on GPU environment docker.
# TIMEOUT env variable is used by nvcr.io/nvidia/pytorch:xx for another purpose
# rendering TIMEOUT defined by uvicorn impossible to use correctly
# We're overriding it to be renamed UVICORN_TIMEOUT
# UVICORN_TIMEOUT is a useful variable for very large models that take more
# than 30s (the default) to load in memory.
# If UVICORN_TIMEOUT is too low, uvicorn will simply never loads as it will
# kill workers all the time before they finish.
RUN sed -i 's/TIMEOUT/UVICORN_TIMEOUT/g' /gunicorn_conf.py
COPY ./app /app/app
Empty file.
91 changes: 91 additions & 0 deletions docker_images/setfit/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import functools
import logging
import os
from typing import Dict, Type

from api_inference_community.routes import pipeline_route, status_ok
from app.pipelines import Pipeline, TextClassificationPipeline
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.routing import Route


TASK = os.getenv("TASK")
MODEL_ID = os.getenv("MODEL_ID")


logger = logging.getLogger(__name__)


# Add the allowed tasks
# Supported tasks are:
# - text-generation
# - text-classification
# - token-classification
# - translation
# - summarization
# - automatic-speech-recognition
# - ...
# For instance
# from app.pipelines import AutomaticSpeechRecognitionPipeline
# ALLOWED_TASKS = {"automatic-speech-recognition": AutomaticSpeechRecognitionPipeline}
# You can check the requirements and expectations of each pipelines in their respective
# directories. Implement directly within the directories.
ALLOWED_TASKS: Dict[str, Type[Pipeline]] = {
"text-classification": TextClassificationPipeline,
}


@functools.lru_cache()
def get_pipeline() -> Pipeline:
task = os.environ["TASK"]
model_id = os.environ["MODEL_ID"]
if task not in ALLOWED_TASKS:
raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}")
return ALLOWED_TASKS[task](model_id)


routes = [
Route("/{whatever:path}", status_ok),
Route("/{whatever:path}", pipeline_route, methods=["POST"]),
]

middleware = [Middleware(GZipMiddleware, minimum_size=1000)]
if os.environ.get("DEBUG", "") == "1":
from starlette.middleware.cors import CORSMiddleware

middleware.append(
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_methods=["*"],
)
)

app = Starlette(routes=routes, middleware=middleware)


@app.on_event("startup")
async def startup_event():
logger = logging.getLogger("uvicorn.access")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logger.handlers = [handler]

# Link between `api-inference-community` and framework code.
app.get_pipeline = get_pipeline
try:
get_pipeline()
except Exception:
# We can fail so we can show exception later.
pass


if __name__ == "__main__":
try:
get_pipeline()
except Exception:
# We can fail so we can show exception later.
pass
2 changes: 2 additions & 0 deletions docker_images/setfit/app/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from app.pipelines.base import Pipeline, PipelineException
from app.pipelines.text_classification import TextClassificationPipeline
16 changes: 16 additions & 0 deletions docker_images/setfit/app/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any


class Pipeline(ABC):
@abstractmethod
def __init__(self, model_id: str):
raise NotImplementedError("Pipelines should implement an __init__ method")

@abstractmethod
def __call__(self, inputs: Any) -> Any:
raise NotImplementedError("Pipelines should implement a __call__ method")


class PipelineException(Exception):
pass
32 changes: 32 additions & 0 deletions docker_images/setfit/app/pipelines/text_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Dict, List

from app.pipelines import Pipeline
from setfit import SetFitModel


class TextClassificationPipeline(Pipeline):
def __init__(
self,
model_id: str,
) -> None:
self.model = SetFitModel.from_pretrained(model_id)

def __call__(self, inputs: str) -> List[Dict[str, float]]:
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`list`: The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing:
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
probs = self.model.predict_proba([inputs], as_numpy=True)
if probs.ndim == 2:
id2label = getattr(self.model, "id2label", {}) or {}
return [
[
{"label": id2label.get(idx, idx), "score": prob}
for idx, prob in enumerate(probs[0])
]
]
1 change: 1 addition & 0 deletions docker_images/setfit/prestart.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python app/main.py
4 changes: 4 additions & 0 deletions docker_images/setfit/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
starlette==0.27.0
api-inference-community==0.0.32
huggingface_hub==0.11.0
setfit==0.7.0
Empty file.
59 changes: 59 additions & 0 deletions docker_images/setfit/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
from typing import Dict
from unittest import TestCase, skipIf

from app.main import ALLOWED_TASKS, get_pipeline


# Must contain at least one example of each implemented pipeline
# Tests do not check the actual values of the model output, so small dummy
# models are recommended for faster tests.
TESTABLE_MODELS: Dict[str, str] = {
"text-classification": "tomaarsen/setfit-all-MiniLM-L6-v2-sst2-32-shot"
}


ALL_TASKS = {
"audio-classification",
"audio-to-audio",
"automatic-speech-recognition",
"feature-extraction",
"image-classification",
"question-answering",
"sentence-similarity",
"speech-segmentation",
"tabular-classification",
"tabular-regression",
"text-to-image",
"text-to-speech",
"token-classification",
"conversational",
"feature-extraction",
"sentence-similarity",
"fill-mask",
"table-question-answering",
"summarization",
"text2text-generation",
"text-classification",
"zero-shot-classification",
}


class PipelineTestCase(TestCase):
@skipIf(
os.path.dirname(os.path.dirname(__file__)).endswith("common"),
"common is a special case",
)
def test_has_at_least_one_task_enabled(self):
self.assertGreater(
len(ALLOWED_TASKS.keys()), 0, "You need to implement at least one task"
)

def test_unsupported_tasks(self):
unsupported_tasks = ALL_TASKS - ALLOWED_TASKS.keys()
for unsupported_task in unsupported_tasks:
with self.subTest(msg=unsupported_task, task=unsupported_task):
os.environ["TASK"] = unsupported_task
os.environ["MODEL_ID"] = "XX"
with self.assertRaises(EnvironmentError):
get_pipeline()
Loading