-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 0271f6e
Showing
66 changed files
with
6,288 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
*.pt | ||
|
||
*.egg-info/ | ||
dist/ | ||
|
||
.vscode | ||
.ipynb_checkpoints/ | ||
|
||
.env* | ||
env*/ | ||
|
||
explorations/ | ||
models/ | ||
|
||
.DS_Store | ||
.neptune |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Minimal makefile for Sphinx documentation | ||
# | ||
|
||
# You can set these variables from the command line, and also | ||
# from the environment for the first two. | ||
SPHINXOPTS ?= | ||
SPHINXBUILD ?= sphinx-build | ||
SOURCEDIR = doc | ||
BUILDDIR = docbuild | ||
|
||
# Put it first so that "make" without argument is like "make help". | ||
help: | ||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | ||
|
||
.PHONY: help Makefile | ||
|
||
# Catch-all target: route all unknown targets to Sphinx using the new | ||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). | ||
%: Makefile | ||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
# 🕺🏽 disco: A Toolkit for Distributional Control of Generative Models | ||
|
||
The 🕺🏽 **disco** toolkit allows to control the properties of the generations by language models and other generative systems to match human preferences while avoiding catastrophic forgetting. | ||
|
||
To achieve this in **disco**, we first represent in what ways we want to update original model as a target distribution and then, generate samples from this new distribution through a combination of learning or monte-carlo methods, as follows. | ||
|
||
**Step 1: We express how the target distribution *should* be** | ||
|
||
To have a handle on the generative model, we define some feature over the generated samples. It can be anything we can compute. For example, on a language model it can be as simple as whether the generated text contains a certain word or as complex as the compilability of some generated piece of code. Importantly, there is no need for the feature to be differentiable. | ||
Then, we can express our preferences on the target distribution by defining the target *moments* of this feature. For example, we might want to ask that a certain word appears 50% of the time when sampling from the model; or that 100% of the generated code is compilable. The resulting target distribution is expressed as an energy-based model or EBM, which is an unnormalized probability distribution that respects the desired moments while avoiding catastrophic forgetting, as a result of having minimal KL divergence to the original model. | ||
This representation of the target distribution can *score* samples, but cannot directly be used to *generate* them. | ||
|
||
**Step 2: We generate samples from the target distribution** | ||
|
||
To generate samples from the target distribution, if not perfectly, we can tune a model to approximate it. The resulting model can generate samples directly from a close approximation of the target distribution. Furthermore, it can be used jointly with Quasi-Rejection Sampling (QRS), a Monte Carlo sampling technique that allows the generation of samples that are even more representative of the target distribution. | ||
Alternatively, it is then possible to use decoding methods such as nucleus sampling, top-k sampling, or beam search, which would return samples from a further updated target distribution. | ||
|
||
See the references below for more theoretical and technical details. | ||
|
||
## Installation | ||
|
||
### Standard installation | ||
|
||
The easiest way to install **disco** is to rely on pip, asking for the ```disco-generation``` package: | ||
|
||
``` | ||
pip install disco-generation | ||
``` | ||
|
||
Note that the toolkit: | ||
- depends on PyTorch, | ||
- uses HuggingFace's Transformers library for the generic handling of language models, as well as the generation of samples from them. | ||
|
||
### Toolkit Developers | ||
|
||
If we plan to extend the toolkit we will need to clone and to install it as a local package. | ||
From the toolkit top folder, once we've git-cloned the repository and activated our development environment, we simply do: | ||
``` | ||
pip install -e . | ||
``` | ||
|
||
## Quick introduction | ||
|
||
|
||
### Distributions | ||
|
||
The generative model that we want to tune must be wrapped by a `Distribution` object. For example, for a (causal or seq2seq) language model compatible with the 🤗 Hugging Face interface use an `LMDistribution`. | ||
|
||
A valid `Distribution` must have the following two methods: | ||
- `.sample(context)` that given an optional `context` on which the distribution can be conditioned, returns a list of samples from the underlying distribution and a tensor with their corresponding log-probabilities; | ||
- `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities. | ||
|
||
```python | ||
from disco.distributions import LMDistribution | ||
distribution = LMDistribution() | ||
|
||
incipit = "It was a cold and stormy night" | ||
samples, log_scores = distribution.sample(context=incipit) | ||
|
||
distribution.log_score(samples, context=incipit) | ||
``` | ||
|
||
`LMDistribution` generate samples, with the `TextSample` type, which are named tuples with both a `text` and `token_ids` fields. | ||
|
||
<small>From now on, after this initial example, imports will be skipped for clarity.</small> | ||
|
||
### Features | ||
|
||
Features are represented by an object with the method | ||
|
||
- `.score(samples, context)` which given a list of samples and an eventual context returns a tensor of real-valued scores. | ||
|
||
A convenient way to define one is using the `Scorer` class, which accepts a function or a lambda abstraction that takes sample and a context, and vectorizes it. For example, we can compute the effective length of a GPT-2 text sample by finding the eos token: | ||
|
||
```python | ||
sequence_length = Scorer(lambda s, c: s.text.index("<|endoftext|>")) | ||
``` | ||
|
||
Where `s` is the sample (assumed to be a `TextSample`) and `c` is an eventual context. | ||
|
||
#### Boolean Features | ||
|
||
An important class of features are *boolean* features. While general features can only be used to define *distributional* constraints, boolean features can also be used to define *pointwise* constraints, see below. To define one, we can use the `BooleanScorer` helper class, which takes a function as an argument. For example, we can score the presence of the string "amazing", as follows: | ||
|
||
```python | ||
amazing = BooleanScorer(lambda s, c: "amazing" in s.text) | ||
``` | ||
|
||
The ```False```/```True``` results from the lambda are casted to `0.0`/`1.0` float values so that they can be used in the EBM definition. | ||
|
||
`BooleanScorer` belongs to the more general `PositiveScorer` class, which can be used to construct EBMs. The main properties of a `PostiveScorer` are that first, it returns positive scorers, and second that it provides the method: | ||
|
||
- `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities. | ||
|
||
As a consequence, we can see that a ```Distribution``` is also a ```PositiveScorer``` that is able to sample as well. | ||
|
||
|
||
### Controlling Generation | ||
|
||
#### Expressing preferences through an EBM | ||
|
||
We express preferences over the distribution by defining target moments for specific features. This results in a target distribution that matches the desired moments while minimizing the KL divergence to the original distribution. In other words, it incorporates the preferences while avoiding catastrophic forgetting. This distribution is represented as an EBM, which can be used to score samples, in other words it is a `PositiveScorer`, but cannot be used to sample, we'll see how to sample below. | ||
|
||
We can express either *pointwise* or *distributional* constraints on a distribution and compose them at will. The former expresses a (boolean) property that must apply to *all* sequences, whereas the latter represents properties at the distributional level. | ||
|
||
To obtain the target distribution that incorporates our constraints, we use the `constraint` method of the corresponding `Distribution`. This method takes a list of features and their corresponding target moments. | ||
|
||
For example, we can define an EBM with a *pointwise* constraint requiring that all our samples must include "amazing" by setting the target moment to `1` on a `BooleanFeature`: | ||
|
||
```python | ||
target_ebm = base.constrain([amazing], [1]) | ||
``` | ||
|
||
Or we can ask for a _distributional_ constraint requiring that _half_ of the samples include "amazing": | ||
|
||
```python | ||
target_ebm = base.constrain([amazing], [1/2]) | ||
``` | ||
|
||
|
||
#### Approximating the target EBM | ||
|
||
|
||
Given an EBM target distribution, we now want to train a model to approximate it so that we can use it to generate samples. In the _unconditional_ case, namely when there is a single fixed context used in generation, then we can use a `Tuner`, more specifically a ```DPGTuner```, as follows. | ||
|
||
|
||
```python | ||
target_ebm = base.constrain([amazing], [1]) | ||
|
||
model = LMDistribution(freeze=False) | ||
incipit = "It was a cold and stormy night" | ||
|
||
tuner = DPGTuner(model, target_ebm, context=incipit) | ||
tuner.tune() | ||
``` | ||
|
||
And we can sample _amazing_ sequences from the tuned model. | ||
```python | ||
samples, log_scores = model.sample(context=incipit) | ||
for s in samples: | ||
print(incipit + s.text) | ||
``` | ||
|
||
##### Tuning parameters | ||
|
||
Important parameters of the `Tuner` include: | ||
|
||
- `n_gradient_steps`: number of total gradient steps in the full tuning process; | ||
- `n_samples_per_step`: total number of samples used in performing a gradient step (aka batch size); | ||
- `scoring_size`: number of samples sent in a batch to the `.score` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results; | ||
- `sampling_size`: number of samples obtained from a single call to the `.sample` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results; | ||
- `features`: list of pairs (`name`, `feature`) so that the `feature` moments will be computed by importance sampling (and reported using the key given by `name`); | ||
- `track_divergence_from_base`: set to True to track the reverse KL divergence from the original model —this requires an additional round of samples' scoring). | ||
|
||
#### Logging | ||
|
||
The Tuner reports a number of metrics that are useful to monitor the training progress. A number of `Logger` classes are provided to keep track of these metrics. Basic logging is provided though the console, as follows: | ||
|
||
```python | ||
console_logger = ConsoleLogger(tuner) | ||
``` | ||
|
||
However, more detailed statistics can be kept trhough a JSON/WandB/Neptune loggers: | ||
|
||
```python | ||
project = "example_project" | ||
name = "run_01" | ||
json_logger = JSONLogger(tuner, project, name) | ||
neptune_logger = NeptuneLogger(tuner, project, name) | ||
wandb_logger = WandBLogger(tuner, project, name) | ||
``` | ||
|
||
where `project` and `name` refer to the project and run name, respectively. | ||
|
||
##### Logged Metrics | ||
|
||
Loggers store a number of metrics about the training process. Here we list a few of the most relevant ones: | ||
|
||
- `kl_target_model` and `kl_target_proposal`: estimates of the forward KL divergence to the target EBM from the tuned model and the proposal distribution, respectively. In the case of using online training, the two are equivalent with the only caveat that `kl_target_model` is computed —this is the metric being optimized, and not the value reported as `loss`; | ||
- `kl_model_base` and `kl_proposal_base`: estimates of the reverse KL divergence to the original model of the tuned model and the proposal distribution, respectively —only reported if `track_divergence_from_base` is set to True; | ||
- Feature moments: Estimate of the features' moments for those features specified with the `features` parameter at the Tuner's construction time. | ||
|
||
### Controlled Conditional Generation | ||
|
||
The _conditional_ case is superficially very similar, with an extra step needed to instantiate a `ContextDistribution`, which allows to sample contexts that can then be used to condition the model. Furthermore, we use the more general ```CDPGTuner``` class. | ||
|
||
Assuming we have a file of incipits, one per line, in a `data/incipits.txt` file, we could do: | ||
|
||
```python | ||
target_ebm = base.constrain([amazing], [1]) | ||
|
||
model = LMDistribution(freeze=False) | ||
|
||
tuner = CDPGTuner(model, target_ebm, | ||
context_distribution=ContextDistribution("data/incipits.txt"), | ||
context_sampling_size=2**3) | ||
tuner.tune() | ||
``` | ||
|
||
Note that while we have used a decoder-only model here for illustrative purposes, the real power of the CDPGTuner is that it allows to control _seq2seq models_ such as those used in NMT, summarization, etc... Please refer to the dedicated [tutorial notebook](tutorials/4.conditional_tuning.ipynb) for an example of how to control an actual conditional model. | ||
|
||
|
||
#### Monte-Carlo sampling to improve the approximation | ||
|
||
After the tuning is done, `model` is now a better approximation to the target EBM, but it is not guaranteed to perfectly match this distribution. While further training can improve the situation, another alternative is using [quasi-rejection sampling (QRS)](https://disco.europe.naverlabs.com/QRS/), a Monte-Carlo sampling technique that allows to trade-off sampling efficiency for a higher fidelity to the target distribution —a higher value of `beta` yields a better fidelity although at a higher computational cost. | ||
|
||
```python | ||
beta=0.5 | ||
sampler = QuasiRejectionSampler(target_ebm, model, beta=beta) | ||
samples, log_scores = sampler.sample(sampling_size=2**7) | ||
``` | ||
|
||
#### In summary | ||
|
||
To put some of this (distributional constraint, tuning in the unconditional case and using QRS) together: | ||
|
||
```python | ||
base = LMDistribution() | ||
target_ebm = base.constrain([amazing], [1/2]) | ||
|
||
model = LMDistribution(freeze=False) | ||
|
||
tuner = DPGTuner(model, target_ebm) | ||
tuner.tune() | ||
|
||
beta=0.5 | ||
sampler = QuasiRejectionSampler(target_ebm, model, beta=beta) | ||
samples, log_scores = sampler.sample(context=incipit, sampling_size=2**7) | ||
``` | ||
|
||
### Going further | ||
|
||
A few things to keep in mind while reading the following paragraphs showing the principles of **disco**: | ||
1. this is only an introduction, skipping some details and relying on toyish use cases; | ||
1. the notebooks in the tutorials folder go in more depth, on more use cases; | ||
1. the focus here and in most notebooks is on natural language, but again the toolkit can be used to control distributions over sequences such as code or chess moves, or even other data types, as long as they respect the basic assumptions of a disco `Distribution` object. | ||
|
||
## References | ||
|
||
The **disco** toolkit implements the theoretical framework presented in the following works: | ||
- A Distributional Approach to Controlled Text Generation, Khalifa et al., 2021, <https://openreview.net/forum?id=jWkw45-9AbL>, ICLR; | ||
- An approximate sampler for energy-based models with divergence diagnostics, Eikema et al., 2022, <https://openreview.net/forum?id=VW4IrC0n0M>, TMLR; | ||
- Energy-Based Models for Code Generation under Compilability Constraints, Korbak et al., 2021, <https://arxiv.org/abs/2106.04985>, ACL (Workshop on Natural Language Processing for Programming); | ||
- Controlling Conditional Language Models without Catastrophic Forgetting, Korbak et al., 2022, <https://proceedings.mlr.press/v162/korbak22a.html>, ICML; | ||
- On Reinforcement Learning and Distribution Matching for Fine-Tuning Language Models with no Catastrophic Forgetting, Korbak et al., 2022, <https://openreview.net/forum?id=XvI6h-s4un>, NeurIPS. | ||
|
||
## License | ||
|
||
See [LICENSE](LICENSE) file. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# disco | ||
# Copyright (C) 2022-present NAVER Corp. | ||
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license | ||
|
||
__version__ = "1.0.0" | ||
__author__ = 'Naver Labs Europe' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# disco | ||
# Copyright (C) 2022-present NAVER Corp. | ||
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license | ||
|
||
from .distribution import Distribution | ||
from .base_distribution import BaseDistribution | ||
from .lm_distribution import LMDistribution | ||
from .single_context_distribution import SingleContextDistribution | ||
from .context_distribution import ContextDistribution | ||
from .dataset_context_distribution import DatasetContextDistribution |
Oops, something went wrong.