Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tuning GPT-J-6B in colab: 8-bit weights with low-rank adaptors #14839

Open
dvmazur opened this issue Dec 19, 2021 · 34 comments
Open

Fine-tuning GPT-J-6B in colab: 8-bit weights with low-rank adaptors #14839

dvmazur opened this issue Dec 19, 2021 · 34 comments

Comments

@dvmazur
Copy link

dvmazur commented Dec 19, 2021

🌟 New model addition

Model description

This is a version of EleutherAI's GPT-J with 6 billion parameters that is modified so you can generate and fine-tune the model in colab or equivalent desktop GPU (e.g. single 1080Ti).

The original GPT-J takes 22+ GB memory for float32 parameters. Even if you cast everything to 16-bit, it will still not fit onto most single-GPU setups short of A6000 and A100. You can inference it on TPU or CPUs, but fine-tuning is way more expensive.

Implementation

Proof-of-concept notebook is available here: colab

Model card has more detailed explanations and auxiliary notebooks (e.g. model conversion and perplexity check).

The current implementation is somewhat hacky, but it can be integrated easily with modelling_gptj.py if you like the idea.

Open source status

@dvmazur dvmazur changed the title 8bit GPT-J-6B Fine-tuning GPT-J-6B in colab: 8-bit weights with low-rank adaptors Dec 20, 2021
@stas00
Copy link
Contributor

stas00 commented Dec 31, 2021

Just read the LORA paper and your implementation combined with weight quantization is very neat, @deniskamazur. Thank you!

a few comments:

  1. If and when this is integrated into transformers may I just suggest not to override __repr__ in the frozen modules? I have puzzled for a while over why the adapter weights don't show up when dumping the model until I have noticed a custom __repr__ that hides them. I totally get it that it was added for brevity of the demo. Nothing needs to be changed in the demo notebook.
  2. With the little I know of BNB, Adam8bit usually requires a StableEmdedding - which is the same as nn.Embedding but with layer_norm inited to kaiming uniform at the end of forward. Do you think it's not needed for LoRA? We will be discussing this in RFC: Integrating bitsandbytes 8-bit optimizer / adding Embedding Norm #14819 as well, once Tim is back from vacation. But I thought I'd bring it up here as well as it's relevant.
  3. It would be good to finetune it for a bit and see that LoRA actually delivers on the promise. Unless someone already did so, then it's not needed.
  4. How do we decide on a good adapter_dim (rank) to recommend to users? what should be the default? and this hparam definitely should be user-configurable.
  5. We surely will want to make it available to more than just GPT-J if it works well. But it's good to start with one model.

@henk717
Copy link

henk717 commented Dec 31, 2021

The KoboldAI community is really looking forward to seeing these 8-bit models implemented, since many users of our software use it at their own home computers this allows more people to run 6B at good speeds.
Ideally we'd see a way to easily convert the models for our users on the fly, similar to how _half() works so they can load unconverted versions and still have the gains this brings.

If this is not possible i hope it will be easy to detect that a model is the 8-bit variant, so we can avoid executing half() on the model.

@TimDettmers
Copy link
Contributor

Thank you so much for creating this, Denis!

  1. With the little I know of BNB, Adam8bit usually requires a StableEmdedding - which is the same as nn.Embedding but with layer_norm inited to kaiming uniform at the end of forward. Do you think it's not needed for LoRA? We will be discussing this in RFC: Integrating bitsandbytes 8-bit optimizer / adding Embedding Norm #14819 as well, once Tim is back from vacation. But I thought I'd bring it up here as well as it's relevant.

See the discussion in this issue for more information, but in short, StableEmbedding layer is only required if the model was pretrained with the StableEmbedding layer.

