Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question]: Resume training #3458

Open
alfredwallace7 opened this issue May 17, 2024 · 5 comments
Open

[Question]: Resume training #3458

alfredwallace7 opened this issue May 17, 2024 · 5 comments
Labels
question Further information is requested

Comments

@alfredwallace7
Copy link

Question

I'm trying to resume training according to :
This code
where it says :

7. continue training at later point. Load previously trained model checkpoint, then resume

trained_model = SequenceTagger.load(path + '/checkpoint.pt')

resume training best model, but this time until epoch 25

trainer.resume(trained_model,
base_path=path + '-resume',
max_epochs=25,
)

but resume is not defined in :
class ModelTrainer(Pluggable)

I'm sure it's a common task using your awesome library yet I cannot get it working.
Any information would be very appreciated.

@alfredwallace7 alfredwallace7 added the question Further information is requested label May 17, 2024
@nturusin
Copy link

nturusin commented Jun 7, 2024

Hi guys. I faced this exact issue too. Is there a solution in the end?

@helpmefindaname
Copy link
Collaborator

Hi @alfredwallace7
I am sorry, but I think resuming is currently not possible, that feature has been removed when the trainer got reworked in 0.13.0
We might reimplement this feature, but there are no plans to do so soon.

For the documentation:
Please refer to the doc page which is maintained and up to date. The /resources/docs/ folder is outdated and only there for legacy reasons.

@nturusin
Copy link

nturusin commented Jun 14, 2024

Hi again. To my own surprise, I managed to do it @alfredwallace7
Unfortunately, I had to add some ugly hack to handle w2v (during the process the trainer is trying to save files to some temporary folder which path is being defined dynamically).

def get_tagger(tag_dictionary, tag_type, path_to_checkpoint=None):
    embeddings = StackedEmbeddings([
        BytePairEmbeddings(
            language="en",
            dim=25,
            syllables=50000,
        ),
        WordEmbeddings(embeddings="en"),
        FlairEmbeddings(model="news-forward-fast"),
        FlairEmbeddings(model="news-backward-fast")
    ])

    if path_to_checkpoint is not None:
        tagger = SequenceTagger.load(path_to_checkpoint)
        path_to_w2v_file = tagger.embeddings.list_embedding_0.embedder.emb_file
        path_to_w2v = str(path_to_w2v_file).rsplit('/', 1)[0]
        if not os.path.exists(path_to_w2v):
            logger.info(f'Create folder for w2v: {path_to_w2v}')
            os.makedirs(path_to_w2v)
        logger.info(f'Loaded tagger from {path_to_checkpoint}')
    else:
        tagger = SequenceTagger(
            hidden_size=256,
            embeddings=embeddings,
            tag_dictionary=tag_dictionary,
            tag_type=tag_type,
            word_dropout=0.1,
            dropout=0.2,
            rnn_layers=2,
            use_crf=True,
            train_initial_hidden_state=True,
        )

    return tagger

Then you can create the trainer object as usual

trainer: ModelTrainer = ModelTrainer(tagger, column_corpus)
...
trainer.train(args.model_folder, **learning_params)
...

@alfredwallace7
Copy link
Author

Thanks for you replies. I'll fully read the doc and try the hack!

@david-waterworth
Copy link

I would add that resuming is important if you're training models on AWS and want to use spot instances, they need to be able to be interrupted and continue from a checkpoint automatically.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants