Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P1] [Error] can not use bfloat16 and TypeError: Object of type type is not JSON serializable #102

Closed
mrsempress opened this issue Jun 6, 2024 · 22 comments
Assignees
Labels
question Further information is requested

Comments

@mrsempress
Copy link

mrsempress commented Jun 6, 2024

Thanks for your wonderful model, but I have got some problems.

  1. can not use bfloat16.
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1211, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 992, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1095, in _update_causal_mask
    causal_mask = torch.triu(causal_mask, diagonal=1)
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
  1. I run the main_demo.ipynb, but got the error:
Traceback (most recent call last):                                                                                                                                                
  File "/mnt/geogpt-gpfs/pyreft/inference.py", line 61, in <module>                                                                  
    _ = trainer.train()                                                                                                                                                           
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train                                                                                     
    return inner_training_loop(                                                                                                                                                   
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2116, in _inner_training_loop                                                                      
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)                                                                                           
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 371, in on_train_begin                                                                    
    return self.call_event("on_train_begin", args, state, control)                                                                                                                
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 415, in call_event                                                                        
    result = getattr(callback, event)(                                                                                                                                            
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 636, in on_train_begin                                                      
    model_config_json = model.config.to_json_string()                                                                                                                             
  File "/opt/conda/lib/python3.10/site-packages/transformers/configuration_utils.py", line 938, in to_json_string                                                                 
    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  File "/opt/conda/lib/python3.10/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 201, in encode
    chunks = list(chunks)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 431, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "/opt/conda/lib/python3.10/json/encoder.py", line 325, in _iterencode_list
    yield from chunks
  File "/opt/conda/lib/python3.10/json/encoder.py", line 438, in _iterencode
    o = _default(o)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type type is not JSON serializable

I find the issues 69, but I use the main_demo.ipynb, so it does not work for me.

@frankaging
Copy link
Collaborator

frankaging commented Jun 6, 2024

@mrsempress hey, thanks for raising the issue.

on the second problem: the root cause is probably the one identified in #70 -- tensorboard is not well integrated yet.

as a result, you need to make sure to run your commend with --report_to none, or --report_to wandb.

@frankaging frankaging changed the title [Error] KeyError: Parameter containing and TypeError: Object of type type is not JSON serializable [P1] KeyError: Parameter containing and TypeError: Object of type type is not JSON serializable Jun 6, 2024
@frankaging frankaging self-assigned this Jun 6, 2024
@frankaging frankaging added the question Further information is requested label Jun 6, 2024
@frankaging
Copy link
Collaborator

@mrsempress on the first problem, could you make sure there is only 1 GPU visible on your machine (i know you ran with CUDA_VISIABLE_DEVICES=6 but this problem usually occur when there are multiple GPUs); our script does not support multi-gpu training well at this point yet.

@mrsempress mrsempress changed the title [P1] KeyError: Parameter containing and TypeError: Object of type type is not JSON serializable [Error] can not use bfloat16 and TypeError: Object of type type is not JSON serializable Jun 6, 2024
@mrsempress mrsempress changed the title [Error] can not use bfloat16 and TypeError: Object of type type is not JSON serializable [P1] [Error] can not use bfloat16 and TypeError: Object of type type is not JSON serializable Jun 6, 2024
@mrsempress
Copy link
Author

Sorry, the first issue has been updated. The first issue is that bfloat16 cannot be used.
For issue 2, I did not see the appearance of tensorboard in main_demo.ipynb, and there is no argparse, so I do not understand why.

@frankaging
Copy link
Collaborator

@mrsempress thanks.

For issue 1) could you provide your running script?

For issue 2) could you reproduce this error by running the notebook on google colab, and share the error'd colab with me?

Thanks! These will help me to root cause the issues here.

@mrsempress
Copy link
Author

For issue 1) CUDA_VISIBLE_DEVICES=0 python examples/loreft/train.py -task gsm8k -model ../../models/Llama-7b-hf -seed 42 -l all -r 4 -p f7+l7 -e 12 -lr 9e-4 -type NodireftIntervention -gradient_accumulation_steps 4 -batch_size 8 -eval_batch_size 4 --dropout 0.05 --test_split validation --use_normalized_template --greedy_decoding --warmup_ratio 0.00 --weight_decay 0.06
Additionally, I would like to know how much memory is occupied by hyperparameter tuning and training in Loreft. Because I use hyperparameter tuning, it takes up over 60 GB of memory. I want to know if it was only caused by changing bfloat to float32. Also, I would like to know how long the training time is usually?

