Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Taproot Compiler #291

Merged
merged 5 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions src/policy/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::marker::PhantomData;
use std::sync::Arc;
use std::{cmp, error, f64, fmt, hash, mem};

use crate::miniscript::context::SigType;
use crate::miniscript::limits::MAX_PUBKEYS_PER_MULTISIG;
use crate::miniscript::types::{self, ErrorKind, ExtData, Property, Type};
use crate::miniscript::ScriptContext;
Expand All @@ -35,7 +36,7 @@ type PolicyCache<Pk, Ctx> =

///Ordered f64 for comparison
#[derive(Copy, Clone, PartialEq, PartialOrd, Debug)]
struct OrdF64(f64);
pub(crate) struct OrdF64(pub f64);

impl Eq for OrdF64 {}
impl Ord for OrdF64 {
Expand Down Expand Up @@ -987,18 +988,23 @@ where
})
.collect();

if key_vec.len() == subs.len() && subs.len() <= MAX_PUBKEYS_PER_MULTISIG {
insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec)));
}
// Not a threshold, it's always more optimal to translate it to and()s as we save the
// resulting threshold check (N EQUAL) in any case.
else if k == subs.len() {
let mut policy = subs.first().expect("No sub policy in thresh() ?").clone();
for sub in &subs[1..] {
policy = Concrete::And(vec![sub.clone(), policy]);
match Ctx::sig_type() {
SigType::Schnorr if key_vec.len() == subs.len() => {
insert_wrap!(AstElemExt::terminal(Terminal::MultiA(k, key_vec)))
}
SigType::Ecdsa
if key_vec.len() == subs.len() && subs.len() <= MAX_PUBKEYS_PER_MULTISIG =>
{
insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec)))
}
_ if k == subs.len() => {
let mut it = subs.iter();
let mut policy = it.next().expect("No sub policy in thresh() ?").clone();
policy = it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()]));

ret = best_compilations(policy_cache, &policy, sat_prob, dissat_prob)?;
ret = best_compilations(policy_cache, &policy, sat_prob, dissat_prob)?;
}
_ => {}
}

// FIXME: Should we also optimize thresh(1, subs) ?
Expand Down Expand Up @@ -1549,6 +1555,17 @@ mod tests {
))
);
}

#[test]
fn compile_tr_thresh() {
for k in 1..4 {
let small_thresh: Concrete<String> =
policy_str!("{}", &format!("thresh({},pk(B),pk(C),pk(D))", k));
let small_thresh_ms: Miniscript<String, Tap> = small_thresh.compile().unwrap();
let small_thresh_ms_expected: Miniscript<String, Tap> = ms_str!("multi_a({},B,C,D)", k);
assert_eq!(small_thresh_ms, small_thresh_ms_expected);
}
}
}

#[cfg(all(test, feature = "unstable"))]
Expand Down
208 changes: 200 additions & 8 deletions src/policy/concrete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,27 @@ use std::{error, fmt, str};

use bitcoin::hashes::hex::FromHex;
use bitcoin::hashes::{hash160, ripemd160, sha256, sha256d};
#[cfg(feature = "compiler")]
use {
crate::descriptor::TapTree,
crate::miniscript::ScriptContext,
crate::policy::compiler::CompilerError,
crate::policy::compiler::OrdF64,
crate::policy::{compiler, Concrete, Liftable, Semantic},
crate::Descriptor,
crate::Miniscript,
crate::Tap,
std::cmp::Reverse,
std::collections::{BinaryHeap, HashMap},
std::sync::Arc,
};

use super::ENTAILMENT_MAX_TERMINALS;
use crate::expression::{self, FromTree};
use crate::miniscript::limits::{HEIGHT_TIME_THRESHOLD, SEQUENCE_LOCKTIME_TYPE_FLAG};
use crate::miniscript::types::extra_props::TimeLockInfo;
#[cfg(feature = "compiler")]
use crate::miniscript::ScriptContext;
#[cfg(feature = "compiler")]
use crate::policy::compiler;
#[cfg(feature = "compiler")]
use crate::policy::compiler::CompilerError;
#[cfg(feature = "compiler")]
use crate::Miniscript;
use crate::{errstr, Error, ForEach, ForEachKey, MiniscriptKey};

/// Concrete policy which corresponds directly to a Miniscript structure,
/// and whose disjunctions are annotated with satisfaction probabilities
/// to assist the compiler
Expand Down Expand Up @@ -128,6 +135,136 @@ impl fmt::Display for PolicyError {
}

impl<Pk: MiniscriptKey> Policy<Pk> {
/// Flatten the [`Policy`] tree structure into a Vector of tuple `(leaf script, leaf probability)`
/// with leaf probabilities corresponding to odds for sub-branch in the policy.
/// We calculate the probability of selecting the sub-branch at every level and calculate the
/// leaf probabilities as the probability of traversing through required branches to reach the
/// leaf node, i.e. multiplication of the respective probabilities.
///
/// For example, the policy tree: OR
/// / \
/// 2 1 odds
/// / \
/// A OR
/// / \
/// 3 1 odds
/// / \
/// B C
///
/// gives the vector [(2/3, A), (1/3 * 3/4, B), (1/3 * 1/4, C)].
#[cfg(feature = "compiler")]
fn to_tapleaf_prob_vec(&self, prob: f64) -> Vec<(f64, Policy<Pk>)> {
match *self {
Policy::Or(ref subs) => {
let total_odds: usize = subs.iter().map(|(ref k, _)| k).sum();
subs.iter()
.map(|(k, ref policy)| {
policy.to_tapleaf_prob_vec(prob * *k as f64 / total_odds as f64)
})
.flatten()
.collect::<Vec<_>>()
}
Policy::Threshold(k, ref subs) if k == 1 => {
let total_odds = subs.len();
subs.iter()
.map(|policy| policy.to_tapleaf_prob_vec(prob / total_odds as f64))
.flatten()
.collect::<Vec<_>>()
}
ref x => vec![(prob, x.clone())],
}
}

/// Compile [`Policy::Or`] and [`Policy::Threshold`] according to odds
#[cfg(feature = "compiler")]
fn compile_tr_policy(&self) -> Result<TapTree<Pk>, Error> {
let leaf_compilations: Vec<_> = self
.to_tapleaf_prob_vec(1.0)
.into_iter()
.filter(|x| x.1 != Policy::Unsatisfiable)
.map(|(prob, ref policy)| (OrdF64(prob), compiler::best_compilation(policy).unwrap()))
.collect();
let taptree = with_huffman_tree::<Pk>(leaf_compilations).unwrap();
Ok(taptree)
}

/// Extract the internal_key from policy tree.
#[cfg(feature = "compiler")]
fn extract_key(self, unspendable_key: Option<Pk>) -> Result<(Pk, Policy<Pk>), Error> {
let mut internal_key: Option<Pk> = None;
{
let mut prob = 0.;
let semantic_policy = self.lift()?;
let concrete_keys = self.keys();
let key_prob_map: HashMap<_, _> = self
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of internal-key extraction has been changed from selecting first-encountered key satisfying the policy to most-probable key satisfying the policy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the ertract key is updated in next commit. No need to fix it here.

.to_tapleaf_prob_vec(1.0)
.into_iter()
.filter(|(_, ref pol)| match *pol {
Concrete::Key(..) => true,
_ => false,
})
.map(|(prob, key)| (key, prob))
.collect();

for key in concrete_keys.into_iter() {
if semantic_policy
.clone()
.satisfy_constraint(&Semantic::KeyHash(key.to_pubkeyhash()), true)
== Semantic::Trivial
{
match key_prob_map.get(&Concrete::Key(key.clone())) {
Some(val) => {
if *val > prob {
prob = *val;
internal_key = Some(key.clone());
}
}
None => return Err(errstr("Key should have existed in the HashMap!")),
}
}
}
}
match (internal_key, unspendable_key) {
(Some(ref key), _) => Ok((key.clone(), self.translate_unsatisfiable_pk(&key))),
(_, Some(key)) => Ok((key, self)),
_ => Err(errstr("No viable internal key found.")),
}
}

/// Compile the [`Policy`] into a [`Tr`][`Descriptor::Tr`] Descriptor.
///
/// ### TapTree compilation
///
/// The policy tree constructed by root-level disjunctions over [`Or`][`Policy::Or`] and
/// [`Thresh`][`Policy::Threshold`](1, ..) which is flattened into a vector (with respective
/// probabilities derived from odds) of policies.
/// For example, the policy `thresh(1,or(pk(A),pk(B)),and(or(pk(C),pk(D)),pk(E)))` gives the vector
/// `[pk(A),pk(B),and(or(pk(C),pk(D)),pk(E)))]`. Each policy in the vector is compiled into
/// the respective miniscripts. A Huffman Tree is created from this vector which optimizes over
/// the probabilitity of satisfaction for the respective branch in the TapTree.
// TODO: We might require other compile errors for Taproot.
#[cfg(feature = "compiler")]
pub fn compile_tr(&self, unspendable_key: Option<Pk>) -> Result<Descriptor<Pk>, Error> {
self.is_valid()?; // Check for validity
match self.is_safe_nonmalleable() {
(false, _) => Err(Error::from(CompilerError::TopLevelNonSafe)),
(_, false) => Err(Error::from(
CompilerError::ImpossibleNonMalleableCompilation,
)),
_ => {
let (internal_key, policy) = self.clone().extract_key(unspendable_key)?;
let tree = Descriptor::new_tr(
internal_key,
match policy {
Policy::Trivial => None,
policy => Some(policy.compile_tr_policy()?),
},
)?;
Ok(tree)
}
}
}

/// Compile the descriptor into an optimized `Miniscript` representation
#[cfg(feature = "compiler")]
pub fn compile<Ctx: ScriptContext>(&self) -> Result<Miniscript<Pk, Ctx>, CompilerError> {
Expand Down Expand Up @@ -226,6 +363,30 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
}
}

