Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
942a7c6
add: the contrastive search for generaton_utils
gmftbyGMFTBY Oct 10, 2022
5909423
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 10, 2022
3e71819
add: testing scripts for contrastive search under examples/text-gener…
gmftbyGMFTBY Oct 10, 2022
47b2b3b
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 11, 2022
41e37a5
update the quality of codes
gmftbyGMFTBY Oct 11, 2022
e278a46
Merge branch 'csearch-pr-v2' of https://github.com/gmftbyGMFTBY/trans…
gmftbyGMFTBY Oct 11, 2022
38b100f
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
9abd1bb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
32e2a30
revise the docstring; make the generation_contrastive_search.py scripts;
gmftbyGMFTBY Oct 12, 2022
e9e2b26
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
1f1dac2
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
6226b9a
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
f3bfd87
revise the examples/pytorch/text-generation/run_generation_contrastiv…
gmftbyGMFTBY Oct 12, 2022
d3a91b8
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
ce26f9f
revise the necessary documents
gmftbyGMFTBY Oct 12, 2022
e801c6f
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
c78cf91
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
fb4174e
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
68429ad
fix: revise the docstring of generation_contrastive_search.py
gmftbyGMFTBY Oct 13, 2022
e1f0db9
Fix the code indentation
gmftbyGMFTBY Oct 13, 2022
42d78be
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
d2a5e02
fix: revise the nits and examples in contrastive_search docstring.
gmftbyGMFTBY Oct 13, 2022
1d4f782
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 14, 2022
d5f90fb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 14, 2022
d5d30b7
fix the copyright
gmftbyGMFTBY Oct 14, 2022
49000c6
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 17, 2022
3058e1c
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 17, 2022
c344a0a
delete generation_contrastive_search.py
gmftbyGMFTBY Oct 17, 2022
628ecda
Merge branch 'csearch-pr-v2' of https://github.com/gmftbyGMFTBY/trans…
gmftbyGMFTBY Oct 17, 2022
5ae4ce2
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
7af4cbb
revise the logic in contrastive_search
gmftbyGMFTBY Oct 18, 2022
b219a17
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
183d7cc
update the intergration test and the docstring
gmftbyGMFTBY Oct 18, 2022
65a1ebd
run the tests over
gmftbyGMFTBY Oct 18, 2022
4972bfb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
e11d342
add the slow decorate to the contrastive_search intergrate test
gmftbyGMFTBY Oct 18, 2022
da014bb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 19, 2022
2aa768c
add more test
gmftbyGMFTBY Oct 19, 2022
ced9f70
do the style, quality, consistency checks
gmftbyGMFTBY Oct 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.

This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`],
[`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.contrastive_search`],
[`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`],
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/main_classes/text_generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search

Expand Down
138 changes: 138 additions & 0 deletions examples/pytorch/text-generation/run_generation_contrastive_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The examples of running contrastive search on the auto-APIs;

Running this example:
python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256
"""


import argparse
import logging

import numpy as np
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer


logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)


def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
)
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--penalty_alpha", type=float, default=0.0)
parser.add_argument("--p", type=float, default=0.9)

parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
args = parser.parse_args()

args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()

logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")

set_seed(args)

# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)

# tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
# model = OPTForCausalLM.from_pretrained(args.model_name_or_path)
model.to(args.device)

if args.fp16:
model.half()

logger.info(args)
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")

inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
inputs = {key: value.to(args.device) for key, value in inputs.items()}

output_sequences = model.generate(
**inputs,
max_length=args.length + len(inputs["input_ids"][0]),
penalty_alpha=args.penalty_alpha,
top_k=args.k,
)

generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequence = generated_sequence.tolist()

# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False)

# Remove all text after the stop token
text = text[: text.find(args.stop_token) if args.stop_token else None]

# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :]
)

generated_sequences.append(total_sequence)
print(total_sequence)

return generated_sequences


if __name__ == "__main__":
main()
Loading