For issue 2), as I reproduce it successfully, the link will be updated. Now when installing pyreft, Colab will prompt "you must restart the runtime in order to use newly installed versions", which will take some time. I only used the original ipynb without modifying the code, so you can also try the experiment. I am not sure if it is due to machine environment issues.

@frankaging
Copy link
Collaborator

@mrsempress Thanks.

I want to know if it was only caused by changing bfloat to float32

Could you explain more about the change? Did you change examples/loreft/train.py? And what is the change?

For issue 2), i attached my local notebook which does not encounter this issue:
main_demo.pdf

Could you check the version of your transformers library? Could you install 4.39.3 version and try again? It is mostly likely a env/set-up issue since all my experiments are running just fine.

@frankaging
Copy link
Collaborator

@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:

Screenshot 2024-06-06 at 3 01 01 PM

Please go to the logs, and trace out other details.

@mrsempress
Copy link
Author

@mrsempress Thanks.

I want to know if it was only caused by changing bfloat to float32

Could you explain more about the change? Did you change examples/loreft/train.py? And what is the change?

For issue 2), i attached my local notebook which does not encounter this issue: main_demo.pdf

Could you check the version of your transformers library? Could you install the 4.39.3 version and try again? It is mostly likely a env/set-up issue since all my experiments are running just fine.

I did not modify examples/loreft/train.py. Issue 1 means that I cannot use bfloat16, so I added —dtype float32 in the command line. I want to know why this memory is too large and why bfloat16 cannot be used. My previous transformers version is 4.40.2. After I changed the transformers version to 4.39.3, it did not work for me.

@mrsempress
Copy link
Author

@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:

Screenshot 2024-06-06 at 3 01 01 PM Please go to the logs, and trace out other details.

Thank you for your patient reply. I have understood the actual quantity required for memory, but I need to find out if it is due to bflot16 not being able to be used or if there are other reasons that cause the memory to be too large when running the same command.

@frankaging
Copy link
Collaborator

I need to find out if it is due to bflot16 not being able to be used or if ..

Hey! yes, i think so. I am running with bf16, and that is probably the reason why my MEM is lower.

@frankaging
Copy link
Collaborator

@mrsempress what is your torch version?

@frankaging
Copy link
Collaborator

@mrsempress hey, this is probably an env issue - to resolve this, maybe create a clean conda env, and install packages in the same versions as i have.

here is the requirements.txt as well as theenvironment.yml file of my conda env

requirements.txt

