We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Please specify cuda:0 at the very beginning.
import torch import numpy as np import os import jax import jax import jax.numpy as jnp from jax import ops as jops from jax.nn import one_hot, sigmoid from jax import lax import jax.scipy.special as sc import optax if "CONTEXT_DEVICE_TARGET" in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU': devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",") device = devices[-2] final_device = "cuda:" + device else: final_device = 'cpu' from network.cv.yolov3_darknet53.Yolov3_torch import YOLOV3DarkNet53 as yolov3_torch def loss_yolo_jax(): from network.cv.yolov4.yolov4_pytorch import yolov4loss_jax yolo_obj = yolov4loss_jax() return yolo_obj y_true_0 = np.load('./output_torch[0][0].npy') yolo_out1 = torch.from_numpy(y_true_0).to(final_device) y_true_0 = np.load('./output_torch[0][1].npy') yolo_out2 = torch.from_numpy(y_true_0).to(final_device) y_true_0 = np.load('./output_torch[0][2].npy') yolo_out3 = torch.from_numpy(y_true_0).to(final_device) y_true_0 = np.load('./output_torch[1][0].npy') yolo_out4 = torch.from_numpy(y_true_0).to(final_device) y_true_0 = np.load('./output_torch[1][1].npy') yolo_out5 = torch.from_numpy(y_true_0).to(final_device) y_true_0 = np.load('./output_torch[1][2].npy') yolo_out6 = torch.from_numpy(y_true_0).to(final_device) y_true_0 = np.load('./output_torch[2][0].npy') yolo_out7 = torch.from_numpy(y_true_0).to(final_device) y_true_0 = np.load('./output_torch[2][1].npy') yolo_out8 = torch.from_numpy(y_true_0).to(final_device) y_true_0 = np.load('./output_torch[2][2].npy') yolo_out9 = torch.from_numpy(y_true_0).to(final_device) yolo_out = ((yolo_out1,yolo_out2,yolo_out3),(yolo_out4,yolo_out5,yolo_out6),(yolo_out7,yolo_out8,yolo_out9)) batch_y_true_0_torch = np.load('./batch_y_true_0_torch.npy') batch_y_true_0_torch = torch.from_numpy(batch_y_true_0_torch).to(final_device) batch_y_true_1_torch = np.load('./batch_y_true_1_torch.npy') batch_y_true_1_torch = torch.from_numpy(batch_y_true_1_torch).to(final_device) batch_y_true_2_torch = np.load('./batch_y_true_2_torch.npy') batch_y_true_2_torch = torch.from_numpy(batch_y_true_2_torch).to(final_device) batch_gt_box0_torch = np.load('./batch_gt_box0_torch.npy') batch_gt_box0_torch = torch.from_numpy(batch_gt_box0_torch).to(final_device) batch_gt_box1_torch = np.load('./batch_gt_box1_torch.npy') batch_gt_box1_torch = torch.from_numpy(batch_gt_box1_torch).to(final_device) batch_gt_box2_torch = np.load('./batch_gt_box2_torch.npy') batch_gt_box2_torch = torch.from_numpy(batch_gt_box2_torch).to(final_device) input_shape = np.load('./input_shape.npy') input_shape = torch.from_numpy(input_shape).to(final_device) model_pt = yolov3_torch(is_training=True) model_pt.train() model_torch = model_pt.to(final_device) params_torch = {key: value.detach().cpu().numpy() for key, value in model_torch.state_dict().items()} loss_jax_fun = loss_yolo_jax() params_jax = {name: jnp.array(value, dtype=jnp.float32) for name, value in params_torch.items()} loss_jax, jax_grads = jax.value_and_grad(loss_jax_fun.calc_loss)(params_jax, yolo_out, batch_y_true_0_torch, batch_y_true_1_torch, batch_y_true_2_torch, batch_gt_box0_torch, batch_gt_box1_torch, batch_gt_box2_torch, input_shape) yolo_out1 = torch.isnan(yolo_out1).any() print('yolo_out1;',yolo_out1) yolo_out2 = torch.isnan(yolo_out2).any() print('yolo_out2;',yolo_out2) yolo_out3 = torch.isnan(yolo_out3).any() print('yolo_out3;',yolo_out3) yolo_out4 = torch.isnan(yolo_out4).any() print('yolo_out4;',yolo_out4) yolo_out5 = torch.isnan(yolo_out5).any() print('yolo_out5;',yolo_out5) yolo_out6 = torch.isnan(yolo_out6).any() print('yolo_out6;',yolo_out6) yolo_out7 = torch.isnan(yolo_out7).any() print('yolo_out7;',yolo_out7) yolo_out8 = torch.isnan(yolo_out8).any() print('yolo_out8;',yolo_out8) yolo_out9 = torch.isnan(yolo_out9).any() print('yolo_out9;',yolo_out9) batch_y_true_0_torch = torch.isnan(batch_y_true_0_torch).any() print('batch_y_true_0_torch;',batch_y_true_0_torch) batch_y_true_1_torch = torch.isnan(batch_y_true_1_torch).any() print('batch_y_true_1_torch;',batch_y_true_1_torch) batch_y_true_2_torch = torch.isnan(batch_y_true_2_torch).any() print('batch_y_true_2_torch;',batch_y_true_2_torch) batch_gt_box0_torch = torch.isnan(batch_gt_box0_torch).any() print('batch_gt_box0_torch;',batch_gt_box0_torch) batch_gt_box1_torch = torch.isnan(batch_gt_box1_torch).any() print('batch_gt_box1_torch;',batch_gt_box1_torch) batch_gt_box2_torch = torch.isnan(batch_gt_box2_torch).any() print('batch_gt_box2_torch;',batch_gt_box2_torch) input_shape = torch.isnan(input_shape).any() print('input_shape;',input_shape) print('loss_torch_result;',np.array(loss_jax))
Code and data links:https://drive.google.com/file/d/1lxHI_OQwjSUCj7vszNIzl_NmkyVF6JQu/view?usp=sharing
The text was updated successfully, but these errors were encountered:
Hi - if I'm not mistaken, this looks like a duplicate of #23625. Let's close this one and continue the discussion there.
Sorry, something went wrong.
No branches or pull requests
Description
Please specify cuda:0 at the very beginning.
System info (python version, jaxlib version, accelerator, etc.)
Code and data links:https://drive.google.com/file/d/1lxHI_OQwjSUCj7vszNIzl_NmkyVF6JQu/view?usp=sharing
The text was updated successfully, but these errors were encountered: