Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bounty] PyTorch & HuggingFace Interface #139

Open
wants to merge 510 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
510 commits
Select commit Hold shift + click to select a range
1b27532
layer test
risingsunomi Aug 11, 2024
0876c79
layer test
risingsunomi Aug 11, 2024
e5c56c4
layer test
risingsunomi Aug 11, 2024
696c3bb
layer test
risingsunomi Aug 11, 2024
e807a63
layer test
risingsunomi Aug 11, 2024
0b589ee
layer test
risingsunomi Aug 11, 2024
ced3879
layer test
risingsunomi Aug 11, 2024
c6693a8
layer test
risingsunomi Aug 11, 2024
30e971d
layer test
risingsunomi Aug 11, 2024
8b4e624
fixing layer issue
risingsunomi Aug 11, 2024
bcb499e
temp and layer test
risingsunomi Aug 11, 2024
724c6c4
temp and layer test
risingsunomi Aug 11, 2024
e23f3f7
temp and layer test
risingsunomi Aug 11, 2024
7f13a6d
temp and layer test
risingsunomi Aug 11, 2024
ec92328
temp and layer test
risingsunomi Aug 11, 2024
fc3d224
temp and layer test
risingsunomi Aug 11, 2024
f14a339
temp and layer test
risingsunomi Aug 11, 2024
4f4a9e1
temp and layer test
risingsunomi Aug 11, 2024
3da44b3
change temp
risingsunomi Aug 11, 2024
0a4a003
change temp and alpha
risingsunomi Aug 11, 2024
e351501
change temp and alpha
risingsunomi Aug 11, 2024
3251567
change temp
risingsunomi Aug 11, 2024
16e4f7e
change temp
risingsunomi Aug 11, 2024
6083927
change temp
risingsunomi Aug 11, 2024
5b02fd1
change sampling
risingsunomi Aug 11, 2024
9805ac2
change sampling
risingsunomi Aug 11, 2024
8da3114
change sampling
risingsunomi Aug 11, 2024
c62dd2d
change sampling
risingsunomi Aug 11, 2024
0f7f96d
change sampling
risingsunomi Aug 11, 2024
fc36619
change sampling
risingsunomi Aug 11, 2024
b5f98d5
remove softmax
risingsunomi Aug 11, 2024
52d608f
remove softmax
risingsunomi Aug 11, 2024
b17a9ab
float long issue
risingsunomi Aug 11, 2024
69552e0
float long issue
risingsunomi Aug 11, 2024
1ee8a10
float long issue
risingsunomi Aug 11, 2024
1d9f482
float long issue
risingsunomi Aug 11, 2024
2ca9689
float long issue
risingsunomi Aug 11, 2024
0b8c9f2
cleaning up utils.py
risingsunomi Aug 11, 2024
94de83f
removing broken llama.py
risingsunomi Aug 11, 2024
f63f4b0
Merge pull request #2 from exo-explore/main
risingsunomi Aug 11, 2024
ab273d3
Merge branch 'main' into main
risingsunomi Aug 22, 2024
c365749
Merge pull request #4 from exo-explore/main
risingsunomi Aug 24, 2024
0b8221f
Merge pull request #5 from exo-explore/main
risingsunomi Aug 24, 2024
226a0ac
removing unittest, update inference return type, fixing converting te…
risingsunomi Aug 25, 2024
e11bebd
adding nvidia quadro and t1000 support
risingsunomi Aug 25, 2024
778cb6e
updating test, updating model selection for smaller quant llama3 model
risingsunomi Aug 25, 2024
56aae50
added updating model options to update_deps.py
risingsunomi Aug 25, 2024
7df4640
updating inference class init to take shard, updating pytorch test_in…
risingsunomi Aug 25, 2024
aa769ca
adding updates for inference_engine.py
risingsunomi Aug 25, 2024
08e8b41
reducing layer amount for llama3-2b-base
risingsunomi Aug 25, 2024
dd2812b
fixing gpu tensor to numpy conversion issues, updating top_p_sampling…
risingsunomi Aug 25, 2024
7bcd35e
forward rewrite, adding in caching with dynamic cache, cache conversi…
risingsunomi Aug 26, 2024
3beea22
updates to caching, stuck on issue with infer_prompt and infer_tensor…
risingsunomi Aug 26, 2024
87a14ca
trying to fix infer problems
risingsunomi Aug 26, 2024
356bf2f
switched everything to use caching, did more prep for encoding the to…
risingsunomi Aug 26, 2024
aa89032
fixing test
risingsunomi Aug 26, 2024
b9331d7
adding init py for old python versions
risingsunomi Aug 26, 2024
2c7aa9c
update readme and add in init pys
risingsunomi Aug 26, 2024
6da3e94
adding more tests
risingsunomi Aug 26, 2024
d0bc93c
adding more try catch to move through tests
risingsunomi Aug 26, 2024
0e221b2
tests
risingsunomi Aug 26, 2024
9fc9fdb
added position embeddings, update test
risingsunomi Aug 26, 2024
2635b4c
tests
risingsunomi Aug 26, 2024
86e89eb
adding back tests
risingsunomi Aug 27, 2024
64fbacd
adding another test
risingsunomi Aug 27, 2024
fb7c73f
Merge pull request #6 from exo-explore/main
risingsunomi Aug 27, 2024
0d93130
added gc collect to remove gpu, fixed tokenizers warning
risingsunomi Aug 27, 2024
0ae716d
fixing device
risingsunomi Aug 27, 2024
7705639
adding smaller model test
risingsunomi Aug 27, 2024
81d597d
testing
risingsunomi Aug 28, 2024
f1d3e31
added tinyllama
risingsunomi Aug 28, 2024
bf0e606
changing top_p
risingsunomi Aug 28, 2024
432efb5
updating test
risingsunomi Aug 28, 2024
2cdc14c
adding A10, adding test
risingsunomi Aug 28, 2024
ed5bea7
removing reloading of shard, changing temp and top_p
risingsunomi Aug 28, 2024
ea41845
Merge pull request #7 from exo-explore/main
risingsunomi Aug 28, 2024
46667b6
Merge pull request #8 from risingsunomi/pr139-dev
risingsunomi Aug 28, 2024
032c9b1
rewrite of sharded model using new split testing of huggingface models
risingsunomi Sep 1, 2024
626b223
building out new hf.py class, testing qwen and llama3 8b
risingsunomi Sep 1, 2024
f983e93
trying to load in weights but transformers/pytorch doesnt allow that …
risingsunomi Sep 4, 2024
d142be0
adding more testing, refining logit selection
risingsunomi Sep 13, 2024
be8d7fb
working split model test, updating class
risingsunomi Sep 15, 2024
9d1ecdd
working on class and inference engine updates
risingsunomi Sep 15, 2024
4b0df06
building out inference engine test
risingsunomi Sep 15, 2024
623468c
adding working tests, update to forward function to just use input_id…
risingsunomi Sep 16, 2024
19b322d
cleaning up code and tests, debugging and adding in cleaned up loggin…
risingsunomi Sep 16, 2024
cc2c14c
getting infer and stop token issues
risingsunomi Sep 16, 2024
583629c
add tracking of next token and other logits into the full input_ids s…
risingsunomi Sep 17, 2024
7ec5bb8
grpc testing
risingsunomi Sep 17, 2024
5903e63
grpc testing
risingsunomi Sep 17, 2024
e7a3fd0
grpc testing
risingsunomi Sep 17, 2024
f6eec5a
grpc testing
risingsunomi Sep 17, 2024
d441a51
grpc testing
risingsunomi Sep 17, 2024
e7f6dcb
grpc testing
risingsunomi Sep 17, 2024
ba5b005
grpc testing
risingsunomi Sep 17, 2024
6242d76
grpc testing
risingsunomi Sep 17, 2024
5630731
grpc testing
risingsunomi Sep 17, 2024
4a29268
testing passing hidden states in inference_state
risingsunomi Sep 17, 2024
2daf65f
testing passing hidden states in inference_state
risingsunomi Sep 17, 2024
36d5cde
fixing scalar issue, reversing passing hidden_states
risingsunomi Sep 17, 2024
6917f30
inference bug fix, grpc testing
risingsunomi Sep 17, 2024
adab336
inference bug fix, grpc testing
risingsunomi Sep 17, 2024
73146dd
fixing hf model for hidden_states
risingsunomi Sep 17, 2024
929386d
fixing hf model for hidden_states
risingsunomi Sep 17, 2024
32b8f67
fixing hf model for hidden_states
risingsunomi Sep 17, 2024
c86facb
fixing hf model for hidden_states
risingsunomi Sep 17, 2024
d15b20d
fixing hf model for hidden_states
risingsunomi Sep 17, 2024
5e41bc4
fixing hf model for hidden_states
risingsunomi Sep 17, 2024
b29c5f8
fixing hf model for hidden_states
risingsunomi Sep 17, 2024
ddaa79c
fixing kvcache issue
risingsunomi Sep 17, 2024
3164d38
fixing kvcache issue
risingsunomi Sep 17, 2024
e8532bc
fixing kvcache issue
risingsunomi Sep 17, 2024
6a5b8db
fixing kvcache issue
risingsunomi Sep 17, 2024
515687d
working on passing past input_ids between infers and nodes
risingsunomi Sep 18, 2024
92ebdd5
implemented infer caching and passing cache information via inference…
risingsunomi Sep 19, 2024
f0795bd
removing dynamic cache passing in inference_state as model does its o…
risingsunomi Sep 19, 2024
b8f15a0
removed clearning cache on infer prompt and only on finished infer te…
risingsunomi Sep 19, 2024
d0f3cb7
hidden state dropping between nodes issue
risingsunomi Sep 19, 2024
fa6f263
hidden state dropping between nodes issue
risingsunomi Sep 19, 2024
2b0e7b5
hidden state dropping between nodes issue
risingsunomi Sep 19, 2024
f793c00
Merge branch 'main' of github.com:exo-explore/exo into exo-fork-update
risingsunomi Sep 19, 2024
131c158
Merge branch 'exo-fork-update' of github.com:risingsunomi/exo-nvidia …
risingsunomi Sep 19, 2024
43a1f61
Merge pull request #11 from risingsunomi/exo-fork-update
risingsunomi Sep 19, 2024
8398409
Merge branch 'exo-explore:main' into main
risingsunomi Sep 26, 2024
09572c1
Merge github.com:exo-explore/exo into exo-explore-main
risingsunomi Oct 2, 2024
c861f30
Merge pull request #13 from risingsunomi/exo-explore-main
risingsunomi Oct 2, 2024
cee3e31
cleaning up code, removing helpers.py
risingsunomi Oct 2, 2024
f95942f
Merge branch 'main' of github.com:risingsunomi/exo-nvidia
risingsunomi Oct 2, 2024
c3ea732
Merge branch 'main' into HEAD
AlexCheema Oct 6, 2024
57e14e8
adding needed libs to setup.py, fixing 4 space to 2 space issue, addi…
risingsunomi Oct 6, 2024
9fe3ec6
cleaning up code, added pytorch engine to llama 3.2 1b model shard in…
risingsunomi Oct 6, 2024
447da4a
Merge pull request #14 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 6, 2024
b44f6e9
updating pytorch requirement
risingsunomi Oct 6, 2024
6e1ab58
Merge pull request #15 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 6, 2024
936e60a
trying tokenizer fixes for llama3.1
risingsunomi Oct 6, 2024
43c3c62
detecting 3.1 for adding padding token and using autotokenizer for ll…
risingsunomi Oct 6, 2024
75a29f4
updating models.py to use instruct version
risingsunomi Oct 7, 2024
e407404
fixing autotokenizer
risingsunomi Oct 7, 2024
47ff4b3
Merge pull request #16 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 7, 2024
668668f
making it so position and cache is computed every forward on hf model
risingsunomi Oct 7, 2024
4e356f8
loading cached input_ids when passing hidden states
risingsunomi Oct 7, 2024
a5ef04a
loading cached iids from infer state fix
risingsunomi Oct 7, 2024
e888baa
device fix
risingsunomi Oct 7, 2024
7d9eb17
position id fix
risingsunomi Oct 7, 2024
1198628
fixing inference instance state issues between nodes
risingsunomi Oct 7, 2024
d25b7ac
node testing
risingsunomi Oct 7, 2024
0721a4c
node inference fix
risingsunomi Oct 7, 2024
49b682b
node inference fix
risingsunomi Oct 7, 2024
77a52a5
node inference fix
risingsunomi Oct 7, 2024
2b3397f
node inference fix
risingsunomi Oct 7, 2024
2e588af
node inference fix
risingsunomi Oct 7, 2024
e2eba05
node inference fix
risingsunomi Oct 7, 2024
d7699eb
node inference fix
risingsunomi Oct 7, 2024
bd9bf4f
inference between nodes fixed by always calculating position id and i…
risingsunomi Oct 7, 2024
72bed37
Merge pull request #17 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 7, 2024
913a008
cleaning up code
risingsunomi Oct 7, 2024
a1c1c76
Merge pull request #18 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 7, 2024
d509e43
Merge branch 'main' into HEAD
AlexCheema Oct 7, 2024
216e83d
Merge branch 'main' into HEAD
AlexCheema Oct 7, 2024
b518f73
comma and other text issue fix
risingsunomi Oct 7, 2024
296dff6
Merge pull request #19 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 7, 2024
9d24779
adding threadpooling to forward and logit sampling
risingsunomi Oct 9, 2024
fe6ae45
Merge pull request #20 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 9, 2024
2c55634
Merge branch 'exo-explore:main' into main
risingsunomi Oct 10, 2024
d4fb74f
rename (PyTorch, pytorch) -> (Torch, torch)
AlexCheema Oct 10, 2024
edf1c3d
add ci jobs for chatgpt_api_integration_test_torch_linux_cpu and chat…
AlexCheema Oct 10, 2024
0fd6711
add ci jobs for chatgpt_api_integration_test_torch_linux_cpu and chat…
AlexCheema Oct 10, 2024
a4feeab
ci filters
AlexCheema Oct 10, 2024
55fd482
rm comments
AlexCheema Oct 10, 2024
da39519
ci
AlexCheema Oct 10, 2024
89f1be0
Merge branch 'main' into HEAD
AlexCheema Oct 10, 2024
b6f6afc
Merge remote-tracking branch 'origin/main' into HEAD
AlexCheema Oct 10, 2024
5eb6c34
fixed torch device selection
risingsunomi Oct 11, 2024
ed64437
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Oct 11, 2024
18d41eb
fixing imports
risingsunomi Oct 11, 2024
c73ed76
Merge pull request #21 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 11, 2024
9ecbf0c
fixing chatgpt_api mistake
risingsunomi Oct 11, 2024
79c9e70
Merge branch 'exo-explore:main' into main
risingsunomi Oct 11, 2024
ebfd44a
Merge pull request #22 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 11, 2024
dae2cbe
removing old pytorch folder
risingsunomi Oct 11, 2024
1c1dd06
Merge pull request #23 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 11, 2024
55ae027
Update README.md
risingsunomi Oct 11, 2024
4b6a86d
set all torch models in models.py
AlexCheema Oct 11, 2024
830d33d
in torch, explicitly set the device when initilaizing the model
AlexCheema Oct 11, 2024
074dfe3
spacing
AlexCheema Oct 11, 2024
d9cfcc4
add model mlx-community/Qwen2-0.5B-Instruct-4bit
AlexCheema Oct 11, 2024
c3e1934
Merge branch 'exo-explore:main' into main
risingsunomi Oct 12, 2024
2c056b4
code changes from PR feedback, working on splitting of weights
risingsunomi Oct 12, 2024
da5c28d
Merge branch 'exo-explore:main' into pr139-dev-oct24
risingsunomi Oct 12, 2024
83a723b
doing more work toward individual safetensor loading, adding back dev…
risingsunomi Oct 13, 2024
47be250
working on split model, moving to server for more vram
risingsunomi Oct 13, 2024
ea0d4b1
change to hf downloader as was not getting all safetensor files
risingsunomi Oct 13, 2024
30b7991
splitting model still work in progress as transformers still seems to…
risingsunomi Oct 13, 2024
3a2c431
updating readme
risingsunomi Oct 13, 2024
4def538
Merge branch 'main' into pr139-dev-oct24
risingsunomi Oct 13, 2024
b35224c
Merge pull request #24 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 13, 2024
6c6e7b2
successful splitting model test with only loading needed weights, imp…
risingsunomi Oct 14, 2024
55ffdc7
Merge branch 'pr139-dev-oct24' of github.com:risingsunomi/exo-nvidia …
risingsunomi Oct 14, 2024
aacdeb5
adding model sharding to inference engine, doing testing with inferen…
risingsunomi Oct 14, 2024
ce702d1
fixing layer range issue
risingsunomi Oct 14, 2024
e387a79
fixing layer range issue
risingsunomi Oct 14, 2024
e0ba2bb
fixing layer range issue
risingsunomi Oct 14, 2024
5b9638f
checking if ram over usaage even if reducing layers on large models
risingsunomi Oct 14, 2024
664f29f
half layer inference engine testing
risingsunomi Oct 14, 2024
2591fab
fixing layer amount with sharded modeling
risingsunomi Oct 14, 2024
99dac57
adding qwen2.5 3B for testing
risingsunomi Oct 14, 2024
c12526f
Merge pull request #25 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 14, 2024
493cd3e
updating inference engine test
risingsunomi Oct 14, 2024
de23294
cleaning up utils and split model
risingsunomi Oct 14, 2024
d5a02be
Merge pull request #26 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 14, 2024
e7470b1
bugfix in llm setup
dtnewman Oct 15, 2024
fa24f46
Merge pull request #27 from dtnewman/main
risingsunomi Oct 15, 2024
5c69f3f
Merge remote-tracking branch 'origin/main' into HEAD
AlexCheema Oct 16, 2024
f5a1cef
handle range not satisfiable edge case
AlexCheema Oct 16, 2024
751bd1c
updating to use automodelforcausallm instead of autoconfig
risingsunomi Oct 16, 2024
7d866d8
removing meta model
risingsunomi Oct 16, 2024
253237b
updating split model test
risingsunomi Oct 16, 2024
e46ffa4
updating split model test
risingsunomi Oct 16, 2024
476b6ba
automodel fix
risingsunomi Oct 16, 2024
f7e02e9
fixing split model test
risingsunomi Oct 16, 2024
bd6322f
pytorch offload buffers error
risingsunomi Oct 17, 2024
c51bd91
device_map any issue with split model
risingsunomi Oct 17, 2024
4a2aef4
updating split model test
risingsunomi Oct 17, 2024
79f0763
fixing split model issue
risingsunomi Oct 17, 2024
cbbc9cf
fixing node issues
risingsunomi Oct 17, 2024
58cebab
fixing node issues
risingsunomi Oct 17, 2024
7f9b1bb
fixing node issues
risingsunomi Oct 17, 2024
c3adec5
fixing node issues
risingsunomi Oct 17, 2024
c8e6acc
fixing node issues
risingsunomi Oct 17, 2024
df028e2
fixing node issues, range issue
risingsunomi Oct 17, 2024
e5a1939
fixing node issues, range issue
risingsunomi Oct 17, 2024
d03a85c
Merge branch 'main' into pr139-dev-oct24
risingsunomi Oct 17, 2024
69a8955
Merge pull request #28 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 17, 2024
d07b825
adding num hidden layers manipulation for all models
risingsunomi Oct 18, 2024
a840e7f
updating to use shard_num_hidden_layers
risingsunomi Oct 18, 2024
bf5f22d
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Oct 18, 2024
52fa3f8
adding in better layer manipulation
risingsunomi Oct 18, 2024
ec49e31
adding in safe tensor sharding, generate model.safetensors.index.json…
risingsunomi Oct 19, 2024
f45b514
implementing sharding tests, fixing bugs with safetensor recompile
risingsunomi Oct 19, 2024
f90c24a
adding safetensor sharding, implementing it into model inference engine
risingsunomi Oct 20, 2024
696c264
updating backup and backup restore
risingsunomi Oct 20, 2024
9514e92
added removing backup when restoring
risingsunomi Oct 20, 2024
d65505e
added generating weight map if none, did updates to backup and restor…
risingsunomi Oct 20, 2024
d5b6113
cleaning up logging
risingsunomi Oct 20, 2024
d2302cc
updating docstring in newest class file
risingsunomi Oct 20, 2024
35c32eb
Merge pull request #29 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 20, 2024
291aa10
Merge branch 'exo-explore:main' into main
risingsunomi Oct 23, 2024
fcb298b
Merge branch 'main' into main
AlexCheema Oct 23, 2024
6e32be6
merge fixing
risingsunomi Oct 27, 2024
df13fbc
Merge branch 'main' into pr/risingsunomi/30
risingsunomi Oct 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,50 @@ jobs:
inference_engine: mlx
model_id: llama-3.2-1b

chatgpt_api_integration_test_torch_linux_cpu:
machine:
image: ubuntu-2404:2024.08.1
resource_class: large
steps:
- checkout
- run:
name: Set up Python
command: |
brew install [email protected]
python3.12 -m venv env
source env/bin/activate
- run:
name: Install dependencies
command: |
source env/bin/activate
pip install --upgrade pip
pip install .
- run_chatgpt_api_test:
inference_engine: torch
model_id: llama-3.2-1b

chatgpt_api_integration_test_torch_mac:
macos:
xcode: "15.4.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Set up Python
command: |
brew install [email protected]
python3.12 -m venv env
source env/bin/activate
- run:
name: Install dependencies
command: |
source env/bin/activate
pip install --upgrade pip
pip install .
- run_chatgpt_api_test:
inference_engine: torch
model_id: llama-3.2-1b

test_macos_m1:
macos:
xcode: "16.0.0"
Expand Down Expand Up @@ -211,9 +255,72 @@ jobs:
workflows:
version: 2
build_and_test:
jobs:
- approve_run:
type: approval
requires: []
filters:
branches:
ignore: main
- unit_test:
requires:
- approve_run
- discovery_integration_test:
requires:
- approve_run
- chatgpt_api_integration_test_mlx:
requires:
- approve_run
- test_macos_m1:
requires:
- approve_run
- chatgpt_api_integration_test_torch_linux_cpu:
requires:
- approve_run
- chatgpt_api_integration_test_torch_mac:
requires:
- approve_run

