Skip to content

Commit

Permalink
add support for max_total_length=4096 for 43b (#6763)
Browse files Browse the repository at this point in the history
* add support for max_total_length=4096 for 43b

Signed-off-by: Zhilin Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Zhilin Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Gerald Shen <[email protected]>
  • Loading branch information
2 people authored and gshennvm committed Jul 12, 2023
1 parent dd1d48f commit b5699b4
Showing 1 changed file with 13 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
1. `--drop_duplicates` : Use this flag to drop rows that are exactly the same for both prompt and completion
2. `--split_train_validation` : Use this flag to split one file into separate train and validation files.
3. `--val_proportion 0.1`: Use a float (default 0.1) between 0 and 1 to control how much of the dataset to allocate to the validation set and the remaining for the train dataset.
4. `--short_context_model`: Use this flag to prepare data for use with models that have shorter context length of 2048 tokens (e.g. 5B and 20B models)
What to expect
Expand Down Expand Up @@ -396,6 +397,12 @@ def print_all_messages(messages):
parser.add_argument("--completion_template", "-ct", default="{completion}")
parser.add_argument("--drop_duplicates", "-dd", action="store_true")
parser.add_argument("--split_train_validation", "-stv", action="store_true")
parser.add_argument(
"--short_context_model",
"-scm",
action="store_true",
help="Specifies if using models with shorter context length of 2048 tokens e.g. 5B and 20B models",
)
parser.add_argument(
"--val_proportion",
"-vp",
Expand All @@ -409,8 +416,13 @@ def print_all_messages(messages):
messages = []
messages.append(str(args))

if args.short_context_model:
MAX_TOKEN_LENGTH = 2048
else:
MAX_TOKEN_LENGTH = 4096

# every token is around 4 chars
MAX_TOTAL_CHAR_LENGTH = 4 * 2048
MAX_TOTAL_CHAR_LENGTH = 4 * MAX_TOKEN_LENGTH

df, message = load_file_into_df(args.filename)
messages.append(message)
Expand Down

0 comments on commit b5699b4

Please sign in to comment.