Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,7 @@ def get_valid_ratio(self, mask):
def get_proposal_pos_embed(self, proposals):
"""Get the position embedding of the proposals."""

num_pos_feats = 128
num_pos_feats = self.config.d_model // 2
temperature = 10000
scale = 2 * math.pi

Expand Down Expand Up @@ -1969,12 +1969,11 @@ def forward(
outputs_coord = outputs_coord_logits.sigmoid()
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
# Keep batch_size as first dimension
outputs_class = torch.stack(outputs_classes, dim=1)
outputs_coord = torch.stack(outputs_coords, dim=1)
outputs_class = torch.stack(outputs_classes)
outputs_coord = torch.stack(outputs_coords)

logits = outputs_class[:, -1]
pred_boxes = outputs_coord[:, -1]
logits = outputs_class[-1]
pred_boxes = outputs_coord[-1]

loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None:
Expand All @@ -2000,7 +1999,7 @@ def forward(
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
if self.config.two_stage:
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
outputs_loss["enc_outputs"] = {"logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}

loss_dict = criterion(outputs_loss, labels)
# Fourth: compute total loss, as a weighted sum of the various losses
Expand Down Expand Up @@ -2232,7 +2231,7 @@ def forward(self, outputs, targets):
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs" and k != "enc_outputs"}

# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
Expand Down Expand Up @@ -2264,14 +2263,10 @@ def forward(self, outputs, targets):
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
kwargs = {}
if loss == "labels":
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
l_dict = {k + "_enc": v for k, v in l_dict.items()}
losses.update(l_dict)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,7 @@ def get_valid_ratio(self, mask):
def get_proposal_pos_embed(self, proposals):
"""Get the position embedding of the proposals."""

num_pos_feats = 128
num_pos_feats = self.config.d_model // 2
temperature = 10000
scale = 2 * math.pi

Expand Down
15 changes: 15 additions & 0 deletions tests/models/deformable_detr/test_modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,21 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

def test_two_stage_training(self):
model_class = DeformableDetrForObjectDetection
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
config.two_stage = True
config.auxiliary_loss = True
config.with_box_refine = True

model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()


TOLERANCE = 1e-4

Expand Down