Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When calculating the loss, the input data does not contain NaN, but the output contains NaN #23625

Open
CZXIANGOvO opened this issue Sep 13, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@CZXIANGOvO
Copy link

CZXIANGOvO commented Sep 13, 2024

Description

Please specify cuda:0 at the very beginning.

import torch
import numpy as np
import os
import jax
import jax
import jax.numpy as jnp
from jax import ops as jops
from jax.nn import one_hot, sigmoid
from jax import lax
import jax.scipy.special as sc
import optax

if "CONTEXT_DEVICE_TARGET" in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU':
    devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
    device = devices[-2]
    final_device = "cuda:" + device
else:
    final_device = 'cpu'


from network.cv.yolov4.yolov4_pytorch import YOLOV4CspDarkNet53_torch as yolov4_torch

def loss_yolo_jax():
    from network.cv.yolov4.yolov4_pytorch import yolov4loss_jax
    yolo_obj = yolov4loss_jax()
    return yolo_obj


y_true_0 = np.load('./yolo_out[0][0].npy')
yolo_out1 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[0][1].npy')
yolo_out2 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[0][2].npy')
yolo_out3 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[1][0].npy')
yolo_out4 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[1][1].npy')
yolo_out5 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[1][2].npy')
yolo_out6 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[2][0].npy')
yolo_out7 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[2][1].npy')
yolo_out8 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[2][2].npy')
yolo_out9 = torch.from_numpy(y_true_0).to(final_device)

yolo_out = ((yolo_out1,yolo_out2,yolo_out3),(yolo_out4,yolo_out5,yolo_out6),(yolo_out7,yolo_out8,yolo_out9))


model_pt = yolov4_torch()
model_pt.train()
model_torch = model_pt.to(final_device)



y_true_0 = np.load('./y_true_0.npy')
y_true_0 = torch.from_numpy(y_true_0).to(final_device)

y_true_1 = np.load('./y_true_1.npy')
y_true_1 = torch.from_numpy(y_true_1).to(final_device)

y_true_2 = np.load('./y_true_2.npy')

y_true_2 = torch.from_numpy(y_true_2).to(final_device)


gt_0 = np.load('./gt_0.npy')
gt_0 = torch.from_numpy(gt_0).to(final_device)

gt_1 = np.load('./gt_1.npy')
gt_1 = torch.from_numpy(gt_1).to(final_device)


gt_2 = np.load('./gt_2.npy')
gt_2 = torch.from_numpy(gt_2).to(final_device)

input_shape_t = np.load('./input_shape_t.npy')
input_shape_t = torch.from_numpy(input_shape_t).to(final_device)

params_torch = {key: value.detach().cpu().numpy() for key, value in model_torch.state_dict().items()}


loss_jax_fun = loss_yolo_jax()
params_jax = {name: jnp.array(value, dtype=jnp.float32) for name, value in params_torch.items()}
loss_jax, jax_grads = jax.value_and_grad(loss_jax_fun.calc_loss)(params_jax, yolo_out, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape_t)


yolo_out1 = torch.isnan(yolo_out1).any()
print('yolo_out1;',yolo_out1) 
yolo_out2 = torch.isnan(yolo_out2).any()
print('yolo_out2;',yolo_out2) 
yolo_out3 = torch.isnan(yolo_out3).any()
print('yolo_out3;',yolo_out3) 
yolo_out4 = torch.isnan(yolo_out4).any()
print('yolo_out4;',yolo_out4)
yolo_out5 = torch.isnan(yolo_out5).any()
print('yolo_out5;',yolo_out5) 
yolo_out6 = torch.isnan(yolo_out6).any()
print('yolo_out6;',yolo_out6) 
yolo_out7 = torch.isnan(yolo_out7).any()
print('yolo_out7;',yolo_out7)
yolo_out8 = torch.isnan(yolo_out8).any()
print('yolo_out8;',yolo_out8)
yolo_out9 = torch.isnan(yolo_out9).any()
print('yolo_out9;',yolo_out9) 
y_true_0 = torch.isnan(y_true_0).any()
print('y_true_0;',y_true_0) 
y_true_1 = torch.isnan(y_true_1).any()
print('y_true_1;',y_true_1)
y_true_2 = torch.isnan(y_true_2).any()
print('y_true_2;',y_true_2) 
gt_0 = torch.isnan(gt_0).any()
print('gt_0;',gt_0) 
gt_1 = torch.isnan(gt_1).any()
print('gt_1;',gt_1) 
gt_2 = torch.isnan(gt_2).any()
print('gt_2;',gt_2) 
input_shape_t = torch.isnan(input_shape_t).any()
print('input_shape_t;',input_shape_t) 

print('loss_torch_result;',np.array(loss_jax)) 

屏幕截图 2024-09-13 204340

System info (python version, jaxlib version, accelerator, etc.)

Code and data links:https://drive.google.com/file/d/1-edrk7_sxSgdu7cmXQXf6JsT57xiG1Hb/view?usp=sharing

@CZXIANGOvO CZXIANGOvO added the bug Something isn't working label Sep 13, 2024
@lockwo
Copy link
Contributor

lockwo commented Sep 15, 2024

Is there a MVC? This code doesn't run for me

@CZXIANGOvO
Copy link
Author

Is there a MVC? This code doesn't run for me

Is there a MVC? This code doesn't run for me

Where can't run it, in the beginning final_device to set it yourself, you can delete

 ifCONTEXT_DEVICE_TARGETin os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU': devices = os.environ['CUDA_VISIBLE_DEVICES'].split(“,”).
    devices = os.environ['CUDA_VISIBLE_DEVICES'].split(“,”)
    device = devices[-2]
    final_device =cuda:” + device
else: final_device = 'cuda:” + device
    final_device = 'cpu' 

Translated with DeepL.com (free version)

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Hi @CZXIANGOvO – it's going to be hard to help with specifics here absent an MVC (also known as a minimal reproducible example). If you're able to re-work your example so that others can run it and see the same errors you are seeing, then we could offer specific guidance.

Absent that, though, in general it's not surprising to see NaN outputs for inputs without NaNs: it just means that you're calling some function in your model in a way that is undefined to floating point precision. Here's a simple example of this:

>>> import jax.numpy as jnp

>>> def f(x, y):
...   return x * jnp.exp(y)

>>> f(1.0, 1.0)
Array(2.7182817, dtype=float32, weak_type=True)

>>> f(0.0, 100.0)
Array(nan, dtype=float32, weak_type=True)

More than likely, somewhere in your model you have an expression that is evaluating to NaN for reasons like this.

The best way to debug this is to start digging-in to your model to figure out exactly where this is coming from. One way to do this is to enable the jax_debug_nans flag, as described here: https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debug-nans-configuration-option-and-context-manager

I hope that helps get you on the right path!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants