Skip to content
This repository was archived by the owner on Jan 14, 2025. It is now read-only.

Commit

Permalink
fix: memory leak #11
Browse files Browse the repository at this point in the history
  • Loading branch information
jonafeucht committed Jun 12, 2024
1 parent 3a74628 commit a402a25
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ services:
- models:/root/.cache/huggingface/hub:rw
environment:
- DEFAULT_MODEL_NAME
- BATCH_SIZE
- ACCESS_TOKEN
- DEFAULT_SCORE
- USE_API_KEYS
Expand All @@ -41,6 +42,7 @@ services:
- models:/root/.cache/huggingface/hub:rw
environment:
- DEFAULT_MODEL_NAME
- BATCH_SIZE
- ACCESS_TOKEN
- DEFAULT_SCORE
- USE_API_KEYS
Expand All @@ -61,6 +63,7 @@ volumes:
- Create a `.env` file and set the preferred values.
```sh
DEFAULT_MODEL_NAME=Falconsai/nsfw_image_detection
BATCH_SIZE=5
DEFAULT_SCORE=0.7
ACCESS_TOKEN=
Expand Down
23 changes: 23 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,34 @@ services:
- models:/root/.cache/huggingface/hub:rw
environment:
- DEFAULT_MODEL_NAME
- BATCH_SIZE
- ACCESS_TOKEN
- DEFAULT_SCORE
- USE_API_KEYS
- API_KEYS
restart: unless-stopped

image_video_classification_cuda:
image: ghcr.io/doppeltilde/image_video_classification:latest-cuda
ports:
- "8000:8000"
volumes:
- models:/root/.cache/huggingface/hub:rw
environment:
- DEFAULT_MODEL_NAME
- BATCH_SIZE
- ACCESS_TOKEN
- DEFAULT_SCORE
- USE_API_KEYS
- API_KEYS
restart: unless-stopped
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [ gpu ]

volumes:
models:
10 changes: 10 additions & 0 deletions src/routes/api/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import filetype
from src.shared.shared import check_model
from src.middleware.auth.auth import get_api_key
import torch

router = APIRouter()

Expand Down Expand Up @@ -84,6 +85,10 @@ async def image_classification(
print("File is not a valid image.")
return {"error": str(e)}

finally:
del classifier
torch.cuda.empty_cache()


@router.post("/api/multi-image-classification", dependencies=[Depends(get_api_key)])
async def multi_image_classification(
Expand Down Expand Up @@ -139,6 +144,7 @@ async def multi_image_classification(
)
finally:
img.close()

else:
img.close()
return HTTPException(
Expand All @@ -149,4 +155,8 @@ async def multi_image_classification(
img.close()
return {"error": str(e)}

finally:
del classifier
torch.cuda.empty_cache()

return image_list
13 changes: 13 additions & 0 deletions src/routes/api/image_query_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from concurrent.futures import ThreadPoolExecutor
from src.shared.shared import check_model, default_score
from src.middleware.auth.auth import get_api_key
import torch

router = APIRouter()

Expand Down Expand Up @@ -141,10 +142,15 @@ async def image_query_classification(
raise HTTPException(
status_code=400, detail=f"Error classifying image: {e}"
)

except Exception as e:
print("File is not a valid image.")
return {"error": str(e)}

finally:
del classifier
torch.cuda.empty_cache()

return totalResults

except Exception as e:
Expand Down Expand Up @@ -216,6 +222,8 @@ async def multi_image_query_classification(
)
finally:
img.close()
del classifier
torch.cuda.empty_cache()

# Check Static Image
else:
Expand Down Expand Up @@ -247,11 +255,16 @@ async def multi_image_query_classification(
)
finally:
img.close()

except Exception as e:
print("File is not a valid image.")
img.close()
return {"error": str(e)}

finally:
del classifier
torch.cuda.empty_cache()

totalResults.append({index: image_list})

return totalResults
3 changes: 3 additions & 0 deletions src/routes/api/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import filetype
from src.shared.shared import check_model, default_score
from src.middleware.auth.auth import get_api_key
import torch

router = APIRouter()

Expand Down Expand Up @@ -128,3 +129,5 @@ async def video_classification(
finally:
tf.close()
os.remove(tf.name)
del classifier
torch.cuda.empty_cache()
2 changes: 2 additions & 0 deletions src/shared/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

access_token = os.getenv("ACCESS_TOKEN", None)
default_model_name = os.getenv("DEFAULT_MODEL_NAME", "Falconsai/nsfw_image_detection")
default_batch_size = os.getenv("BATCH_SIZE", 5)
default_score = os.getenv("DEFAULT_SCORE", 0.7)
device = 0 if torch.cuda.is_available() else -1

Expand All @@ -25,6 +26,7 @@ def check_model(model_name):
model=_model_name,
token=access_token,
device=device,
batch_size=default_batch_size,
)

return classifier
Expand Down

0 comments on commit a402a25

Please sign in to comment.