diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 30b7f22669..ba724c0387 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -64,10 +64,14 @@ def from_files( Examples:: - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, target="fraud", - num_cols=["account_value"], - cat_cols=["account_type"]) + datamodule = TranslationData.from_files( + train_file="data/wmt_en_ro/train.csv", + val_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", + batch_size=1, + ) """ return super().from_files(