# Workflow for forked PRs without approval
forked_pr_workflow:
jobs:
- unit_test
- discovery_integration_test
- chatgpt_api_integration_test_mlx
- test_macos_m1
# - chatgpt_api_integration_test_tinygrad
- chatgpt_api_integration_test_torch_linux_cpu
- chatgpt_api_integration_test_torch_mac
# The trigger condition ensures this workflow runs for forked PRs
triggers:
- type: pull_request
filters:
branches:
ignore: main

# Existing workflow for main branch
main_branch_workflow:
jobs:
- unit_test:
filters:
branches:
only: main
- discovery_integration_test:
filters:
branches:
only: main
- chatgpt_api_integration_test_mlx:
filters:
branches:
only: main
- test_macos_m1:
filters:
branches:
only: main
- chatgpt_api_integration_test_torch_linux_cpu:
filters:
branches:
only: main
- chatgpt_api_integration_test_torch_mac:
filters:
branches:
only: main
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,9 @@ cython_debug/
#.idea/

**/*.xcodeproj/*

# PyTorch interface
.offload

# neovim/vim settings
.vimrc
4 changes: 3 additions & 1 deletion exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from exo.models import model_base_shards
from typing import Callable


class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
self.role = role
Expand Down Expand Up @@ -60,6 +59,9 @@ def generate_completion(
"finish_reason": finish_reason,
}],
}

if DEBUG >= 3:
print(f"completion: {completion}")

if not stream:
completion["usage"] = {
Expand Down
1 change: 0 additions & 1 deletion exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def get_system_info():
return "Linux"
return "Non-Mac, non-Linux system"


def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports")

Expand Down
5 changes: 4 additions & 1 deletion exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class InferenceEngine(ABC):
@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass

@abstractmethod
Expand All @@ -27,5 +27,8 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))

return TinygradDynamicShardInferenceEngine(shard_downloader)
elif inference_engine_name == "torch":
from exo.inference.torch.inference import TorchDynamicShardInferenceEngine
return TorchDynamicShardInferenceEngine(shard_downloader)
else:
raise ValueError(f"Inference engine {inference_engine_name} not supported")
2 changes: 2 additions & 0 deletions exo/inference/torch/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data/
model/archive/
9 changes: 9 additions & 0 deletions exo/inference/torch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# PyTorch & HuggingFace inference engine

## Notes/Issues
### 10/10/2024
- To select a pytorch device via environment variables, set the variable TORCH_DEVICE
- XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM
- With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373)
- Looking into adding mobile device support properly
- If device is not CPU the data type defaults to float32 else float16.
Empty file added exo/inference/torch/__init__.py
Empty file.
Loading