/// Translate `Concrete::Key(key)` to `Concrete::Unsatisfiable` when extracting TapKey
pub fn translate_unsatisfiable_pk(self, key: &Pk) -> Policy<Pk> {
match self {
Policy::Key(ref k) if k.clone() == *key => Policy::Unsatisfiable,
Policy::And(subs) => Policy::And(
subs.into_iter()
.map(|sub| sub.translate_unsatisfiable_pk(key))
.collect::<Vec<_>>(),
),
Policy::Or(subs) => Policy::Or(
subs.into_iter()
.map(|(k, sub)| (k, sub.translate_unsatisfiable_pk(key)))
.collect::<Vec<_>>(),
),
Policy::Threshold(k, subs) => Policy::Threshold(
k,
subs.into_iter()
.map(|sub| sub.translate_unsatisfiable_pk(key))
.collect::<Vec<_>>(),
),
x => x,
}
}

/// Get all keys in the policy
pub fn keys(&self) -> Vec<&Pk> {
match *self {
Expand Down Expand Up @@ -645,3 +806,34 @@ where
Policy::from_tree_prob(top, false).map(|(_, result)| result)
}
}

/// Create a Huffman Tree from compiled [Miniscript] nodes
#[cfg(feature = "compiler")]
fn with_huffman_tree<Pk: MiniscriptKey>(
ms: Vec<(OrdF64, Miniscript<Pk, Tap>)>,
) -> Result<TapTree<Pk>, Error> {
let mut node_weights = BinaryHeap::<(Reverse<OrdF64>, TapTree<Pk>)>::new();
for (prob, script) in ms {
node_weights.push((Reverse(prob), TapTree::Leaf(Arc::new(script))));
}
if node_weights.is_empty() {
return Err(errstr("Empty Miniscript compilation"));
}
while node_weights.len() > 1 {
let (p1, s1) = node_weights.pop().expect("len must atleast be two");
let (p2, s2) = node_weights.pop().expect("len must atleast be two");

let p = (p1.0).0 + (p2.0).0;
node_weights.push((
Reverse(OrdF64(p)),
TapTree::Tree(Arc::from(s1), Arc::from(s2)),
));
}

debug_assert!(node_weights.len() == 1);
let node = node_weights
.pop()
.expect("huffman tree algorithm is broken")
.1;
Ok(node)
}
Loading