In the case of regular finetuning with 8-bit Adam, it is better to have 32-bit optimizers for the embedding layer. It is currently unclear if this is required for LoRA since the frozen 8-bit weights will provide some stability. Just to be sure, it is probably better to optimize the LoRA embedding layer in 32-bit (no change to the model). You can integrate this in your embedding layer class as shown here. Optionally, you can use the bnb.nn.StableEmbedding in place of the LoRA embedding layer and optimize the linear projection normally:

        elif isinstance(module, FrozenBNBEmbedding):
            module.adapter = nn.Sequential(
                bnb.nn.StableEmbedding(module.num_embeddings, adapter_dim),
                nn.Linear(adapter_dim, module.embedding_dim, bias=False),
            )
  1. It would be good to finetune it for a bit and see that LoRA actually delivers on the promise. Unless someone already did so, then it's not needed.

This is definitely a good idea. From my experience with 8-bit weights is that they work fine as long as they are not optimized over time. So keeping them frozen and optimizing the low-rank matrices should work just fine and produce results similar to the LoRA paper. However, I have never tried the setup of 8-bit weights + 16/32-bit low-rank matrices, so its better to check this.

@stas00
Copy link
Contributor

stas00 commented Dec 31, 2021

If this is not possible i hope it will be easy to detect that a model is the 8-bit variant, so we can avoid executing half() on the model.

  1. This quant+lora will most likely require new architecture, in which case the model should automatically do the right thing on load. or at the very least there should be a config entry which will tell transformers what to do.

  2. I'm trying to find a way to automatically detect the dtype here https://github.com/stas00/ml-ways/blob/master/numbers/detect-model-pretrained-in-bf16-fp16-fp32.ipynb, so now we can try int8 as well - I could use more inputs to help with that work.

  3. Also I proposed a while ago to have a model save how it was trained in its config.json [RFC] introduce config.trained_precision #11209 - my proposal didn't go far, but perhaps this new development might give it some push.

@dvmazur
Copy link
Author

dvmazur commented Jan 1, 2022

Hey, everyone! Thanks for your interest and comments!

  1. I'd like to discuss if we actually need LoRa adapters in the possible implementation. As I see it, they are not necessarily a part of the 8bit model. Maybe, we could just add an add_low_rank_adaptors_ function or method.

  2. @stas00, I like your idea of generalizing this to other models. Though I don't have any ideas regarding the possible implementation of this. Would be glad to hear yours.

  3. I could open a PR with the 8bit GPT-J without adapters like tomorrow. Should I do it, or is there anything we should discuss before that?

@stas00
Copy link
Contributor

stas00 commented Jan 1, 2022

  • I'd like to discuss if we actually need LoRa adapters in the possible implementation. As I see it, they are not necessarily a part of the 8bit model. Maybe, we could just add an add_low_rank_adaptors_ function or method.

These are orthogonal features so probably they can be implemented separately. Separating these surely would make the PRs simpler to manage. But it'd be good to keep in mind the ensemble from the get going.

  • @stas00, I like your idea of generalizing this to other models. Though I don't have any ideas regarding the possible implementation of this. Would be glad to hear yours.

Since you're overriding pytorch components, this is already generic enough.

So the unique to model changes are the post_init code where you call 1x or 2x of convert_to_int8. By post init I mean literally post init (we don't have such method yet I think).

So here we need a sort of a map/policy per arch that will run the right post_init after the map lookup if the model config says so, so .e.g.

So this is a hardcoded way (to replace monkeypatch)

class GPTJBlock():
    def __init__(self, config):
        super().__init__(config)
        [...]
        if config.8bits:
            convert_to_int8(self.attn)
            convert_to_int8(self.mlp)

and the more generic way which can then be expanded to other archs easily:

# in another file
8bit_map = dict(
    gptj=dict(
        GPTJBlock       = ["self.attn", "self.mlp"],
        GPTJModel       = ["self"],
        GPTJForCausalLM = ["self"],
    ),
    gptneo=dict(),
    gpt2=dict()
)

# gptj_modeling
class GPTJBlock():
    def __init__(self, config):
        super().__init__(config)
        #[...]
        if config.8bits:
            to_int8_params = 8bit_map["GPTJBlock"]
            for param in to_int8_params:
                # XXX: figure out the getattr for self vs self.foo
                convert_to_int8(getattr(self, "param"))

which of course should be refactored into a simple:

        if config.8bits:
            self.to_init8() # do all of the above

and since we will likely to have other similar maps as we try to integrate all the new development this then again can be abstracted away:

post_init_maps = dict(
    8bit=8bit_map,
    featureX=featureX_map, # doesn't exist yet
)
[....]
        self.post_init()

which will do this and other future feature enabling and not make the code noisy.

On the other hand it's possible that my proposal will be supported by others and an explicit code will be used for each class/arch.

This is all very incomplete pseudo code, just to show what I'm trying to propose conceptually

Here is another example where a policy map is created for different archs:

# XXX: Reza - need the rest of the map
inference_custom_map = dict(
electra=dict(ElectraLayer=("output.dense")),
roberta=dict(RobertaLayer=("output.dense")),
t5=dict(T5Block=("SelfAttention.o", "EncDecAttention.o", "DenseReluDense.wo")),
)

This is from a very early deepspeed-inference PR #14426

  • I could open a PR with the 8bit GPT-J without adapters like tomorrow. Should I do it, or is there anything we should discuss before that?

I'm sure others will have a lot more to say, but since you have the code written already, probably the best way is to just open an PR and go from there.

You can start with the hardcoded version or you can try to do something like I suggested, which will immediately prepare a foundation to support other architectures. As I said earlier w/o hearing from other maintainers I'm not sure what is the best first step. The lowest risk is hardcoded I'd say.

Alternatively you can wait till Monday when many devs should be back and may have a chance to comment.

@dvmazur
Copy link
Author

dvmazur commented Jan 1, 2022

I like your suggestion with the policy map. I think I'll wait for the other maintainer's opinions before opening the PR. Thanks!

@stas00
Copy link
Contributor

stas00 commented Jan 2, 2022

I also wonder whether the policy should be arch-specific, or model-specific - what if someone wants to do 8-bit only for FFN or only for Embedding? If model-specific than the specific params to convert to 8-bit can be declared in the model config. or perhaps there could be an arch-specific default and then the model-specific could override it? Not sure...

@TimDettmers
Copy link
Contributor

TimDettmers commented Jan 2, 2022

  1. I'd like to discuss if we actually need LoRa adapters in the possible implementation. As I see it, they are not necessarily a part of the 8bit model. Maybe, we could just add an add_low_rank_adaptors_ function or method.

From my experience training with 8-bit dynamic block-wise quantization degrades performance over time but it is fine if only applied once and used for inference or as in the case of LoRA as some sort of "base output" that is adapted. As such, I think that LoRA might be required to maintain good performance. That being said, I have never tried finetuning a model and I only worked on pretraining -- it might be that finetuning with 8-bit weights works just fine.

I think the solution with a map to specify 8-bit parameters would be very handy. I think that would give the flexibility that is needed. What I would add is what kind of int8 data type is used.

I also wonder whether the policy should be arch-specific, or model-specific - what if someone wants to do 8-bit only for FFN or only for Embedding? If model-specific than the specific params to convert to 8-bit can be declared in the model config. or perhaps there could be an arch-specific default and then the model-specific could override it? Not sure...

I think it should be model specific. There are certain tradeoffs and important differences having certain things in 8-bit and others in 16-bit for the same model architecture. So it would be very useful to be able to have more flexibility overall to accommodate that.

@stas00
Copy link
Contributor

stas00 commented Jan 2, 2022

What I would add is what kind of int8 data type is used.

Did you mean to say something different here, Tim? Unless I misunderstood, int8 is already a single data type.

Perhaps you meant having a flexibility on how many quantization bits are used for different components, so it's not always 8, but can be 4, 16, etc.? Same as optim_bits param in the BNB's optim:

 GlobalOptimManager.get_instance().register_module_override(module, 'weight', {'optim_bits': 32})

@TimDettmers
Copy link
Contributor

Did you mean to say something different here, Tim? Unless I misunderstood, int8 is already a single data type.

Currently, the bnb quantization by default uses dynamic block-wise quantization so the int8 data type represents that data type which is defined by the int8 data + int-to-float map + normalization constants. This data type is storage optimized. Soon, I will also add another data type to bnb which will be compute optimized. It is still represented by int8 data + int-to-float map + normalization constants but these will be different and incompatible from the storage optimized variant.

At this point, it is already clear to me that the storage data type can be improved quite easily. So it might also be helpful to support that to make sure future variants can be supported easily.

On the other hand, it might be better defined separately. That one defines int8 + a quantization method which is defined somewhere else.

@stas00
Copy link
Contributor

stas00 commented Jan 3, 2022

Sounds good, Tim. So I trust you will come up with the different names then. We just need to think how to make it easily expandable in the future to support other types.

My thinking is that perhaps BNB won't be the only library providing quantization support so the more generic it is the better.

We can start with one model, flag it experimental until we sort out the config.

@dvmazur
Copy link
Author

dvmazur commented Jan 11, 2022

Hi, everyone! Thank you for your suggestions. I'm currently busy with my uni exams, but I'll be back with a PR in a couple of weeks.

@BangDaeng
Copy link

I have a question and I am writing.

quantized model(hivemind/gpt-j-6B-8bit) and of the original model(EleutherAI/gpt-j-6B)
The generate inference speed is almost doubled(quantized model is much slower)
I wonder if it is normal to come out at that speed or if it can be reduced

@justheuristic
Copy link
Contributor

justheuristic commented Jan 24, 2022

Hi! The inference speed is indeed slower due to the fact that you de-quantize weight matrices for every token.
You can increase the batch size (i.e. generate several sequences in parallel) to reduce that overhead.

The same is true for training: the fine-tuning speed is not significantly different from the original model because training is parallel over sequence_length tokens (while inference is inherently sequential).

You can combine the two setups (vanilla and 8-bit) to better fit your hardware.
For instance, if you have a T4 or rtx3090 gpu, it is enough to inference the model but not enough to fine-tune it. The optimal pipeline would be to fine-tune using 8-bit weights, then de-quantize for inference. In turn, if you have a 10-12GB GPU such as rtx 2080Ti or 3080, inferencing should run in 8-bit mode as well.

@Ontopic
Copy link

Ontopic commented Feb 5, 2022

I just really can not hold back from saying, this is awesome! Thank you 🙏 Good luck on your studies, hope when you're finished I can assist you somehow with next steps.

@dvmazur
Copy link
Author

dvmazur commented Feb 11, 2022

Hey, everyone! I've implemented the «hardcoded» version of this issue. You can verify it's functional over here. Should I add any tests before opening a PR?

I'd also be glad to implement LoRA and a generalized version of this issue in future PRs.

@stas00
Copy link
Contributor

stas00 commented Feb 11, 2022

I've implemented the «hardcoded» version of this issue.

Awesome news, @deniskamazur!

I won't have time at this moment to support this process very closely but I trust there will be other maintainers who will have a closer look and provide feedback once you open a PR.

Should I add any tests before opening a PR?

Definitely, and you can use this tiny model for functionality tests: https://huggingface.co/hf-internal-testing/tiny-random-gptj
but I guess you will need the 8bit version which we currently don't have, perhaps then start with what you have and then we can reduce it to a tiny size at the end of the PR process (we want functional tests to run fast).

As we have a massive test suite it should be relatively easy to build upon/mimic some of the existing tests. And if you get stuck please don't hesitate to ask in the PR.

@dvmazur
Copy link
Author

dvmazur commented Feb 11, 2022

Great, thanks! I'll open a PR as soon as I write the test then.

@dvmazur
Copy link
Author

dvmazur commented Jul 7, 2022

Hey! I've noticed this PR, that seems to generalize what we are doing with gpt-j-8bit. What should I do with this issue?

@stas00
Copy link
Contributor

stas00 commented Jul 8, 2022

Hi Denis, it has been a long time.... perhaps there has been a misunderstanding - as we have been waiting for you to complete the PR so nothing has happened here until now.

Let's tag @younesbelkada, whose PR you linked to. Younes, not to load more work on you, but a quick question - does your PR supercedes Denis' work? or is there some collaboration that can happen here?

@younesbelkada
Copy link
Contributor

Hi @deniskamazur @stas00
Sorry for getting back late on this!
I don't think there will be a conflict in both methods, but our PR aims to support all models on transformers by replacing their Linear layers by the one that will be provided by bitsandbytes - so naturally GPT-J should be supported too. But I am not sure the quantization method you want to integrate here is the same as the one we are aiming to integrate on the other PR. In our implementation the weights should not need to be loaded/pushed in int8 and could be directly casted from any fp16 weights, therefore we could just do something like AutoModel.from_pretrained(load_in_8bit=True) and it should be fine (which is different to what is described here?). Though, I will be definitely happy to discuss any possible collaboration with you if you see any! Feel free to jump in the other PR
tagging also @TimDettmers

@stas00
Copy link
Contributor

stas00 commented Jul 12, 2022

I suppose the advantage of loading in int8, is that with fp16 you need 2x memory upfront, but since we now have sharded checkpoints this can be overcome by sharding into smaller shards if someone is really tight on memory, so only the embedding will be the largest param.

But otherwise I'll let you guys to discuss the pros and cons of which way, as I'm still busy with the BigScience, but would love to study this closer / support you guys once the marathon is over.

May be let's also cc @justheuristic to this discussion. So between the four of you this domain is in the good hands.

@parastooAflaki
Copy link

Hey Thanks for your notebook

I am trying to run this notebook how ever I am getting the following error when installing bitesandbytes-cuda111 with your specified version 0.26.0:
ERROR: Could not find a version that satisfies the requirement bitsandbytes-cuda111==0.26.0 (from versions: 0.26.0.post2)
ERROR: No matching distribution found for bitsandbytes-cuda111==0.26.0

Please let me know if any other version should be replaced. Thanks

@petertjmills
Copy link

Hey Thanks for your notebook

I am trying to run this notebook how ever I am getting the following error when installing bitesandbytes-cuda111 with your specified version 0.26.0: ERROR: Could not find a version that satisfies the requirement bitsandbytes-cuda111==0.26.0 (from versions: 0.26.0.post2) ERROR: No matching distribution found for bitsandbytes-cuda111==0.26.0

Please let me know if any other version should be replaced. Thanks

Change !pip install bitsandbytes-cuda111==0.26.0 to !pip install bitsandbytes and this notebook works for now.

I suppose the advantage of loading in int8, is that with fp16 you need 2x memory upfront, but since we now have sharded checkpoints this can be overcome by sharding into smaller shards if someone is really tight on memory, so only the embedding will be the largest param.

But otherwise I'll let you guys to discuss the pros and cons of which way, as I'm still busy with the BigScience, but would love to study this closer / support you guys once the marathon is over.

May be let's also cc @justheuristic to this discussion. So between the four of you this domain is in the good hands.

As per the hivemind/gpt-j-6b-8bit model card, I'm trying to use load_in_8bit=True with EleutherAI/gpt-j-6B but I can't seem to get it to work without crashing due to too much RAM usage.
What would the RAM requirements be?

@mommi84
Copy link

mommi84 commented Oct 4, 2022

@petertjmills Same here.

Using int8, the original model fits on an 8 GB NVIDIA GeForce GTX 1080, but crashes after the first generation. The Hivemind model uses float16 or float32 for computation, so it's even more unlikely to succeed. Probably at least 9-10 GB VRAM are needed.

@calix
Copy link

calix commented Nov 24, 2022

I am getting the following error when attempting to fine-tune:

Traceback (most recent call last):
File "/opt/gpt-j-8bit/gpt-j-6b-8-bit.py", line 242, in
out = gpt.forward(**batch,)
File "/opt/gpt-j-8bit/.env/lib/python3.8/site-packages/transformers/models/gptj/modeling_gptj.py", line 782, in forward
transformer_outputs = self.transformer(
File "/opt/gpt-j-8bit/.env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/gpt-j-8bit/.env/lib/python3.8/site-packages/transformers/models/gptj/modeling_gptj.py", line 636, in forward
outputs = block(
File "/opt/gpt-j-8bit/.env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/gpt-j-8bit/.env/lib/python3.8/site-packages/transformers/models/gptj/modeling_gptj.py", line 291, in forward
feed_forward_hidden_states = self.mlp(hidden_states)
File "/opt/gpt-j-8bit/.env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/gpt-j-8bit/.env/lib/python3.8/site-packages/transformers/models/gptj/modeling_gptj.py", line 254, in forward
hidden_states = self.fc_in(hidden_states)
File "/opt/gpt-j-8bit/.env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/gpt-j-8bit/gpt-j-6b-8-bit.py", line 48, in forward
output += self.adapter(input)
RuntimeError: Output 0 of DequantizeAndLinearBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

Any idea on how to solve this?

Edit: Was able to get the fine-tuning going by modifying the following part:

def forward(self, input): output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias) if self.adapter: output += self.adapter(input) return output

To:

def forward(self, input): output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias) if self.adapter: output_cloned = torch.clone(output + self.adapter(input)) return output_cloned else: return output

@oobabooga
Copy link
Contributor

oobabooga commented Dec 19, 2022

After training the model with this notebook, how can it be saved and loaded back? If I try

gpt.save_pretrained(some_folder)

I can save the model, but then if I try to load it back in another script with

model = AutoModelForCausalLM.from_pretrained(some_folder).cuda()

I get the following warning:

Some weights of the model checkpoint at some_folder were not used when initializing GPTJForCausalLM: ['transformer.h.0.mlp.fc_in.code', 'transformer.h.21.attn.k_proj.adapter.1.weight', 'transformer.h.17.attn.k_proj.code', 'transformer.h.12.attn.v_proj.absmax', 'transformer.h.0.attn.q_proj.absmax', 'transformer.h.2.attn.out_proj.code (...)

And the loaded model only produces garbage output.

Alternatively, if I try to load it with

model = AutoModelForCausalLM.from_pretrained(some_folder, load_in_8bit=True, device_map='auto')

I get an error:

RuntimeError: Only Tensors of floating point and complex dtype can require gradients

@justheuristic
Copy link
Contributor

To the best of my knowledge, you will need to manually extract and save model state dict -- containing only the modules you have trained -- and then load the state dict with model.load_state_dict .

@andreo73
Copy link

andreo73 commented Jan 9, 2023

Hi, thanks for your very nice work!

I tried to almost blindly copy-past your notebook on a blank colab notebook (simple standard free account).

I only encountered one error, almost at the beginning

CUDA SETUP: Required library version not found: libsbitsandbytes_cpu.so. Maybe you need to compile it from source?
CUDA SETUP: Defaulting to libbitsandbytes_cpu.so...
/usr/local/lib/python3.8/dist-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.
warn("The installed version of bitsandbytes was compiled without GPU support. "

I ignored it (of course I felt that something was not right) and I arrived with no other errors up to the point where it tries to gpt.generate new text, i.e. before fine-tuning.

The command gtp.generate is running since 25 minutes. I suspect this slowness is not normal, but rather is an effect of not using gpu. Is that correct? any suggestion how to solve it?

Thanks
Andrea

@ajugjacob
Copy link

Hi @andreo73,

You need to install the CUDA version of bitsandbytes,
pip install bitsandbytes-cuda111

@crescedo
Copy link

Has anyone already tried fine-tuning this with the alpaca approach?

@jbingel
Copy link

jbingel commented Mar 22, 2023

Runtime error when batching

I'm having issues with getting the proof-of-concept notebook to work with a batch size > 1. The original notebook just iterates over the sample dataset row by row (one example at a time), which works fine also for my dataset.

However, when I feed batches of more than one example to the model (in out = gpt.forward(**batch,)), I get a RuntimeError: The size of tensor a (64) must match the size of tensor b (4) at non-singleton dimension 3.

The same happens when I use the Trainer API. Does anyone have an idea what's going on here?

My batches are of the form

{
  "input_ids": [[123, 456, ...], [321, 654, ...], ...],
  "attention_mask": [[1,1,1, ...0], [1,1,1,...0], ...]
}

@vinnitu
Copy link

vinnitu commented Apr 27, 2023

image

out of memory on colab every time ((

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests