Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When calculating the loss, the input data does not contain NaN, but the output contains NaN #23717

Closed
CZXIANGOvO opened this issue Sep 18, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@CZXIANGOvO
Copy link

Description

Please specify cuda:0 at the very beginning.

import torch
import numpy as np
import os
import jax
import jax
import jax.numpy as jnp
from jax import ops as jops
from jax.nn import one_hot, sigmoid
from jax import lax
import jax.scipy.special as sc
import optax


if "CONTEXT_DEVICE_TARGET" in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU':
    devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
    device = devices[-2]
    final_device = "cuda:" + device
else:
    final_device = 'cpu'


from network.cv.yolov3_darknet53.Yolov3_torch import YOLOV3DarkNet53 as yolov3_torch

def loss_yolo_jax():
    from network.cv.yolov4.yolov4_pytorch import yolov4loss_jax
    yolo_obj = yolov4loss_jax()
    return yolo_obj


y_true_0 = np.load('./output_torch[0][0].npy')
yolo_out1 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./output_torch[0][1].npy')
yolo_out2 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./output_torch[0][2].npy')
yolo_out3 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./output_torch[1][0].npy')
yolo_out4 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./output_torch[1][1].npy')
yolo_out5 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./output_torch[1][2].npy')
yolo_out6 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./output_torch[2][0].npy')
yolo_out7 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./output_torch[2][1].npy')
yolo_out8 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./output_torch[2][2].npy')
yolo_out9 = torch.from_numpy(y_true_0).to(final_device)

yolo_out = ((yolo_out1,yolo_out2,yolo_out3),(yolo_out4,yolo_out5,yolo_out6),(yolo_out7,yolo_out8,yolo_out9))


batch_y_true_0_torch = np.load('./batch_y_true_0_torch.npy')
batch_y_true_0_torch = torch.from_numpy(batch_y_true_0_torch).to(final_device)

batch_y_true_1_torch = np.load('./batch_y_true_1_torch.npy')
batch_y_true_1_torch = torch.from_numpy(batch_y_true_1_torch).to(final_device)

batch_y_true_2_torch = np.load('./batch_y_true_2_torch.npy')

batch_y_true_2_torch = torch.from_numpy(batch_y_true_2_torch).to(final_device)


batch_gt_box0_torch = np.load('./batch_gt_box0_torch.npy')
batch_gt_box0_torch = torch.from_numpy(batch_gt_box0_torch).to(final_device)

batch_gt_box1_torch = np.load('./batch_gt_box1_torch.npy')
batch_gt_box1_torch = torch.from_numpy(batch_gt_box1_torch).to(final_device)


batch_gt_box2_torch = np.load('./batch_gt_box2_torch.npy')
batch_gt_box2_torch = torch.from_numpy(batch_gt_box2_torch).to(final_device)

input_shape = np.load('./input_shape.npy')
input_shape = torch.from_numpy(input_shape).to(final_device)


model_pt = yolov3_torch(is_training=True)
model_pt.train()
model_torch = model_pt.to(final_device)

params_torch = {key: value.detach().cpu().numpy() for key, value in model_torch.state_dict().items()}


loss_jax_fun = loss_yolo_jax()
params_jax = {name: jnp.array(value, dtype=jnp.float32) for name, value in params_torch.items()}
loss_jax, jax_grads = jax.value_and_grad(loss_jax_fun.calc_loss)(params_jax, yolo_out,
                                                                    batch_y_true_0_torch, batch_y_true_1_torch,
                                                                    batch_y_true_2_torch, batch_gt_box0_torch,
                                                                    batch_gt_box1_torch,
                                                                    batch_gt_box2_torch, input_shape)



yolo_out1 = torch.isnan(yolo_out1).any()
print('yolo_out1;',yolo_out1) 
yolo_out2 = torch.isnan(yolo_out2).any()
print('yolo_out2;',yolo_out2)
yolo_out3 = torch.isnan(yolo_out3).any()
print('yolo_out3;',yolo_out3) 
yolo_out4 = torch.isnan(yolo_out4).any()
print('yolo_out4;',yolo_out4) 
yolo_out5 = torch.isnan(yolo_out5).any()
print('yolo_out5;',yolo_out5)
yolo_out6 = torch.isnan(yolo_out6).any()
print('yolo_out6;',yolo_out6) 
yolo_out7 = torch.isnan(yolo_out7).any()
print('yolo_out7;',yolo_out7) 
yolo_out8 = torch.isnan(yolo_out8).any()
print('yolo_out8;',yolo_out8) 
yolo_out9 = torch.isnan(yolo_out9).any()
print('yolo_out9;',yolo_out9) 
batch_y_true_0_torch = torch.isnan(batch_y_true_0_torch).any()
print('batch_y_true_0_torch;',batch_y_true_0_torch)  
batch_y_true_1_torch = torch.isnan(batch_y_true_1_torch).any()
print('batch_y_true_1_torch;',batch_y_true_1_torch) 
batch_y_true_2_torch = torch.isnan(batch_y_true_2_torch).any()
print('batch_y_true_2_torch;',batch_y_true_2_torch)
batch_gt_box0_torch = torch.isnan(batch_gt_box0_torch).any()
print('batch_gt_box0_torch;',batch_gt_box0_torch) 
batch_gt_box1_torch = torch.isnan(batch_gt_box1_torch).any()
print('batch_gt_box1_torch;',batch_gt_box1_torch)
batch_gt_box2_torch = torch.isnan(batch_gt_box2_torch).any()
print('batch_gt_box2_torch;',batch_gt_box2_torch) 
input_shape = torch.isnan(input_shape).any()
print('input_shape;',input_shape) 

print('loss_torch_result;',np.array(loss_jax)) 

屏幕截图 2024-09-18 184040

System info (python version, jaxlib version, accelerator, etc.)

Code and data links:https://drive.google.com/file/d/1lxHI_OQwjSUCj7vszNIzl_NmkyVF6JQu/view?usp=sharing

@CZXIANGOvO CZXIANGOvO added the bug Something isn't working label Sep 18, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Hi - if I'm not mistaken, this looks like a duplicate of #23625. Let's close this one and continue the discussion there.

@jakevdp jakevdp closed this as completed Sep 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants