Skip to content

Commit 22dc9bc

Browse files
authored
fix(ml): batch axis not being added for recognition model (#12588)
* fix has_batch_axis * fix typing
1 parent fa095c3 commit 22dc9bc

File tree

2 files changed

+1
-7
lines changed

2 files changed

+1
-7
lines changed

Diff for: machine-learning/app/models/facial_recognition/recognition.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from app.models.base import InferenceModel
1414
from app.models.transforms import decode_cv2
1515
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
16-
from app.sessions import has_batch_axis
1716

1817

1918
class FaceRecognizer(InferenceModel):
@@ -27,7 +26,7 @@ def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any)
2726

2827
def _load(self) -> ModelSession:
2928
session = self._make_session(self.model_path)
30-
if self.batch and not has_batch_axis(session):
29+
if self.batch and str(session.get_inputs()[0].shape[0]) != "batch":
3130
self._add_batch_axis(self.model_path)
3231
session = self._make_session(self.model_path)
3332
self.model = ArcFaceONNX(

Diff for: machine-learning/app/sessions/__init__.py

-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +0,0 @@
1-
from app.schemas import ModelSession
2-
3-
4-
def has_batch_axis(session: ModelSession) -> bool:
5-
return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0

0 commit comments

Comments
 (0)