Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jonafeucht committed Jun 12, 2024
1 parent 629fa39 commit 40c8ec4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
27 changes: 18 additions & 9 deletions src/routes/api/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ async def image_classification(
model_name: str = Query(None),
):

classifier = check_model(model_name)

try:
# Read the file as bytes
contents = await file.read()
Expand All @@ -40,6 +38,8 @@ async def image_classification(
# Check if the image is a GIF and if it's animated
if img.format.lower() == "gif":
try:
classifier = check_model(model_name)

results = []
with ThreadPoolExecutor() as executor:
futures = [
Expand All @@ -56,6 +56,8 @@ async def image_classification(
)
finally:
img.close()
del classifier
torch.cuda.empty_cache()

# Check Static Image
else:
Expand All @@ -66,15 +68,17 @@ async def image_classification(
# Encode the file to base64
base64Image = base64.b64encode(contents).decode("utf-8")

res = classifier(base64Image)
res2 = classifier(base64Image)

return res
return res2
except (ValueError, IOError) as e:
raise HTTPException(
status_code=400, detail=f"Error classifying image: {e}"
)
finally:
img.close()
del res2
torch.cuda.empty_cache()
else:
return HTTPException(
status_code=400, detail="The uploaded file is not a valid image."
Expand All @@ -86,15 +90,14 @@ async def image_classification(
return {"error": str(e)}

finally:
del classifier
img.close()
torch.cuda.empty_cache()


@router.post("/api/multi-image-classification", dependencies=[Depends(get_api_key)])
async def multi_image_classification(
files: List[UploadFile] = File(), model_name: str = Query(None)
):
classifier = check_model(model_name)

image_list = []

Expand All @@ -112,6 +115,8 @@ async def multi_image_classification(
# Check if the image is a GIF and if it's animated
if img.format.lower() == "gif":
try:
classifier = check_model(model_name)

results = []
with ThreadPoolExecutor() as executor:
futures = [
Expand All @@ -128,22 +133,26 @@ async def multi_image_classification(
)
finally:
img.close()
del classifier
torch.cuda.empty_cache()

# Check Static Image
else:
try:
# Encode the file to base64
base64Image = base64.b64encode(contents).decode("utf-8")

res = classifier(base64Image)
image_list.append({index: res})
res2 = classifier(base64Image)
image_list.append({index: res2})

except (ValueError, IOError) as e:
raise HTTPException(
status_code=400, detail=f"Error classifying image: {e}"
)
finally:
img.close()
del res2
torch.cuda.empty_cache()

else:
img.close()
Expand All @@ -156,7 +165,7 @@ async def multi_image_classification(
return {"error": str(e)}

finally:
del classifier
img.close()
torch.cuda.empty_cache()

return image_list
40 changes: 31 additions & 9 deletions src/routes/api/image_query_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def process_image(
print("Error processing image:", str(e))
return results

finally:
img.close()
del result
torch.cuda.empty_cache()


@router.post("/api/image-query-classification", dependencies=[Depends(get_api_key)])
async def image_query_classification(
Expand All @@ -78,7 +83,6 @@ async def image_query_classification(
contents = await file.read()

for model_name in model_names:
classifier = check_model(model_name)
_score = score or default_score

try:
Expand All @@ -97,6 +101,8 @@ async def image_query_classification(
if img.format.lower() == "gif":

try:
classifier = check_model(model_name)

res = await asyncio.get_event_loop().run_in_executor(
executor,
process_image,
Expand All @@ -115,6 +121,11 @@ async def image_query_classification(
status_code=400, detail="The uploaded GIF is not animated."
)

finally:
img.close()
del classifier
torch.cuda.empty_cache()

# Check Static Image
else:
try:
Expand All @@ -125,9 +136,9 @@ async def image_query_classification(
# Encode the file to base64
base64Image = base64.b64encode(contents).decode("utf-8")

res = classifier(base64Image)
res2 = classifier(base64Image)

label_scores = {i["label"]: i["score"] for i in res}
label_scores = {i["label"]: i["score"] for i in res2}
for l in labels[:]:
if l in label_scores and label_scores[l] >= _score:
results.append(
Expand All @@ -143,12 +154,17 @@ async def image_query_classification(
status_code=400, detail=f"Error classifying image: {e}"
)

finally:
img.close()
del res2
torch.cuda.empty_cache()

except Exception as e:
print("File is not a valid image.")
return {"error": str(e)}

finally:
del classifier
img.close()
torch.cuda.empty_cache()

return totalResults
Expand Down Expand Up @@ -180,13 +196,13 @@ async def multi_image_query_classification(

for index, file in enumerate(files):
# Read the file as bytes
contents = await file.read()
image_list = []

for model_name in model_names:
try:
contents = await file.read()

labels_copy = labels.copy()
classifier = check_model(model_name)

# Check if the image is in fact an image
try:
Expand All @@ -202,6 +218,8 @@ async def multi_image_query_classification(
# Check if the image is a GIF and if it's animated
if img.format.lower() == "gif":
try:
classifier = check_model(model_name)

res = await asyncio.get_event_loop().run_in_executor(
executor,
process_image,
Expand Down Expand Up @@ -235,9 +253,9 @@ async def multi_image_query_classification(
# Encode the file to base64
base64Image = base64.b64encode(contents).decode("utf-8")

res = classifier(base64Image)
res2 = classifier(base64Image)

label_scores = {i["label"]: i["score"] for i in res}
label_scores = {i["label"]: i["score"] for i in res2}
for l in labels[:]:
if l in label_scores and label_scores[l] >= _score:
results.append(
Expand All @@ -255,14 +273,18 @@ async def multi_image_query_classification(
)
finally:
img.close()
del classifier
del res2
torch.cuda.empty_cache()

except Exception as e:
print("File is not a valid image.")
img.close()
return {"error": str(e)}

finally:
img.close()
torch.cuda.empty_cache()

totalResults.append({index: image_list})

return totalResults
8 changes: 5 additions & 3 deletions src/routes/api/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,16 @@ def process_video(
if return_on_first_matching_label or set(labels) == m:
break

# Release the video capture object
vc.release()

# Return the results
return results
except Exception as e:
return e

finally:
vc.release()
del result
torch.cuda.empty_cache()


@router.post("/api/video-classification", dependencies=[Depends(get_api_key)])
async def video_classification(
Expand Down

0 comments on commit 40c8ec4

Please sign in to comment.