diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 983b21acb2..ba724c0387 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -63,13 +63,14 @@ def from_files( TranslateData: The constructed data module. Examples:: + datamodule = TranslationData.from_files( train_file="data/wmt_en_ro/train.csv", val_file="data/wmt_en_ro/valid.csv", test_file="data/wmt_en_ro/test.csv", input="input", target="target", - batch_size=1 + batch_size=1, ) """