name: wuzhengx-310
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.12.12=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.13=h7f8727e_0
  - python=3.10.13=h955ad1f_0
  - readline=8.2=h5eee18b_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - xz=5.4.5=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
    - accelerate==0.29.1
    - aiofiles==23.2.1
    - aiohttp==3.9.3
    - aiosignal==1.3.1
    - alpaca-eval==0.6
    - altair==5.2.0
    - annotated-types==0.6.0
    - anyio==4.3.0
    - appdirs==1.4.4
    - argon2-cffi==23.1.0
    - argon2-cffi-bindings==21.2.0
    - arrow==1.3.0
    - asttokens==2.4.1
    - async-lru==2.0.4
    - async-timeout==4.0.3
    - attrs==23.2.0
    - babel==2.14.0
    - beautifulsoup4==4.12.3
    - bitsandbytes==0.42.0
    - bleach==6.1.0
    - cachetools==5.3.3
    - certifi==2024.2.2
    - cffi==1.16.0
    - charset-normalizer==3.3.2
    - click==8.1.7
    - colorama==0.4.6
    - comm==0.2.1
    - contourpy==1.2.0
    - cycler==0.12.1
    - dacite==1.8.1
    - datasets==2.18.0
    - debugpy==1.8.1
    - decorator==5.1.1
    - defusedxml==0.7.1
    - diffusers==0.27.2
    - dill==0.3.7
    - distro==1.9.0
    - docker-pycreds==0.4.0
    - einops==0.7.0
    - evaluate==0.4.1
    - exceptiongroup==1.2.0
    - executing==2.0.1
    - fastapi==0.110.0
    - fastjsonschema==2.19.1
    - ffmpy==0.3.2
    - filelock==3.13.1
    - fire==0.5.0
    - fonttools==4.49.0
    - fqdn==1.5.1
    - frozenlist==1.4.1
    - fsspec==2024.2.0
    - gcsfs==2024.2.0
    - gitdb==4.0.11
    - gitpython==3.1.42
    - google-api-core==2.18.0
    - google-auth==2.29.0
    - google-auth-oauthlib==1.2.0
    - google-cloud-core==2.4.1
    - google-cloud-storage==2.16.0
    - google-crc32c==1.5.0
    - google-resumable-media==2.7.0
    - googleapis-common-protos==1.63.0
    - gradio==3.50.0
    - gradio-client==0.6.1
    - h11==0.14.0
    - htmlmin==0.1.12
    - httpcore==1.0.4
    - httpx==0.27.0
    - huggingface-hub==0.20.3
    - idna==3.6
    - imagehash==4.3.1
    - importlib-metadata==7.1.0
    - importlib-resources==6.1.2
    - ipykernel==6.29.3
    - ipython==8.22.1
    - ipywidgets==8.1.1
    - isoduration==20.11.0
    - jedi==0.19.1
    - jinja2==3.1.3
    - joblib==1.3.2
    - json5==0.9.17
    - jsonpointer==2.4
    - jsonschema==4.21.1
    - jsonschema-specifications==2023.12.1
    - jupyter==1.0.0
    - jupyter-client==8.6.0
    - jupyter-console==6.6.3
    - jupyter-core==5.7.1
    - jupyter-events==0.9.0
    - jupyter-lsp==2.2.3
    - jupyter-server==2.12.5
    - jupyter-server-terminals==0.5.2
    - jupyterlab==4.1.2
    - jupyterlab-pygments==0.3.0
    - jupyterlab-server==2.25.3
    - jupyterlab-widgets==3.0.10
    - kiwisolver==1.4.5
    - llvmlite==0.42.0
    - markdown-it-py==3.0.0
    - markupsafe==2.1.5
    - matplotlib==3.7.4
    - matplotlib-inline==0.1.6
    - mdurl==0.1.2
    - mistune==3.0.2
    - mizani==0.9.3
    - mpmath==1.3.0
    - multidict==6.0.5
    - multimethod==1.11.2
    - multiprocess==0.70.15
    - nbclient==0.9.0
    - nbconvert==7.16.1
    - nbformat==5.9.2
    - nest-asyncio==1.6.0
    - networkx==3.2.1
    - ninja==1.11.1.1
    - notebook==7.1.1
    - notebook-shim==0.2.4
    - numba==0.59.1
    - numpy==1.26.4
    - nvidia-cublas-cu12==12.1.3.1
    - nvidia-cuda-cupti-cu12==12.1.105
    - nvidia-cuda-nvrtc-cu12==12.1.105
    - nvidia-cuda-runtime-cu12==12.1.105
    - nvidia-cudnn-cu12==8.9.2.26
    - nvidia-cufft-cu12==11.0.2.54
    - nvidia-curand-cu12==10.3.2.106
    - nvidia-cusolver-cu12==11.4.5.107
    - nvidia-cusparse-cu12==12.1.0.106
    - nvidia-nccl-cu12==2.19.3
    - nvidia-nvjitlink-cu12==12.3.101
    - nvidia-nvtx-cu12==12.1.105
    - oauthlib==3.2.2
    - openai==1.12.0
    - orjson==3.9.15
    - overrides==7.7.0
    - packaging==23.2
    - pandas==2.2.1
    - pandocfilters==1.5.1
    - parso==0.8.3
    - patsy==0.5.6
    - peft==0.11.1
    - pexpect==4.9.0
    - phik==0.12.4
    - pillow==10.2.0
    - pip==23.3.1
    - platformdirs==4.2.0
    - plotnine==0.12.4
    - prometheus-client==0.20.0
    - prompt-toolkit==3.0.43
    - proto-plus==1.23.0
    - protobuf==3.20.3
    - psutil==5.9.8
    - ptyprocess==0.7.0
    - pure-eval==0.2.2
    - pyarrow==15.0.0
    - pyarrow-hotfix==0.6
    - pyasn1==0.6.0
    - pyasn1-modules==0.4.0
    - pycparser==2.21
    - pydantic==2.6.2
    - pydantic-core==2.16.3
    - pydub==0.25.1
    - pygments==2.17.2
    - pyparsing==3.1.1
    - pyreft==0.0.4
    - python-dateutil==2.8.2
    - python-dotenv==1.0.1
    - python-json-logger==2.0.7
    - python-multipart==0.0.9
    - pytz==2024.1
    - pyvene==0.1.2
    - pywavelets==1.6.0
    - pyyaml==6.0.1
    - pyzmq==25.1.2
    - qtconsole==5.5.1
    - qtpy==2.4.1
    - referencing==0.33.0
    - reft==0.0.1.dev0
    - regex==2023.12.25
    - requests==2.31.0
    - requests-oauthlib==2.0.0
    - responses==0.18.0
    - rfc3339-validator==0.1.4
    - rfc3986-validator==0.1.1
    - rich==13.7.1
    - rpds-py==0.18.0
    - rsa==4.9
    - ruff==0.3.0
    - safetensors==0.4.2
    - scikit-learn==1.4.1.post1
    - scipy==1.11.4
    - seaborn==0.12.2
    - semantic-version==2.10.0
    - send2trash==1.8.2
    - sentencepiece==0.1.96
    - sentry-sdk==1.40.6
    - setproctitle==1.3.3
    - setuptools==68.2.2
    - shellingham==1.5.4
    - six==1.16.0
    - smmap==5.0.1
    - sniffio==1.3.1
    - soupsieve==2.5
    - spaces==0.26.0
    - stack-data==0.6.3
    - starlette==0.36.3
    - statsmodels==0.14.1
    - sympy==1.12
    - termcolor==2.4.0
    - terminado==0.18.0
    - threadpoolctl==3.3.0
    - tiktoken==0.6.0
    - tinycss2==1.2.1
    - tokenizers==0.15.2
    - tomli==2.0.1
    - tomlkit==0.12.0
    - toolz==0.12.1
    - torch==2.2.1
    - tornado==6.4
    - tqdm==4.66.2
    - traitlets==5.14.1
    - transformers==4.39.3
    - triton==2.2.0
    - typeguard==4.2.1
    - typer==0.9.0
    - types-python-dateutil==2.8.19.20240106
    - typing-extensions==4.10.0
    - tzdata==2024.1
    - uri-template==1.3.0
    - urllib3==2.2.1
    - uvicorn==0.27.1
    - visions==0.7.6
    - wandb==0.16.3
    - wcwidth==0.2.13
    - webcolors==1.13
    - webencodings==0.5.1
    - websocket-client==1.7.0
    - websockets==11.0.3
    - wheel==0.41.2
    - widgetsnbextension==4.0.10
    - wordcloud==1.9.3
    - xxhash==3.4.1
    - yarl==1.9.4
    - ydata-profiling==4.7.0
    - zipp==3.18.1

please let me know if the problem still exists. thanks.

@mrsempress
Copy link
Author

I need to find out if it is due to bflot16 not being able to be used or if ..

Hey! yes, i think so. I am running with bf16, and that is probably the reason why my MEM is lower.

Ok~

@mrsempress
Copy link
Author

@mrsempress what is your torch version?

My torch vision is 2.0.1

@mrsempress
Copy link
Author

@mrsempress hey, this is probably an env issue - to resolve this, maybe create a clean conda env, and install packages in the same versions as i have.

here is the requirements.txt as well as theenvironment.yml file of my conda env

requirements.txt

name: wuzhengx-310
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.12.12=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.13=h7f8727e_0
  - python=3.10.13=h955ad1f_0
  - readline=8.2=h5eee18b_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - xz=5.4.5=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
    - accelerate==0.29.1
    - aiofiles==23.2.1
    - aiohttp==3.9.3
    - aiosignal==1.3.1
    - alpaca-eval==0.6
    - altair==5.2.0
    - annotated-types==0.6.0
    - anyio==4.3.0
    - appdirs==1.4.4
    - argon2-cffi==23.1.0
    - argon2-cffi-bindings==21.2.0
    - arrow==1.3.0
    - asttokens==2.4.1
    - async-lru==2.0.4
    - async-timeout==4.0.3
    - attrs==23.2.0
    - babel==2.14.0
    - beautifulsoup4==4.12.3
    - bitsandbytes==0.42.0
    - bleach==6.1.0
    - cachetools==5.3.3
    - certifi==2024.2.2
    - cffi==1.16.0
    - charset-normalizer==3.3.2
    - click==8.1.7
    - colorama==0.4.6
    - comm==0.2.1
    - contourpy==1.2.0
    - cycler==0.12.1
    - dacite==1.8.1
    - datasets==2.18.0
    - debugpy==1.8.1
    - decorator==5.1.1
    - defusedxml==0.7.1
    - diffusers==0.27.2
    - dill==0.3.7
    - distro==1.9.0
    - docker-pycreds==0.4.0
    - einops==0.7.0
    - evaluate==0.4.1
    - exceptiongroup==1.2.0
    - executing==2.0.1
    - fastapi==0.110.0
    - fastjsonschema==2.19.1
    - ffmpy==0.3.2
    - filelock==3.13.1
    - fire==0.5.0
    - fonttools==4.49.0
    - fqdn==1.5.1
    - frozenlist==1.4.1
    - fsspec==2024.2.0
    - gcsfs==2024.2.0
    - gitdb==4.0.11
    - gitpython==3.1.42
    - google-api-core==2.18.0
    - google-auth==2.29.0
    - google-auth-oauthlib==1.2.0
    - google-cloud-core==2.4.1
    - google-cloud-storage==2.16.0
    - google-crc32c==1.5.0
    - google-resumable-media==2.7.0
    - googleapis-common-protos==1.63.0
    - gradio==3.50.0
    - gradio-client==0.6.1
    - h11==0.14.0
    - htmlmin==0.1.12
    - httpcore==1.0.4
    - httpx==0.27.0
    - huggingface-hub==0.20.3
    - idna==3.6
    - imagehash==4.3.1
    - importlib-metadata==7.1.0
    - importlib-resources==6.1.2
    - ipykernel==6.29.3
    - ipython==8.22.1
    - ipywidgets==8.1.1
    - isoduration==20.11.0
    - jedi==0.19.1
    - jinja2==3.1.3
    - joblib==1.3.2
    - json5==0.9.17
    - jsonpointer==2.4
    - jsonschema==4.21.1
    - jsonschema-specifications==2023.12.1
    - jupyter==1.0.0
    - jupyter-client==8.6.0
    - jupyter-console==6.6.3
    - jupyter-core==5.7.1
    - jupyter-events==0.9.0
    - jupyter-lsp==2.2.3
    - jupyter-server==2.12.5
    - jupyter-server-terminals==0.5.2
    - jupyterlab==4.1.2
    - jupyterlab-pygments==0.3.0
    - jupyterlab-server==2.25.3
    - jupyterlab-widgets==3.0.10
    - kiwisolver==1.4.5
    - llvmlite==0.42.0
    - markdown-it-py==3.0.0
    - markupsafe==2.1.5
    - matplotlib==3.7.4
    - matplotlib-inline==0.1.6
    - mdurl==0.1.2
    - mistune==3.0.2
    - mizani==0.9.3
    - mpmath==1.3.0
    - multidict==6.0.5
    - multimethod==1.11.2
    - multiprocess==0.70.15
    - nbclient==0.9.0
    - nbconvert==7.16.1
    - nbformat==5.9.2
    - nest-asyncio==1.6.0
    - networkx==3.2.1
    - ninja==1.11.1.1
    - notebook==7.1.1
    - notebook-shim==0.2.4
    - numba==0.59.1
    - numpy==1.26.4
    - nvidia-cublas-cu12==12.1.3.1
    - nvidia-cuda-cupti-cu12==12.1.105
    - nvidia-cuda-nvrtc-cu12==12.1.105
    - nvidia-cuda-runtime-cu12==12.1.105
    - nvidia-cudnn-cu12==8.9.2.26
    - nvidia-cufft-cu12==11.0.2.54
    - nvidia-curand-cu12==10.3.2.106
    - nvidia-cusolver-cu12==11.4.5.107
    - nvidia-cusparse-cu12==12.1.0.106
    - nvidia-nccl-cu12==2.19.3
    - nvidia-nvjitlink-cu12==12.3.101
    - nvidia-nvtx-cu12==12.1.105
    - oauthlib==3.2.2
    - openai==1.12.0
    - orjson==3.9.15
    - overrides==7.7.0
    - packaging==23.2
    - pandas==2.2.1
    - pandocfilters==1.5.1
    - parso==0.8.3
    - patsy==0.5.6
    - peft==0.11.1
    - pexpect==4.9.0
    - phik==0.12.4
    - pillow==10.2.0
    - pip==23.3.1
    - platformdirs==4.2.0
    - plotnine==0.12.4
    - prometheus-client==0.20.0
    - prompt-toolkit==3.0.43
    - proto-plus==1.23.0
    - protobuf==3.20.3
    - psutil==5.9.8
    - ptyprocess==0.7.0
    - pure-eval==0.2.2
    - pyarrow==15.0.0
    - pyarrow-hotfix==0.6
    - pyasn1==0.6.0
    - pyasn1-modules==0.4.0
    - pycparser==2.21
    - pydantic==2.6.2
    - pydantic-core==2.16.3
    - pydub==0.25.1
    - pygments==2.17.2
    - pyparsing==3.1.1
    - pyreft==0.0.4
    - python-dateutil==2.8.2
    - python-dotenv==1.0.1
    - python-json-logger==2.0.7
    - python-multipart==0.0.9
    - pytz==2024.1
    - pyvene==0.1.2
    - pywavelets==1.6.0
    - pyyaml==6.0.1
    - pyzmq==25.1.2
    - qtconsole==5.5.1
    - qtpy==2.4.1
    - referencing==0.33.0
    - reft==0.0.1.dev0
    - regex==2023.12.25
    - requests==2.31.0
    - requests-oauthlib==2.0.0
    - responses==0.18.0
    - rfc3339-validator==0.1.4
    - rfc3986-validator==0.1.1
    - rich==13.7.1
    - rpds-py==0.18.0
    - rsa==4.9
    - ruff==0.3.0
    - safetensors==0.4.2
    - scikit-learn==1.4.1.post1
    - scipy==1.11.4
    - seaborn==0.12.2
    - semantic-version==2.10.0
    - send2trash==1.8.2
    - sentencepiece==0.1.96
    - sentry-sdk==1.40.6
    - setproctitle==1.3.3
    - setuptools==68.2.2
    - shellingham==1.5.4
    - six==1.16.0
    - smmap==5.0.1
    - sniffio==1.3.1
    - soupsieve==2.5
    - spaces==0.26.0
    - stack-data==0.6.3
    - starlette==0.36.3
    - statsmodels==0.14.1
    - sympy==1.12
    - termcolor==2.4.0
    - terminado==0.18.0
    - threadpoolctl==3.3.0
    - tiktoken==0.6.0
    - tinycss2==1.2.1
    - tokenizers==0.15.2
    - tomli==2.0.1
    - tomlkit==0.12.0
    - toolz==0.12.1
    - torch==2.2.1
    - tornado==6.4
    - tqdm==4.66.2
    - traitlets==5.14.1
    - transformers==4.39.3
    - triton==2.2.0
    - typeguard==4.2.1
    - typer==0.9.0
    - types-python-dateutil==2.8.19.20240106
    - typing-extensions==4.10.0
    - tzdata==2024.1
    - uri-template==1.3.0
    - urllib3==2.2.1
    - uvicorn==0.27.1
    - visions==0.7.6
    - wandb==0.16.3
    - wcwidth==0.2.13
    - webcolors==1.13
    - webencodings==0.5.1
    - websocket-client==1.7.0
    - websockets==11.0.3
    - wheel==0.41.2
    - widgetsnbextension==4.0.10
    - wordcloud==1.9.3
    - xxhash==3.4.1
    - yarl==1.9.4
    - ydata-profiling==4.7.0
    - zipp==3.18.1

please let me know if the problem still exists. thanks.

Ok, I will try.

@mrsempress
Copy link
Author

@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:

Screenshot 2024-06-06 at 3 01 01 PM Please go to the logs, and trace out other details.

After I updated the version to make bfloat16 available, 7B experiments for arithmetic tasks need 52574 GMiB when batch size is 8, but your experiment can be run with 40G A100. That is to say, besides bfloat16, there are other factors that can reduce memory.
The command I run uses the one in examples/loreft/README.md, python train.py -task math
-Data_dir dataset
-Model yahma/llama-7b-hf
-Seed 42
-L all - r 8- p f7+l7- e 12- lr 9e-4
-Type LoreftIntervention
-Gradient_cccumulation_steps 2
-Batch_size 16
-Eval-batch_size 4
--Dropout 0.00
--Test_split test
--Usenormalized template
--Share_weights
--Warmup ratio 0.1
--Greedy_decoding
--Save_model

@frankaging
Copy link
Collaborator

@mrsempress this is expected i think, i am using -gradient_accumulation_steps 8 -batch_size 4, so if you are running with a batch size of 8, it can be doubled?

@frankaging
Copy link
Collaborator

this is the screenshot of one of the publicly released run stats:
https://wandb.ai/wuzhengx/ReFT_MuadDib_math/runs/xoumltuz/
Screenshot 2024-06-12 at 1 43 33 PM

@mrsempress
Copy link
Author

@mrsempress this is expected i think, i am using -gradient_accumulation_steps 8 -batch_size 4, so if you are running with a batch size of 8, it can be doubled?

So, should I use the command -gradient_accumulation_steps 16 -batch_size 8 or -gradient_accumulation_steps 4 -batch_size 8?
I also want to know how to set hyperparameters like gradient_accumulation_steps, as we cannot directly follow the command in examples/loreft/README.md.

@frankaging
Copy link
Collaborator

@mrsempress this is expected i think, i am using -gradient_accumulation_steps 8 -batch_size 4, so if you are running with a batch size of 8, it can be doubled?

So, should I use the command -gradient_accumulation_steps 16 -batch_size 8 or -gradient_accumulation_steps 4 -batch_size 8?

I also want to know how to set hyperparameters like gradient_accumulation_steps, as we cannot directly follow the command in examples/loreft/README.md.

Sorry about the confusion - but I think there is nothing being changed here, since for hyperparameter, what matters is the effective batch size, which is batch size (bounded by the GPU MEM) times the gradient accumulation step (you can set whatever you want to match the effective batch size). I think for the script you sent earlier and my settings, they all have an effective batch size of 32.

Note in the paper, we only report effective batch size, not per device batch size.

Hope these help. Thanks.

@mrsempress
Copy link
Author

Gradient_cccumulation_steps

Thank you for your reply. Now I understand how to set Gradient_cccumulation_steps and batch size.

@johnson7788
Copy link

Thanks for your wonderful model, but I have got some problems.

  1. can not use bfloat16.
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1211, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 992, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1095, in _update_causal_mask
    causal_mask = torch.triu(causal_mask, diagonal=1)
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
  1. I run the main_demo.ipynb, but got the error:
Traceback (most recent call last):                                                                                                                                                
  File "/mnt/geogpt-gpfs/pyreft/inference.py", line 61, in <module>                                                                  
    _ = trainer.train()                                                                                                                                                           
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train                                                                                     
    return inner_training_loop(                                                                                                                                                   
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2116, in _inner_training_loop                                                                      
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)                                                                                           
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 371, in on_train_begin                                                                    
    return self.call_event("on_train_begin", args, state, control)                                                                                                                
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 415, in call_event                                                                        
    result = getattr(callback, event)(                                                                                                                                            
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 636, in on_train_begin                                                      
    model_config_json = model.config.to_json_string()                                                                                                                             
  File "/opt/conda/lib/python3.10/site-packages/transformers/configuration_utils.py", line 938, in to_json_string                                                                 
    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  File "/opt/conda/lib/python3.10/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 201, in encode
    chunks = list(chunks)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 431, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "/opt/conda/lib/python3.10/json/encoder.py", line 325, in _iterencode_list
    yield from chunks
  File "/opt/conda/lib/python3.10/json/encoder.py", line 438, in _iterencode
    o = _default(o)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type type is not JSON serializable

I find the issues 69, but I use the main_demo.ipynb, so it does not work for me.

For problem 1, I upgrade to torch-2.3.1, it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants