Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights not aligned for pt and jax #22425

Closed
2 of 4 tasks
crystina-z opened this issue Mar 28, 2023 · 2 comments
Closed
2 of 4 tasks

Weights not aligned for pt and jax #22425

crystina-z opened this issue Mar 28, 2023 · 2 comments

Comments

@crystina-z
Copy link
Contributor

crystina-z commented Mar 28, 2023

System Info

  • transformers version: 4.26.1
  • Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.17
  • Python version: 3.8.16
  • Huggingface_hub version: 0.13.1
  • PyTorch version (GPU?): 1.10.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.6.6 (tpu)
  • Jax version: 0.4.5
  • JaxLib version: 0.4.4
  • Using GPU in script?: (False)
  • Using distributed or parallel set-up in script?: (True)

Who can help?

@sanchit-gandhi since it's about the weights in jax, and @ArthurZucker @younesbelkada since the sample below is based on XLM-R

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import (AutoConfig, AutoTokenizer, FlaxAutoModelForMaskedLM, AutoModelForMaskedLM)

model_name = "xlm-roberta-base" 
config_name = model_name
tokenizer_name = model_name
num_labels = 1
config = AutoConfig.from_pretrained(config_name, num_labels=num_labels) 
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)

sentence = ["This is a sentence."]

def jax_inference():
    data = tokenizer(sentence, return_tensors='np', return_attention_mask=True)
    model = FlaxAutoModelForMaskedLM.from_pretrained(model_name, config=config)
    embedding = model(**data, params=model.params, train=False)[0]
    return embedding


def jax_from_pt_inference():
    data = tokenizer(sentence, return_tensors='np', return_attention_mask=True)
    model = FlaxAutoModelForMaskedLM.from_pretrained(model_name, config=config, from_pt=True)
    embedding = model(**data, params=model.params, train=False)[0]
    return embedding


def torch_inference():
    data = tokenizer(sentence, return_tensors='pt', return_attention_mask=True)
    model = AutoModelForMaskedLM.from_pretrained(model_name, config=config)
    embedding = model(**data, return_dict=True).logits
    return embedding.cpu().detach().numpy()

e1 = jax_inference()
e2 = jax_from_pt_inference()
e3 = torch_inference()

print(e1[0, 0, :10])
print(e2[0, 0, :10])
print(e3[0, 0, :10])

Expected behavior

The above script was supposed to output the same (or very close) values, however, it would produce:

[64.36088     0.12701216 37.773556   26.37121    26.858221   28.791494  25.630554   21.905432   21.001484   25.389727  ]
[64.36088     0.12701216 37.773556   26.37121    26.858221   28.791494  25.630554   21.905432   21.001484   25.389727  ]
[64.29784     0.12513931 37.865814   26.475258   26.956318   28.914783 25.684874   21.950882   21.039997   25.494867  ]

The first two lines are the same (results using jax model, from jax weights or pytorch weights), however, they are different from the third line, the results produced by pytorch model. The difference is around 1~2 decimal points (e.g., 64.36 vs 64.297 and even 26.371 vs 26.475, which isn't neglectable)

@sanchit-gandhi
Copy link
Contributor

Hey @crystina-z! Thanks for the great code example! I see that you're running the script on TPU, could you try repeating the benchmark using the highest JAX matmul precision (see #15754 (comment) for details)? I think this should close the gap to PyTorch.

Some more details about the behaviour of JAX matmul here: google/jax#10413 (comment)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants