diff --git a/.gitignore b/.gitignore index 7ee2d42e2..9d97d6312 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,12 @@ build docs/_build venv .venv +**/*.so +**/*.dll +**/*.dylib + + +# Added by cargo + +/target +Cargo.lock diff --git a/.readthedocs.yml b/.readthedocs.yml index 3811c44a8..4bc3a0a44 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,10 +6,7 @@ build: os: ubuntu-22.04 tools: python: "3.10" - jobs: - pre_install: - - pip install git+https://github.com/angr/archinfo.git - - pip install git+https://github.com/angr/pyvex.git + rust: "latest" python: install: diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..a550702a3 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "claripy" +version = "0.1.0" +edition = "2021" + +[lib] +name = "clarirs" +crate-type = ["cdylib"] + +[dependencies] +md5 = "0.7.0" +pyo3 = { version = "0.22.0", features = ["extension-module"] } diff --git a/claripy/ast/base.py b/claripy/ast/base.py index d0efe4904..a41e686de 100644 --- a/claripy/ast/base.py +++ b/claripy/ast/base.py @@ -1,3 +1,4 @@ +import builtins import itertools import logging import math @@ -7,26 +8,17 @@ from collections import OrderedDict, deque from collections.abc import Iterable, Iterator from itertools import chain -from typing import TYPE_CHECKING, Generic, NoReturn, TypeVar +from typing import TYPE_CHECKING, NoReturn, TypeVar +import claripy.clarirs as clarirs from claripy import operations, simplifications from claripy.backend_manager import backends +from claripy.clarirs import ASTCacheKey from claripy.errors import BackendError, ClaripyOperationError, ClaripyReplacementError if TYPE_CHECKING: from claripy.annotation import Annotation -try: - import _pickle as pickle -except ImportError: - import pickle - -try: - # Python's build-in MD5 is about 2x faster than hashlib.md5 on short bytestrings - import _md5 as md5 -except ImportError: - import hashlib as md5 - l = logging.getLogger("claripy.ast") WORKER = bool(os.environ.get("WORKER", False)) @@ -39,20 +31,6 @@ T = TypeVar("T", bound="Base") -class ASTCacheKey(Generic[T]): - def __init__(self, a: T): - self.ast: T = a - - def __hash__(self): - return hash(self.ast) - - def __eq__(self, other): - return type(self) is type(other) and self.ast._hash == other.ast._hash - - def __repr__(self): - return f"" - - # # AST variable naming # @@ -68,7 +46,7 @@ def _make_name(name: str, size: int, explicit_name: bool = False, prefix: str = return name -def _d(h, cls, state): +def _unpickle(h, cls, state): """ This function is the deserializer for ASTs. It exists to work around the fact that pickle will (normally) call __new__() with no arguments during @@ -80,7 +58,7 @@ def _d(h, cls, state): ) -class Base: +class Base(clarirs.Base, metaclass=type): """ This is the base class of all claripy ASTs. An AST tracks a tree of operations on arguments. @@ -99,29 +77,7 @@ class Base: :ivar args: The arguments that are being used """ - __slots__ = [ - "op", - "args", - "variables", - "symbolic", - "_hash", - "_simplified", - "_cached_encoded_name", - "_cache_key", - "_errored", - "_eager_backends", - "length", - "_excavated", - "_burrowed", - "_uninitialized", - "_uc_alloc_depth", - "annotations", - "simplifiable", - "_uneliminatable_annotations", - "_relocatable_annotations", - "depth", - "__weakref__", - ] + __slots__ = () _hash_cache = weakref.WeakValueDictionary() _leaf_cache = weakref.WeakValueDictionary() @@ -153,9 +109,6 @@ def __new__(cls, op, args, add_variables=None, hash=None, **kwargs): # pylint:d :param annotations: A frozenset of annotations applied onto this AST. """ - # if any(isinstance(a, BackendObject) for a in args): - # raise Exception('asdf') - a_args = args if type(args) is tuple else tuple(args) # initialize the following properties: symbolic, variables and errored @@ -273,30 +226,34 @@ def __new__(cls, op, args, add_variables=None, hash=None, **kwargs): # pylint:d elif op in {"BVS", "BVV", "BoolS", "BoolV", "FPS", "FPV"} and not annotations: if op == "FPV" and a_args[0] == 0.0 and math.copysign(1, a_args[0]) < 0: # Python does not distinguish between +0.0 and -0.0 so we add sign to tuple to distinguish - h = (op, kwargs.get("length", None), ("-", *a_args)) + h = builtins.hash((op, kwargs.get("length", None), ("-", *a_args))) elif op == "FPV" and math.isnan(a_args[0]): # cannot compare nans - h = (op, kwargs.get("length", None), ("nan",) + a_args[1:]) + h = builtins.hash((op, kwargs.get("length", None), ("nan",) + a_args[1:])) else: - h = (op, kwargs.get("length", None), a_args) + h = builtins.hash((op, kwargs.get("length", None), a_args)) cache = cls._leaf_cache else: h = Base._calc_hash(op, a_args, kwargs) if hash is None else hash - self = cache.get(h, None) + self = cache.get(h & 0x7FFF_FFFF_FFFF_FFFF, None) if self is None: - self = super().__new__(cls) - depth = arg_max_depth + 1 - self.__a_init__( + # depth = arg_max_depth + 1 + self = super().__new__( + cls, op, - a_args, - depth=depth, + tuple(args), + kwargs.pop("length", 1), + frozenset(kwargs.pop("variables")), + kwargs.pop("symbolic"), + # annotations, + depth=arg_max_depth + 1, uneliminatable_annotations=uneliminatable_annotations, relocatable_annotations=relocatable_annotations, **kwargs, ) - self._hash = h - cache[h] = self + self._hash = h & 0x7FFF_FFFF_FFFF_FFFF + cache[self._hash] = self # else: # if self.args != a_args or self.op != op or self.variables != kwargs['variables']: # raise Exception("CRAP -- hash collision") @@ -309,29 +266,34 @@ def __init_with_annotations__( ): cache = cls._hash_cache h = Base._calc_hash(op, a_args, kwargs) - self = cache.get(h, None) + self = cache.get(h & 0x7FFF_FFFF_FFFF_FFFF, None) if self is not None: return self - self = super().__new__(cls) - self.__a_init__( + print("aaa") + self = super().__new__( + cls, op, - a_args, + tuple(a_args), + kwargs.pop("length", None), + frozenset(kwargs.pop("variables")), + kwargs.pop("symbolic"), + tuple(kwargs.pop("annotations", ())), depth=depth, uneliminatable_annotations=uneliminatable_annotations, relocatable_annotations=relocatable_annotations, **kwargs, ) - self._hash = h - cache[h] = self + self._hash = h & 0x7FFF_FFFF_FFFF_FFFF + cache[self._hash] = self return self def __reduce__(self): # HASHCONS: these attributes key the cache # BEFORE CHANGING THIS, SEE ALL OTHER INSTANCES OF "HASHCONS" IN THIS FILE - return _d, ( + return _unpickle, ( self._hash, self.__class__, (self.op, self.args, self.length, self.variables, self.symbolic, self.annotations), @@ -340,170 +302,6 @@ def __reduce__(self): def __init__(self, *args, **kwargs): pass - @staticmethod - def _calc_hash(op, args, keywords): - """ - Calculates the hash of an AST, given the operation, args, and kwargs. - - :param op: The operation. - :param args: The arguments to the operation. - :param keywords: A dict including the 'symbolic', 'variables', and 'length' items. - :returns: a hash. - - We do it using md5 to avoid hash collisions. - (hash(-1) == hash(-2), for example) - """ - args_tup = tuple(a if type(a) in (int, float) else getattr(a, "_hash", hash(a)) for a in args) - # HASHCONS: these attributes key the cache - # BEFORE CHANGING THIS, SEE ALL OTHER INSTANCES OF "HASHCONS" IN THIS FILE - - to_hash = Base._ast_serialize(op, args_tup, keywords) - if to_hash is None: - # fall back to pickle.dumps - to_hash = ( - op, - args_tup, - str(keywords.get("length", None)), - hash(keywords["variables"]), - keywords["symbolic"], - hash(keywords.get("annotations", None)), - ) - to_hash = pickle.dumps(to_hash, -1) - - # Why do we use md5 when it's broken? Because speed is more important - # than cryptographic integrity here. Then again, look at all those - # allocations we're doing here... fast python is painful. - hd = md5.md5(to_hash).digest() - return md5_unpacker.unpack(hd)[0] # 64 bits - - @staticmethod - def _arg_serialize(arg) -> bytes | None: - if arg is None: - return b"\x0f" - elif arg is True: - return b"\x1f" - elif arg is False: - return b"\x2e" - elif isinstance(arg, int): - if arg < 0: - if arg >= -0x7FFF: - return b"-" + struct.pack("= -0x7FFF_FFFF: - return b"-" + struct.pack("= -0x7FFF_FFFF_FFFF_FFFF: - return b"-" + struct.pack(" bytes | None: - """ - Serialize the AST and get a bytestring for hashing. - - :param op: The operator. - :param args_tup: A tuple of arguments. - :param keywords: A dict of keywords. - :return: The serialized bytestring. - """ - - serialized_args = Base._arg_serialize(args_tup) - if serialized_args is None: - return None - - if "length" in keywords: - length = Base._arg_serialize(keywords["length"]) - if length is None: - return None - else: - length = b"none" - - variables = struct.pack(" ASTCacheKey[T]: + def cache_key(self: T) -> ASTCacheKey: """ A key that refers to this AST - this value is appropriate for usage as a key in dictionaries. """ + if self._cache_key is None: + self._cache_key = ASTCacheKey(self) return self._cache_key @property @@ -527,17 +327,12 @@ def _encoded_name(self): # Collapsing and simplification # - # def _models_for(self, backend): - # for a in self.args: - # backend.convert_expr(a) - # else: - # yield backend.convert(a) - def make_like(self: T, op: str, args: Iterable, **kwargs) -> T: # Try to simplify the expression again simplified = simplifications.simpleton.simplify(op, args) if kwargs.pop("simplify", False) is True else None if simplified is not None: op = simplified.op + if ( simplified is None and len(kwargs) == 3 diff --git a/claripy/backends/__init__.py b/claripy/backends/__init__.py index baa65183a..397f0f565 100644 --- a/claripy/backends/__init__.py +++ b/claripy/backends/__init__.py @@ -180,7 +180,7 @@ def convert(self, expr): # pylint:disable=R0201 ) if self._cache_objects: - cached_obj = self._object_cache.get(ast._cache_key, None) + cached_obj = self._object_cache.get(ast.cache_key, None) if cached_obj is not None: arg_queue.append(cached_obj) continue @@ -214,7 +214,7 @@ def convert(self, expr): # pylint:disable=R0201 r = self.apply_annotation(r, a) if self._cache_objects: - self._object_cache[ast._cache_key] = r + self._object_cache[ast.cache_key] = r arg_queue.append(r) diff --git a/claripy/backends/backend_concrete.py b/claripy/backends/backend_concrete.py index c145109ca..1d674aaca 100644 --- a/claripy/backends/backend_concrete.py +++ b/claripy/backends/backend_concrete.py @@ -113,10 +113,10 @@ def convert(self, expr): Override Backend.convert() to add fast paths for BVVs and BoolVs. """ if type(expr) is BV and expr.op == "BVV": - cached_obj = self._object_cache.get(expr._cache_key, None) + cached_obj = self._object_cache.get(expr.cache_key, None) if cached_obj is None: cached_obj = self.BVV(*expr.args) - self._object_cache[expr._cache_key] = cached_obj + self._object_cache[expr.cache_key] = cached_obj return cached_obj if type(expr) is Bool and expr.op == "BoolV": return expr.args[0] diff --git a/pyproject.toml b/pyproject.toml index f365de0ec..c097b9111 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,55 @@ [build-system] -requires = ["setuptools>=46.4.0", "wheel"] +requires = ["setuptools", "setuptools-rust"] build-backend = "setuptools.build_meta" +[project] +name = "claripy" +version = "9.2.108.dev0" +description = "An abstraction layer for constraint solvers" +license = {text = "BSD-2-Clause"} +classifiers = [ + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +urls = {Homepage = "https://github.com/angr/claripy"} +requires-python = ">=3.10" +dependencies = [ + "cachetools", + "decorator", + "pysmt>=0.9.5", + "z3-solver==4.13.0.0", +] + +[project.readme] +file = "README.md" +content-type = "text/markdown" + +[project.optional-dependencies] +cvc4_solver = ["cvc4-solver"] +docs = [ + "furo", + "myst-parser", + "sphinx", + "sphinx-autodoc-typehints", +] +testing = [ + "pytest", + "pytest-xdist", + 'cvc4-solver;platform_system == "linux"', +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["claripy*"] +namespaces = false + +[[tool.setuptools-rust.ext-modules]] +target = "claripy.clarirs" + [tool.black] line-length = 120 target-version = ['py310'] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 2ba020a1a..000000000 --- a/setup.cfg +++ /dev/null @@ -1,41 +0,0 @@ -[metadata] -name = claripy -version = attr: claripy.__version__ -description = An abstraction layer for constraint solvers -long_description = file: README.md -long_description_content_type = text/markdown -url = https://github.com/angr/claripy -license = BSD-2-Clause -license_files = LICENSE -classifiers = - License :: OSI Approved :: BSD License - Programming Language :: Python :: 3 - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Programming Language :: Python :: 3.12 - -[options] -packages = find: -install_requires = - cachetools - decorator - pysmt>=0.9.5 - z3-solver==4.13.0.0 -python_requires = >=3.10 - -[options.extras_require] -cvc4_solver = - cvc4-solver -docs = - furo - myst-parser - sphinx - sphinx-autodoc-typehints -testing = - pytest - pytest-xdist - cvc4-solver;platform_system == "linux" - -[options.package_data] -claripy = py.typed diff --git a/src/ast/base.rs b/src/ast/base.rs new file mode 100644 index 000000000..8b4a8148d --- /dev/null +++ b/src/ast/base.rs @@ -0,0 +1,356 @@ +use std::borrow::Cow; + +use pyo3::{ + exceptions::PyValueError, + prelude::*, + types::{PyAnyMethods, PyBool, PyBytes, PyDict, PyFloat, PyInt, PySet, PyString, PyTuple}, +}; + +#[pyclass(weakref)] +pub struct ASTCacheKey { + #[pyo3(get)] + ast: PyObject, + #[pyo3(get)] + hash: isize, +} + +#[pymethods] +impl ASTCacheKey { + #[new] + pub fn new(ast: Bound) -> PyResult { + Ok(ASTCacheKey { + hash: ast.as_any().hash()?, + ast: ast.into(), + }) + } + + pub fn __hash__(&self) -> isize { + self.hash + } + + pub fn __eq__(&self, other: &Self) -> bool { + self.hash == other.hash + } + + pub fn __repr__(&self) -> String { + format!("", self.ast) + } +} + +#[pyclass(subclass, weakref)] +pub struct Base { + // Hashcons + #[pyo3(get, set)] + op: String, + #[pyo3(get, set)] + args: Py, + #[pyo3(get, set)] + length: usize, + #[pyo3(get, set)] + variables: PyObject, // TODO: This should be a HashSet, leave opaque for now + #[pyo3(get, set)] + symbolic: bool, + #[pyo3(get, set)] + annotations: Py, + + // Not Hashcons + #[pyo3(get)] + depth: usize, + + #[pyo3(get, set)] + _hash: Option, + #[pyo3(get, set)] + _simplified: Option, + #[pyo3(get, set)] + _cache_key: Option>, + #[pyo3(get, set)] + _cached_encoded_name: Option, + #[pyo3(get, set)] + _errored: Py, + #[pyo3(get, set)] + _eager_backends: Option, + #[pyo3(get, set)] + _excavated: Option, + #[pyo3(get, set)] + _burrowed: Option, + #[pyo3(get, set)] + _uninitialized: Option, + #[pyo3(get, set)] + _uc_alloc_depth: Option, + #[pyo3(get, set)] + _uneliminatable_annotations: Option, + #[pyo3(get, set)] + _relocatable_annotations: Option, +} + +#[pymethods] +impl Base { + #[new] + #[pyo3(signature = ( + op, + args, + length, + variables, + symbolic, + annotations=None, + simplified=None, + errored=None, + eager_backends=None, + uninitialized=None, + uc_alloc_depth=None, + encoded_name=None, + depth=None, + uneliminatable_annotations=None, + relocatable_annotations=None + ))] + fn new( + py: Python, + op: String, + args: Bound, + length: usize, + variables: Bound, + symbolic: bool, + annotations: Option>, + // New stuff + simplified: Option>, + errored: Option>, + eager_backends: Option>, + uninitialized: Option>, + uc_alloc_depth: Option>, + encoded_name: Option>, + depth: Option, + uneliminatable_annotations: Option>, + relocatable_annotations: Option>, + ) -> PyResult { + if args.len() == 0 { + return Err(PyValueError::new_err("AST with no arguments!")); // TODO: This should be a custom error + } + + let depth = depth.unwrap_or( + *args + .iter() + .map(|arg| { + arg.getattr("depth") + .and_then(|p| p.extract::()) + .or_else(|_| Ok(1)) + }) + .collect::, PyErr>>()? + .iter() + .max() + .unwrap_or(&0) + + 1, + ); + + Ok(Base { + op, + args: args.into(), + length, + variables: variables.into(), + symbolic, + annotations: annotations + .unwrap_or_else(|| PyTuple::empty_bound(py)) + .into(), + + depth, + + _hash: None, + _simplified: simplified.map(|s| s.into()), + _cache_key: None, + _cached_encoded_name: encoded_name.map(|s| s.into()), + _errored: errored.unwrap_or(PySet::empty_bound(py)?).into(), + _eager_backends: eager_backends.map(|s| s.into()), + _excavated: None, + _burrowed: None, + _uninitialized: uninitialized.map(|s| s.into()), + _uc_alloc_depth: uc_alloc_depth.map(|s| s.into()), + _uneliminatable_annotations: uneliminatable_annotations.map(|s| s.into()), + _relocatable_annotations: relocatable_annotations.map(|s| s.into()), + }) + } + + #[staticmethod] + fn _arg_serialize<'py>( + py: Python<'py>, + arg: Bound<'_, PyAny>, + ) -> PyResult>> { + if arg.is_none() { + return Ok(Some(Cow::from(vec![b'\x0f']))); + } + if arg.is(&*PyBool::new_bound(py, true)) { + return Ok(Some(Cow::from(vec![b'\x1f']))); + } + if arg.is(&*PyBool::new_bound(py, false)) { + return Ok(Some(Cow::from(vec![b'\x2e']))); + } + if arg.is_instance(&py.get_type_bound::())? { + let arg = arg.downcast::()?.extract::()?; + let mut result = Vec::new(); + if arg < 0 { + result.push(b'-'); + if arg >= -0x7FFF { + result.extend_from_slice(&(arg as i16).to_le_bytes()); + } else if arg >= -0x7FFF_FFFF { + result.extend_from_slice(&(arg as i32).to_le_bytes()); + } else if arg >= -0x7FFF_FFFF_FFFF_FFFF { + result.extend_from_slice(&(arg as i64).to_le_bytes()); + } else { + return Ok(None); + } + } else { + if arg <= 0xFFFF { + result.extend_from_slice(&(arg as i16).to_le_bytes()); + } else if arg <= 0xFFFF_FFFF { + result.extend_from_slice(&(arg as i32).to_le_bytes()); + } else if arg <= 0xFFFF_FFFF_FFFF_FFFF { + result.extend_from_slice(&(arg as i64).to_le_bytes()); + } else { + return Ok(None); + } + } + return Ok(Some(Cow::from(result))); + } + if arg.is_instance(&py.get_type_bound::())? { + let arg: String = arg.downcast::()?.extract()?; + return Ok(Some(Cow::from(arg.into_bytes()))); + } + if arg.is_instance(&py.get_type_bound::())? { + return Ok(Some(Cow::from(Vec::from( + arg.downcast::()?.extract::()?.to_le_bytes(), + )))); + } + if arg.is_instance(&py.get_type_bound::())? { + let mut result = Vec::new(); + for item in arg.downcast::()?.iter() { + if let Some(sub_result) = Self::_arg_serialize(py, item)? { + result.extend(sub_result.iter()); + } else { + return Ok(None); // Do we really want to return None here? + } + } + return Ok(Some(Cow::from(result))); + } + Ok(None) + } + + #[staticmethod] + fn _ast_serialize<'py>( + py: Python<'py>, + op: String, + args_tuple: Bound<'_, PyTuple>, + keywords: Bound<'_, PyDict>, // TODO: This should be a struct or seperate args + ) -> PyResult>> { + let serailized_args = match Base::_arg_serialize(py, args_tuple.into_any())? { + Some(args) => args, + None => return Ok(None), + }; + + let length = match keywords.contains("length")? { + true => match Base::_arg_serialize(py, keywords.get_item("length")?.unwrap())? { + Some(length) => length, + None => return Ok(None), + }, + false => Cow::from(Vec::from(b"none")), + }; + + // get_item was unchecked in the python version too + let variables = (keywords.get_item("variables")?.unwrap().hash()? as u64).to_le_bytes(); + // this one was unchecked too + let symbolic = match keywords.get_item("symbolic")?.unwrap().is_truthy()? { + true => Cow::from(Vec::from(b"\x01")), + false => Cow::from(Vec::from(b"\x00")), + }; + let annotations = match keywords.get_item("annotations")? { + Some(item) => Cow::from(Vec::from((item.hash()? as u64).to_le_bytes())), + None => Cow::from(Vec::from(b"\xf9")), + }; + + Ok(Some(Cow::from( + [ + op.as_bytes(), + &serailized_args, + &length, + &variables, + &symbolic, + &annotations, + ] + .concat(), + ))) + } + + #[staticmethod] + fn _calc_hash<'py>( + py: Python<'py>, + op: String, + args: Bound, + keywords: Bound, + ) -> PyResult { + let mut args_tuple = Vec::new(); + for arg in args.iter() { + if arg.is_instance(&py.get_type_bound::())? + || arg.is_instance(&py.get_type_bound::())? + { + args_tuple.push(arg); + } else { + if arg.hasattr("_hash")? { + args_tuple.push( + arg.getattr("_hash")? + .downcast::() + .unwrap() + .clone() + .into_any(), + ); + } else { + args_tuple.push( + // Call hash on the object + arg.call_method0("__hash__")? + .downcast::() + .unwrap() + .clone() + .into_any(), + ); + } + } + } + + let to_hash = match Base::_ast_serialize(py, op.clone(), args, keywords.clone())? { + Some(to_hash) => to_hash, + None => { + let hash_tuple: Bound = PyTuple::new_bound( + py, + vec![ + op.to_object(py).bind(py).as_ref(), + args_tuple.to_object(py).bind(py).as_ref(), + keywords + .get_item("length")? + .unwrap_or(py.None().into_bound(py)) + .str()? + .as_ref(), + keywords + .get_item("variables")? + .unwrap() // Unchecked unwrap in python version + .hash()? + .to_object(py) + .bind(py), + keywords.get_item("symbolic")?.unwrap().as_ref(), // Unchecked unwrap in python version + keywords + .get_item("annotations")? + .unwrap_or(py.None().into_bound(py)) + .hash()? + .to_object(py) + .bind(py), + ], + ); + Cow::from(Vec::from( + py.import_bound("pickle")? + .getattr("dumps")? + .call1(PyTuple::new_bound(py, vec![&hash_tuple]))? + .downcast_into::()? + .as_bytes(), + )) + } + }; + Ok(isize::from_be_bytes( + (md5::compute(to_hash).0)[..8].try_into().unwrap(), + )) + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs new file mode 100644 index 000000000..6cf245d4d --- /dev/null +++ b/src/ast/mod.rs @@ -0,0 +1 @@ +pub mod base; diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 000000000..b651dfda6 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,10 @@ +mod ast; + +use pyo3::prelude::*; + +#[pymodule] +fn clarirs(_py: Python, m: Bound) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/tests/common_backend_smt_solver.py b/tests/common_backend_smt_solver.py index 7dc3c16c4..bf8007763 100644 --- a/tests/common_backend_smt_solver.py +++ b/tests/common_backend_smt_solver.py @@ -1,29 +1,27 @@ -from unittest import skip +import typing +import unittest -from decorator import decorator from test_backend_smt import TestSMTLibBackend import claripy -# use of decorator instead of the usual pattern is important because nose2 will check the argspec and wraps does not -# preserve that! -@decorator -def if_installed(f, *args, **kwargs): - try: - return f(*args, **kwargs) - except claripy.errors.MissingSolverError: - return skip("Missing Solver")(f) +def if_installed(test_func: typing.Callable): + def wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except claripy.errors.MissingSolverError as exception: + raise unittest.SkipTest from exception + + return wrapper KEEP_TEST_PERFORMANT = True class SmtLibSolverTestBase(TestSMTLibBackend): - @skip def get_solver(self): - pass - # raise NotImplementedError + raise NotImplementedError @if_installed def test_concat(self): diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 92f31d14c..badbc0cb2 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -1,4 +1,6 @@ -# pylint:disable=missing-class-docstring,multiple-statements +# pylint:disable=missing-class-docstring +import unittest + import claripy @@ -55,149 +57,137 @@ def apply_annotation(self, o, a): return o + a.number -def test_backend(): - x = claripy.BVV(10, 32).annotate(AnnotationA("a", 1)) - assert BackendA().convert(x) == 11 - - -def test_simplification(): - x = claripy.BVS("x", 32).annotate(AnnotationA("a", 1)) - y = x ^ x - assert y.depth == 1 - assert len(y.annotations) == 0 - - x = claripy.BVS("x", 32).annotate(AnnotationB("a", 1)) - y = x ^ x - assert y.depth == 2 - - x = claripy.BVS("x", 32).annotate(AnnotationC("a", 1)) - y = x ^ x - assert y.depth == 1 - assert len(y.annotations) == 1 - assert y.annotations[0].number == 2 - - -def test_missing_annotations_from_simplification(): - relocatable_anno = AnnotationC("a", 2) - - x0 = claripy.BVS("x", 32) - x1 = claripy.BVV(24, 32) - k = (x1 + x0).annotate(relocatable_anno) - - x3 = claripy.simplify(k) - - assert len(x3.annotations) == 1 - - -def test_annotations(): - x = claripy.BVS("x", 32) + 1 - xx = x._apply_to_annotations(lambda a: a) - assert x is xx - - a1 = AnnotationA("a", 1) - a2 = AnnotationA("a", 1) - - x1 = x.annotate(a1) - x2 = x1.annotate(a2) - x2a = x.annotate(a1, a2) - x3 = x2.remove_annotation(a1) - x4 = x3.remove_annotation(a2) - x5 = x2.remove_annotations({a1, a2}) - - assert x.variables == x1.variables - assert x.variables == x2.variables - assert x.variables == x2a.variables - assert x.variables == x3.variables - assert x.variables == x4.variables - assert x.variables == x5.variables - - assert x is not x1 - assert x is not x2 - assert x is not x3 - assert x1 is not x2 - assert x1 is not x3 - assert x2 is not x3 - assert x2 is x2a - assert x is x4 - assert x is x5 - - assert x.op == x1.op - assert x.annotations == () - assert x1.annotations == (a1,) - assert x2.annotations == (a1, a2) - assert x3.annotations == (a2,) - - assert claripy.backends.z3.convert(x).eq(claripy.backends.z3.convert(x3)) - - const = claripy.BVV(1, 32) - consta = const.annotate(AnnotationB("a", 0)) - const1 = consta + 1 - const1a = const1.annotate(AnnotationB("b", 1)) - const2 = const1a + 1 - # const2 should be (const1a + 1), instead of (1 + 1 + 1) - # the flatten simplifier for __add__ should not be applied as AnnotationB is not relocatable (and not eliminatable) - assert const2.depth == 3 - - -def test_eagerness(): - x = claripy.BVV(10, 32).annotate(AnnotationD()) - y = x + 1 - assert y.annotations == x.annotations - - -def test_ast_hash_should_consider_relocatable_annotations(): - relocatable_anno = AnnotationC("a", 2) - const = claripy.BVV(1337, 32) - x0 = claripy.BVS("x", 32).annotate(relocatable_anno) - y0 = claripy.Concat(x0, const) - - # make the annotation not relocatable - # this is of course a hack, but it can demonstrate the problem - relocatable_anno._relocatable = False - x0._relocatable_annotations = frozenset() - - y1 = claripy.Concat(x0, const) - - assert len(y0.annotations) == 1 - assert len(y1.annotations) == 0 - assert y0._hash != y1._hash - - -def test_remove_relocatable_annotations(): - relocatable_anno = AnnotationC("a", 2) - const = claripy.BVV(1337, 32) - - x0 = claripy.BVS("x", 32).annotate(relocatable_anno) - y0 = claripy.Concat(x0, const) - assert len(y0.annotations) == 1 - assert y0.annotations == (relocatable_anno,) - - y1 = y0.remove_annotation(relocatable_anno) - - assert len(y1.annotations) == 0 - - -def test_duplicated_annotations_from_makelike(): - relocatable_anno = AnnotationC("a", 2) - - x0 = claripy.BVS("x", 32).annotate(relocatable_anno) - x1 = claripy.BVV(24, 32) - - # make_like() should not re-apply child annotations if the child is the make_like target - x2 = x0 + x1 - assert len(x2.annotations) == 1 - - # simplify() should not re-apply annotations since annotations are kept during the simplification process by - # make_like(). - x3 = claripy.simplify(x0 + x1) - assert len(x3.annotations) == 1 +class TestAnnotation(unittest.TestCase): + def test_backend(self): + x = claripy.BVV(10, 32).annotate(AnnotationA("a", 1)) + assert BackendA().convert(x) == 11 + + def test_simplification(self): + x = claripy.BVS("x", 32).annotate(AnnotationA("a", 1)) + y = x ^ x + assert y.depth == 1 + assert len(y.annotations) == 0 + + x = claripy.BVS("x", 32).annotate(AnnotationB("a", 1)) + y = x ^ x + assert y.depth == 2 + + x = claripy.BVS("x", 32).annotate(AnnotationC("a", 1)) + y = x ^ x + assert y.depth == 1 + assert len(y.annotations) == 1 + assert y.annotations[0].number == 2 + + def test_missing_annotations_from_simplification(self): + relocatable_anno = AnnotationC("a", 2) + + x0 = claripy.BVS("x", 32) + x1 = claripy.BVV(24, 32) + k = (x1 + x0).annotate(relocatable_anno) + + x3 = claripy.simplify(k) + + assert len(x3.annotations) == 1 + + def test_annotations(self): + x = claripy.BVS("x", 32) + 1 + xx = x._apply_to_annotations(lambda a: a) + assert x is xx + + a1 = AnnotationA("a", 1) + a2 = AnnotationA("a", 1) + + x1 = x.annotate(a1) + x2 = x1.annotate(a2) + x2a = x.annotate(a1, a2) + x3 = x2.remove_annotation(a1) + x4 = x3.remove_annotation(a2) + x5 = x2.remove_annotations({a1, a2}) + + assert x.variables == x1.variables + assert x.variables == x2.variables + assert x.variables == x2a.variables + assert x.variables == x3.variables + assert x.variables == x4.variables + assert x.variables == x5.variables + + assert x is not x1 + assert x is not x2 + assert x is not x3 + assert x1 is not x2 + assert x1 is not x3 + assert x2 is not x3 + assert x2 is x2a + assert x is x4 + assert x is x5 + + assert x.op == x1.op + assert x.annotations == () + assert x1.annotations == (a1,) + assert x2.annotations == (a1, a2) + assert x3.annotations == (a2,) + + assert claripy.backends.z3.convert(x).eq(claripy.backends.z3.convert(x3)) + + const = claripy.BVV(1, 32) + consta = const.annotate(AnnotationB("a", 0)) + const1 = consta + 1 + const1a = const1.annotate(AnnotationB("b", 1)) + const2 = const1a + 1 + # const2 should be (const1a + 1), instead of (1 + 1 + 1) + # the flatten simplifier for __add__ should not be applied as AnnotationB is not relocatable (and not eliminatable) + assert const2.depth == 3 + + def test_eagerness(self): + x = claripy.BVV(10, 32).annotate(AnnotationD()) + y = x + 1 + assert y.annotations == x.annotations + + def test_ast_hash_should_consider_relocatable_annotations(self): + relocatable_anno = AnnotationC("a", 2) + const = claripy.BVV(1337, 32) + x0 = claripy.BVS("x", 32).annotate(relocatable_anno) + y0 = claripy.Concat(x0, const) + + # make the annotation not relocatable + # this is of course a hack, but it can demonstrate the problem + relocatable_anno._relocatable = False + x0._relocatable_annotations = frozenset() + + y1 = claripy.Concat(x0, const) + + assert len(y0.annotations) == 1 + assert len(y1.annotations) == 0 + assert y0._hash != y1._hash + + def test_remove_relocatable_annotations(self): + relocatable_anno = AnnotationC("a", 2) + const = claripy.BVV(1337, 32) + + x0 = claripy.BVS("x", 32).annotate(relocatable_anno) + y0 = claripy.Concat(x0, const) + assert len(y0.annotations) == 1 + assert y0.annotations == (relocatable_anno,) + + y1 = y0.remove_annotation(relocatable_anno) + + assert len(y1.annotations) == 0 + + def test_duplicated_annotations_from_makelike(self): + relocatable_anno = AnnotationC("a", 2) + + x0 = claripy.BVS("x", 32).annotate(relocatable_anno) + x1 = claripy.BVV(24, 32) + + # make_like() should not re-apply child annotations if the child is the make_like target + x2 = x0 + x1 + assert len(x2.annotations) == 1 + + # simplify() should not re-apply annotations since annotations are kept during the simplification process by + # make_like(). + x3 = claripy.simplify(x0 + x1) + assert len(x3.annotations) == 1 if __name__ == "__main__": - test_annotations() - test_backend() - test_eagerness() - test_ast_hash_should_consider_relocatable_annotations() - test_remove_relocatable_annotations() - test_duplicated_annotations_from_makelike() - test_simplification() + unittest.main() diff --git a/tests/test_ast.py b/tests/test_ast.py index e02edfb61..489ad7022 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,36 +1,37 @@ -import claripy - - -def test_lite_repr(): - one = claripy.BVV(1, 8) - two = claripy.BVV(2, 8) - a = claripy.BVS("a", 8, explicit_name=True) - b = claripy.BVS("b", 8, explicit_name=True) - - assert (a + one * b + two).shallow_repr() == "" - assert ((a + one) * (b + two)).shallow_repr() == "" - assert (a * one + b * two).shallow_repr() == "" - assert ( - (one + a) * (two + b) + (two + a) * (one + b) - ).shallow_repr() == "" +import unittest +import claripy -def test_associativity(): - x = claripy.BVS("x", 8, explicit_name=True) - y = claripy.BVS("y", 8, explicit_name=True) - z = claripy.BVS("z", 8, explicit_name=True) - w = claripy.BVS("w", 8, explicit_name=True) - assert (x - (y - (z - w))).shallow_repr() == "" - assert (x - y - z - w).shallow_repr() == "" - assert (x * (y * (z * w))).shallow_repr() == (x * y * z * w).shallow_repr() - assert (x * y * z * w).shallow_repr() == "" - assert (x + y - z - w).shallow_repr() == "" - assert (x + (y - (z - w))).shallow_repr() == "" - assert (x * y / z % w).shallow_repr() == "" - assert (x * (y / (z % w))).shallow_repr() == "" +class TestAST(unittest.TestCase): + def test_lite_repr(self): + one = claripy.BVV(1, 8) + two = claripy.BVV(2, 8) + a = claripy.BVS("a", 8, explicit_name=True) + b = claripy.BVS("b", 8, explicit_name=True) + + assert (a + one * b + two).shallow_repr() == "" + assert ((a + one) * (b + two)).shallow_repr() == "" + assert (a * one + b * two).shallow_repr() == "" + assert ( + (one + a) * (two + b) + (two + a) * (one + b) + ).shallow_repr() == "" + + def test_associativity(self): + x = claripy.BVS("x", 8, explicit_name=True) + y = claripy.BVS("y", 8, explicit_name=True) + z = claripy.BVS("z", 8, explicit_name=True) + w = claripy.BVS("w", 8, explicit_name=True) + + assert (x - (y - (z - w))).shallow_repr() == "" + assert (x - y - z - w).shallow_repr() == "" + assert (x * (y * (z * w))).shallow_repr() == (x * y * z * w).shallow_repr() + assert (x * y * z * w).shallow_repr() == "" + assert (x + y - z - w).shallow_repr() == "" + assert (x + (y - (z - w))).shallow_repr() == "" + assert (x * y // z % w).shallow_repr() == "" + assert (x * (y // (z % w))).shallow_repr() == "" if __name__ == "__main__": - test_lite_repr() - test_associativity() + unittest.main() diff --git a/tests/test_backend_smt.py b/tests/test_backend_smt.py index cde8e7b0d..055f2da1b 100644 --- a/tests/test_backend_smt.py +++ b/tests/test_backend_smt.py @@ -40,8 +40,6 @@ def test_concat(self): res = str_concrete + str_symbol solver.add(res == claripy.StringV("concrete")) script = solver.get_smtlib_script_satisfiability() - # with open("dump_concat.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_concat_simplification(self): @@ -69,8 +67,6 @@ def test_substr(self): res = claripy.StrSubstr(1, 2, str_symbol) == claripy.StringV("on") solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_substr.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_substr_BV_concrete_index(self): @@ -86,8 +82,6 @@ def test_substr_BV_concrete_index(self): res = claripy.StrSubstr(bv1, bv2, str_symbol) == claripy.StringV("on") solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_substr_bv_concrete.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_substr_BV_symbolic_index(self): @@ -105,8 +99,6 @@ def test_substr_BV_symbolic_index(self): res = claripy.StrSubstr(bv1, bv2, str_symbol) == claripy.StringV("on") solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_substr_bv_symbolic.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_substr_BV_mixed_index(self): @@ -123,8 +115,6 @@ def test_substr_BV_mixed_index(self): res = claripy.StrSubstr(bv1, bv2, str_symbol) == claripy.StringV("on") solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_substr_bv_symbolic.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_substr_simplification(self): @@ -152,8 +142,6 @@ def test_replace(self): repl_stringa = claripy.StrReplace(str_to_replace_symb, sub_str_to_repl, replacement) solver.add(repl_stringa == claripy.StringV("cbne")) script = solver.get_smtlib_script_satisfiability() - # with open("dump_replace.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_replace_simplification(self): @@ -182,8 +170,6 @@ def test_ne(self): solver = self.get_solver() solver.add(str_symb != claripy.StringV("concrete")) script = solver.get_smtlib_script_satisfiability() - # with open("dump_ne.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_length(self): @@ -197,8 +183,6 @@ def test_length(self): # TODO: How do we want to dela with the size of a symbolic string? solver.add(claripy.StrLen(str_symb, 32) == 14) script = solver.get_smtlib_script_satisfiability() - # with open("dump_length.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_length_simplification(self): @@ -224,8 +208,6 @@ def test_or(self): res = claripy.Or((str_symb == claripy.StringV("abc")), (str_symb == claripy.StringV("ciao"))) solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_or.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_lt_etc(self): @@ -248,8 +230,6 @@ def test_lt_etc(self): solver.add(c3) solver.add(c4) script = solver.get_smtlib_script_satisfiability() - # with open("dump_lt_etc.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_contains(self): @@ -263,8 +243,6 @@ def test_contains(self): solver = self.get_solver() solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_contains.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_contains_simplification(self): @@ -291,7 +269,6 @@ def test_prefix(self): solver = self.get_solver() solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_prefix.smt2", "w") as dump_f: # dump_f.write(script) self.assertEqual(correct_script, script) @@ -306,8 +283,6 @@ def test_suffix(self): solver = self.get_solver() solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_suffix.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_prefix_simplification(self): @@ -347,8 +322,6 @@ def test_index_of(self): solver = self.get_solver() solver.add(res) script = solver.get_smtlib_script_satisfiability() - # with open("dump_suffix.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_index_of_simplification(self): @@ -384,8 +357,6 @@ def test_index_of_symbolic_start_idx(self): solver.add(res == 33) script = solver.get_smtlib_script_satisfiability() - # with open("dump_suffix.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_str_to_int(self): @@ -399,8 +370,6 @@ def test_str_to_int(self): solver = self.get_solver() solver.add(res == 12) script = solver.get_smtlib_script_satisfiability() - # with open("dump_strtoint.smt2", "w") as dump_f: - # dump_f.write(script) self.assertEqual(correct_script, script) def test_str_to_int_simplification(self): @@ -470,5 +439,4 @@ def test_int_to_str_simplification(self): if __name__ == "__main__": - suite = unittest.TestLoader().loadTestsFromTestCase(TestSMTLibBackend) - unittest.TextTestRunner(verbosity=2).run(suite) + unittest.main() diff --git a/tests/test_backend_smt_abc.py b/tests/test_backend_smt_abc.py index 7d3f66aa6..79e3ea710 100644 --- a/tests/test_backend_smt_abc.py +++ b/tests/test_backend_smt_abc.py @@ -5,7 +5,7 @@ import claripy -class SmtLibSolverTest_ABC(common_backend_smt_solver.SmtLibSolverTestBase): +class TestSmtLibSolverTest_ABC(common_backend_smt_solver.SmtLibSolverTestBase): @common_backend_smt_solver.if_installed def get_solver(self): from claripy.backends.backend_smtlib_solvers.abc_popen import SolverBackendABC @@ -16,5 +16,4 @@ def get_solver(self): if __name__ == "__main__": - suite = unittest.TestLoader().loadTestsFromTestCase(SmtLibSolverTest_ABC) - unittest.TextTestRunner(verbosity=2).run(suite) + unittest.main() diff --git a/tests/test_backend_smt_composite.py b/tests/test_backend_smt_composite.py index 08668b23d..8c82da2a9 100644 --- a/tests/test_backend_smt_composite.py +++ b/tests/test_backend_smt_composite.py @@ -5,7 +5,7 @@ import claripy -class SmtLibSolverTest_Z3(common_backend_smt_solver.SmtLibSolverTestBase): +class TestSmtLibSolverTest_Z3(common_backend_smt_solver.SmtLibSolverTestBase): @unittest.skip("Skip these test for now because of a problem with pysmt") def get_solver(self): solver = claripy.SolverPortfolio( @@ -31,5 +31,4 @@ def get_solver(self): if __name__ == "__main__": - suite = unittest.TestLoader().loadTestsFromTestCase(SmtLibSolverTest_Z3) - unittest.TextTestRunner(verbosity=2).run(suite) + unittest.main() diff --git a/tests/test_backend_smt_congruency.py b/tests/test_backend_smt_congruency.py index 5a5a44d8a..a196c7b89 100644 --- a/tests/test_backend_smt_congruency.py +++ b/tests/test_backend_smt_congruency.py @@ -21,7 +21,7 @@ def all_equal(vals): return all(v == v0 for v in vals) -class SmtLibSolverTestCongruency(unittest.TestCase): +class TestSmtLibSolverTestCongruency(unittest.TestCase): @unittest.skip("Skip these test for now because of a problem with pysmt") def get_solvers(self): solvers = [ @@ -166,5 +166,4 @@ def field_sep_idx(s, start_idx=0): if __name__ == "__main__": - suite = unittest.TestLoader().loadTestsFromTestCase(SmtLibSolverTestCongruency) - unittest.TextTestRunner(verbosity=2).run(suite) + unittest.main() diff --git a/tests/test_backend_smt_cvc4.py b/tests/test_backend_smt_cvc4.py index 8a6c08440..691aa91c2 100644 --- a/tests/test_backend_smt_cvc4.py +++ b/tests/test_backend_smt_cvc4.py @@ -6,6 +6,7 @@ from claripy.backends.backend_smtlib_solvers.cvc4_popen import SolverBackendCVC4 +@unittest.skip class SmtLibSolverTest_CVC4(common_backend_smt_solver.SmtLibSolverTestBase): @common_backend_smt_solver.if_installed def get_solver(self): @@ -15,5 +16,4 @@ def get_solver(self): if __name__ == "__main__": - suite = unittest.TestLoader().loadTestsFromTestCase(SmtLibSolverTest_CVC4) - unittest.TextTestRunner(verbosity=2).run(suite) + unittest.main() diff --git a/tests/test_backend_smt_z3.py b/tests/test_backend_smt_z3.py index 37cbe83b5..383c5e2e7 100644 --- a/tests/test_backend_smt_z3.py +++ b/tests/test_backend_smt_z3.py @@ -6,7 +6,7 @@ from claripy.backends.backend_smtlib_solvers.z3_popen import SolverBackendZ3 -class SmtLibSolverTest_Z3(common_backend_smt_solver.SmtLibSolverTestBase): +class TestSmtLibSolverTest_Z3(common_backend_smt_solver.SmtLibSolverTestBase): @unittest.skip("Skip these test for now because of a problem with pysmt") def get_solver(self): backend = SolverBackendZ3(daggify=True) @@ -15,5 +15,4 @@ def get_solver(self): if __name__ == "__main__": - suite = unittest.TestLoader().loadTestsFromTestCase(SmtLibSolverTest_Z3) - unittest.TextTestRunner(verbosity=2).run(suite) + unittest.main() diff --git a/tests/test_balancer.py b/tests/test_balancer.py index a1b582bcc..bc9c23a7c 100644 --- a/tests/test_balancer.py +++ b/tests/test_balancer.py @@ -1,20 +1,11 @@ +import unittest + import claripy -def test_complex_guy(): - guy_wide = claripy.widen( - claripy.union( - claripy.union( - claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), - claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)) + claripy.BVV(1, 32), - ), - claripy.union( - claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), - claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)) + claripy.BVV(1, 32), - ) - + claripy.BVV(1, 32), - ), - claripy.union( +class TestBalancer(unittest.TestCase): + def test_complex(self): + guy_wide = claripy.widen( claripy.union( claripy.union( claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), @@ -28,213 +19,197 @@ def test_complex_guy(): ), claripy.union( claripy.union( - claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), - claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)) + claripy.BVV(1, 32), + claripy.union( + claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), + claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)) + claripy.BVV(1, 32), + ), + claripy.union( + claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), + claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)) + claripy.BVV(1, 32), + ) + + claripy.BVV(1, 32), ), claripy.union( - claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), - claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)) + claripy.BVV(1, 32), + claripy.union( + claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), + claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)) + claripy.BVV(1, 32), + ), + claripy.union( + claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)), + claripy.union(claripy.BVV(0, 32), claripy.BVV(1, 32)) + claripy.BVV(1, 32), + ) + + claripy.BVV(1, 32), ) + claripy.BVV(1, 32), - ) - + claripy.BVV(1, 32), - ), - ) - guy_inc = guy_wide + claripy.BVV(1, 32) - guy_zx = claripy.ZeroExt(32, guy_inc) - - s, r = claripy.balancer.Balancer(claripy.backends.vsa, guy_inc <= claripy.BVV(39, 32)).compat_ret - assert s - assert r[0][0] is guy_wide - assert claripy.backends.vsa.min(r[0][1]) == 0 - assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == {4294967295, *list(range(39))} - - s, r = claripy.balancer.Balancer(claripy.backends.vsa, guy_zx <= claripy.BVV(39, 64)).compat_ret - assert r[0][0] is guy_wide - assert claripy.backends.vsa.min(r[0][1]) == 0 - assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == {4294967295, *list(range(39))} - - -def test_simple_guy(): - x = claripy.BVS("x", 32) - s, r = claripy.balancer.Balancer(claripy.backends.vsa, x <= claripy.BVV(39, 32)).compat_ret - assert s - assert r[0][0] is x - assert claripy.backends.vsa.min(r[0][1]) == 0 - assert claripy.backends.vsa.max(r[0][1]) == 39 - - s, r = claripy.balancer.Balancer(claripy.backends.vsa, x + 1 <= claripy.BVV(39, 32)).compat_ret - assert s - assert r[0][0] is x - all_vals = r[0][1]._model_vsa.eval(1000) - assert len(all_vals) - assert min(all_vals) == 0 - assert max(all_vals) == 4294967295 - all_vals.remove(4294967295) - assert max(all_vals) == 38 - - -def test_widened_guy(): - w = claripy.widen(claripy.BVV(1, 32), claripy.BVV(0, 32)) - s, r = claripy.balancer.Balancer(claripy.backends.vsa, w <= claripy.BVV(39, 32)).compat_ret - assert s - assert r[0][0] is w - assert claripy.backends.vsa.min(r[0][1]) == 0 - assert claripy.backends.vsa.max(r[0][1]) == 1 # used to be 39, but that was a bug in the VSA widening - - s, r = claripy.balancer.Balancer(claripy.backends.vsa, w + 1 <= claripy.BVV(39, 32)).compat_ret - assert s - assert r[0][0] is w - assert claripy.backends.vsa.min(r[0][1]) == 0 - all_vals = r[0][1]._model_vsa.eval(1000) - assert set(all_vals) == {4294967295, 0, 1} - - -def test_overflow(): - x = claripy.BVS("x", 32) - - print("x + 10 <= 20") - s, r = claripy.balancer.Balancer(claripy.backends.vsa, x + 10 <= claripy.BVV(20, 32)).compat_ret - # mn,mx = claripy.backends.vsa.min(r[0][1]), claripy.backends.vsa.max(r[0][1]) - assert s - assert r[0][0] is x - assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == { - 4294967286, - 4294967287, - 4294967288, - 4294967289, - 4294967290, - 4294967291, - 4294967292, - 4294967293, - 4294967294, - 4294967295, - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - } - - # print("0 <= x + 10") - # s,r = claripy.balancer.Balancer(claripy.backends.vsa, 0 <= x + 10).compat_ret - # assert s - # assert r[0][0] is x - - print("x - 10 <= 20") - s, r = claripy.balancer.Balancer(claripy.backends.vsa, x - 10 <= claripy.BVV(20, 32)).compat_ret - assert s - assert r[0][0] is x - assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == set(range(10, 31)) - - # print("0 <= x - 10") - # s,r = claripy.balancer.Balancer(claripy.backends.vsa, 0 <= x - 10).compat_ret - # assert s - # assert r[0][0] is x - - -def test_extract_zeroext(): - x = claripy.BVS("x", 8) - expr = claripy.Extract(31, 0, claripy.ZeroExt(56, x)) <= claripy.BVV(0xE, 32) - s, r = claripy.balancer.Balancer(claripy.backends.vsa, expr).compat_ret - - assert s is True - assert len(r) == 1 - assert r[0][0] is x - - -def test_complex_case_0(): - # - - """ - - - Created by VEX running on the following x86_64 assembly: - cmp word ptr [rdi], 40h - ja skip - """ - - x = claripy.BVS("x", 16) - expr = (claripy.ZeroExt(48, claripy.Reverse(x)) << 0x30) <= 0x40000000000000 - - s, r = claripy.balancer.Balancer(claripy.backends.vsa, expr).compat_ret - - assert s - assert r[0][0] is x - assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == set(range(0, 65 * 0x100, 0x100)) - - -def test_complex_case_1(): - # - - """ - - - Created by VEX running on the following S390X assembly: - 0x40062c: ahik %r2, %r11, -2 - 0x400632: clijh %r2, 8, 0x40065c - - IRSB { - t0:Ity_I32 t1:Ity_I32 t2:Ity_I32 t3:Ity_I32 t4:Ity_I32 t5:Ity_I32 t6:Ity_I64 t7:Ity_I64 t8:Ity_I64 t9:Ity_I64 t10:Ity_I32 t11:Ity_I1 t12:Ity_I64 t13:Ity_I64 t14:Ity_I64 t15:Ity_I32 t16:Ity_I1 - - 00 | ------ IMark(0x40062c, 6, 0) ------ - 01 | t0 = GET:I32(r11_32) - 02 | t1 = Add32(0xfffffffe,t0) - 03 | PUT(352) = 0x0000000000000003 - 04 | PUT(360) = 0xfffffffffffffffe - 05 | t13 = 32Sto64(t0) - 06 | t7 = t13 - 07 | PUT(368) = t7 - 08 | PUT(376) = 0x0000000000000000 - 09 | PUT(r2_32) = t1 - 10 | PUT(ia) = 0x0000000000400632 - 11 | ------ IMark(0x400632, 6, 0) ------ - 12 | t14 = 32Uto64(t1) - 13 | t8 = t14 - 14 | t16 = CmpLT64U(0x0000000000000008,t8) - 15 | t15 = 1Uto32(t16) - 16 | t10 = t15 - 17 | t11 = CmpNE32(t10,0x00000000) - 18 | if (t11) { PUT(ia) = 0x40065c; Ijk_Boring } - NEXT: PUT(ia) = 0x0000000000400638; Ijk_Boring - } - """ - - x = claripy.BVS("x", 32) - expr = claripy.ZeroExt( - 31, claripy.If(claripy.BVV(0x8, 32) < claripy.BVV(0xFFFFFFFE, 32) + x, claripy.BVV(1, 1), claripy.BVV(0, 1)) - ) == claripy.BVV(0, 32) - s, r = claripy.balancer.Balancer(claripy.backends.vsa, expr).compat_ret - - assert s is True - assert len(r) == 1 - assert r[0][0] is x - - -def test_complex_case_2(): - x = claripy.BVS("x", 32) - expr = claripy.ZeroExt( - 31, claripy.If(claripy.BVV(0xC, 32) < x, claripy.BVV(1, 1), claripy.BVV(0, 1)) - ) == claripy.BVV(0, 32) - s, r = claripy.balancer.Balancer(claripy.backends.vsa, expr).compat_ret - - assert s is True - assert len(r) == 1 - assert r[0][0] is x + ), + ) + guy_inc = guy_wide + claripy.BVV(1, 32) + guy_zx = claripy.ZeroExt(32, guy_inc) + + s, r = claripy.balancer.Balancer(claripy.backends.vsa, guy_inc <= claripy.BVV(39, 32)).compat_ret + assert s + assert r[0][0] is guy_wide + assert claripy.backends.vsa.min(r[0][1]) == 0 + assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == {4294967295, *list(range(39))} + + s, r = claripy.balancer.Balancer(claripy.backends.vsa, guy_zx <= claripy.BVV(39, 64)).compat_ret + assert r[0][0] is guy_wide + assert claripy.backends.vsa.min(r[0][1]) == 0 + assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == {4294967295, *list(range(39))} + + def test_simple(self): + x = claripy.BVS("x", 32) + s, r = claripy.balancer.Balancer(claripy.backends.vsa, x <= claripy.BVV(39, 32)).compat_ret + assert s + assert r[0][0] is x + assert claripy.backends.vsa.min(r[0][1]) == 0 + assert claripy.backends.vsa.max(r[0][1]) == 39 + + s, r = claripy.balancer.Balancer(claripy.backends.vsa, x + 1 <= claripy.BVV(39, 32)).compat_ret + assert s + assert r[0][0] is x + all_vals = r[0][1]._model_vsa.eval(1000) + assert len(all_vals) + assert min(all_vals) == 0 + assert max(all_vals) == 4294967295 + all_vals.remove(4294967295) + assert max(all_vals) == 38 + + def test_widened(self): + w = claripy.widen(claripy.BVV(1, 32), claripy.BVV(0, 32)) + s, r = claripy.balancer.Balancer(claripy.backends.vsa, w <= claripy.BVV(39, 32)).compat_ret + assert s + assert r[0][0] is w + assert claripy.backends.vsa.min(r[0][1]) == 0 + assert claripy.backends.vsa.max(r[0][1]) == 1 # used to be 39, but that was a bug in the VSA widening + + s, r = claripy.balancer.Balancer(claripy.backends.vsa, w + 1 <= claripy.BVV(39, 32)).compat_ret + assert s + assert r[0][0] is w + assert claripy.backends.vsa.min(r[0][1]) == 0 + all_vals = r[0][1]._model_vsa.eval(1000) + assert set(all_vals) == {4294967295, 0, 1} + + def test_overflow(self): + x = claripy.BVS("x", 32) + + print("x + 10 <= 20") + s, r = claripy.balancer.Balancer(claripy.backends.vsa, x + 10 <= claripy.BVV(20, 32)).compat_ret + # mn,mx = claripy.backends.vsa.min(r[0][1]), claripy.backends.vsa.max(r[0][1]) + assert s + assert r[0][0] is x + assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == { + 4294967286, + 4294967287, + 4294967288, + 4294967289, + 4294967290, + 4294967291, + 4294967292, + 4294967293, + 4294967294, + 4294967295, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + } + + print("x - 10 <= 20") + s, r = claripy.balancer.Balancer(claripy.backends.vsa, x - 10 <= claripy.BVV(20, 32)).compat_ret + assert s + assert r[0][0] is x + assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == set(range(10, 31)) + + def test_extract_zeroext(self): + x = claripy.BVS("x", 8) + expr = claripy.Extract(31, 0, claripy.ZeroExt(56, x)) <= claripy.BVV(0xE, 32) + s, r = claripy.balancer.Balancer(claripy.backends.vsa, expr).compat_ret + + assert s is True + assert len(r) == 1 + assert r[0][0] is x + + def test_complex_case_0(self): + """ + + + Created by VEX running on the following x86_64 assembly: + cmp word ptr [rdi], 40h + ja skip + """ + + x = claripy.BVS("x", 16) + expr = (claripy.ZeroExt(48, claripy.Reverse(x)) << 0x30) <= 0x40000000000000 + + s, r = claripy.balancer.Balancer(claripy.backends.vsa, expr).compat_ret + + assert s + assert r[0][0] is x + assert set(claripy.backends.vsa.eval(r[0][1], 1000)) == set(range(0, 65 * 0x100, 0x100)) + + def test_complex_case_1(self): + """ + + + Created by VEX running on the following S390X assembly: + 0x40062c: ahik %r2, %r11, -2 + 0x400632: clijh %r2, 8, 0x40065c + + IRSB { + t0:Ity_I32 t1:Ity_I32 t2:Ity_I32 t3:Ity_I32 t4:Ity_I32 t5:Ity_I32 t6:Ity_I64 t7:Ity_I64 t8:Ity_I64 t9:Ity_I64 t10:Ity_I32 t11:Ity_I1 t12:Ity_I64 t13:Ity_I64 t14:Ity_I64 t15:Ity_I32 t16:Ity_I1 + + 00 | ------ IMark(0x40062c, 6, 0) ------ + 01 | t0 = GET:I32(r11_32) + 02 | t1 = Add32(0xfffffffe,t0) + 03 | PUT(352) = 0x0000000000000003 + 04 | PUT(360) = 0xfffffffffffffffe + 05 | t13 = 32Sto64(t0) + 06 | t7 = t13 + 07 | PUT(368) = t7 + 08 | PUT(376) = 0x0000000000000000 + 09 | PUT(r2_32) = t1 + 10 | PUT(ia) = 0x0000000000400632 + 11 | ------ IMark(0x400632, 6, 0) ------ + 12 | t14 = 32Uto64(t1) + 13 | t8 = t14 + 14 | t16 = CmpLT64U(0x0000000000000008,t8) + 15 | t15 = 1Uto32(t16) + 16 | t10 = t15 + 17 | t11 = CmpNE32(t10,0x00000000) + 18 | if (t11) { PUT(ia) = 0x40065c; Ijk_Boring } + NEXT: PUT(ia) = 0x0000000000400638; Ijk_Boring + } + """ + + x = claripy.BVS("x", 32) + expr = claripy.ZeroExt( + 31, claripy.If(claripy.BVV(0x8, 32) < claripy.BVV(0xFFFFFFFE, 32) + x, claripy.BVV(1, 1), claripy.BVV(0, 1)) + ) == claripy.BVV(0, 32) + s, r = claripy.balancer.Balancer(claripy.backends.vsa, expr).compat_ret + + assert s is True + assert len(r) == 1 + assert r[0][0] is x + + def test_complex_case_2(self): + x = claripy.BVS("x", 32) + expr = claripy.ZeroExt( + 31, claripy.If(claripy.BVV(0xC, 32) < x, claripy.BVV(1, 1), claripy.BVV(0, 1)) + ) == claripy.BVV(0, 32) + s, r = claripy.balancer.Balancer(claripy.backends.vsa, expr).compat_ret + + assert s is True + assert len(r) == 1 + assert r[0][0] is x if __name__ == "__main__": - test_overflow() - test_simple_guy() - test_widened_guy() - test_complex_guy() - test_complex_case_0() - test_complex_case_1() - test_complex_case_2() - test_extract_zeroext() + unittest.main() diff --git a/tests/test_concrete.py b/tests/test_concrete.py index cca4f1b6f..92674c909 100644 --- a/tests/test_concrete.py +++ b/tests/test_concrete.py @@ -1,31 +1,32 @@ -import claripy +import unittest +import claripy -def test_concrete(): - a = claripy.BVV(10, 32) - b = claripy.BoolV(True) - assert isinstance(claripy.backends.concrete.convert(a), claripy.bv.BVV) - assert isinstance(claripy.backends.concrete.convert(b), bool) +class TestConcreteBackend(unittest.TestCase): + def test_concrete(self): + a = claripy.BVV(10, 32) + b = claripy.BoolV(True) - a = claripy.BVV(1337, 32) - b = a[31:16] - c = claripy.BVV(0, 16) - assert b is c + assert isinstance(claripy.backends.concrete.convert(a), claripy.bv.BVV) + assert isinstance(claripy.backends.concrete.convert(b), bool) - bc = claripy.backends.concrete - d = claripy.BVV(-1, 32) - assert bc.convert(d) == 0xFFFFFFFF + a = claripy.BVV(1337, 32) + b = a[31:16] + c = claripy.BVV(0, 16) + assert b is c - e = claripy.BVV(2**32 + 1337, 32) - assert bc.convert(e) == 1337 + bc = claripy.backends.concrete + d = claripy.BVV(-1, 32) + assert bc.convert(d) == 0xFFFFFFFF + e = claripy.BVV(2**32 + 1337, 32) + assert bc.convert(e) == 1337 -def test_concrete_fp(): - f = claripy.FPV(1.0, claripy.FSORT_FLOAT) - assert claripy.backends.concrete.eval(f, 2) == (1.0,) + def test_concrete_fp(self): + f = claripy.FPV(1.0, claripy.FSORT_FLOAT) + assert claripy.backends.concrete.eval(f, 2) == (1.0,) if __name__ == "__main__": - test_concrete() - test_concrete_fp() + unittest.main() diff --git a/tests/test_expression.py b/tests/test_expression.py index e9bedbe94..927e225ce 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -283,7 +283,7 @@ def raw_ite(self, solver_type): ss = s.branch() ss.add(z == itz) ss.add(itz != 0) - self.assertEqual(ss.eval(y / x, 100), (2,)) + self.assertEqual(ss.eval(y // x, 100), (2,)) self.assertEqual(sorted(ss.eval(x, 100)), [1, 10, 100]) self.assertEqual(sorted(ss.eval(y, 100)), [2, 20, 200]) @@ -447,10 +447,10 @@ def test_signed_concrete(self): d = claripy.BVV(-3, 32) # test unsigned - assert bc.convert(a / c) == 1 - assert bc.convert(a / d) == 0 - assert bc.convert(b / c) == 0x55555553 - assert bc.convert(b / d) == 0 + assert bc.convert(a // c) == 1 + assert bc.convert(a // d) == 0 + assert bc.convert(b // c) == 0x55555553 + assert bc.convert(b // d) == 0 assert bc.convert(a % c) == 2 assert bc.convert(a % d) == 5 assert bc.convert(b % c) == 2 @@ -478,10 +478,10 @@ def test_signed_symbolic(self): solver.add(d == -3) # test unsigned - assert list(solver.eval(a / c, 2)) == [1] - assert list(solver.eval(a / d, 2)) == [0] - assert list(solver.eval(b / c, 2)) == [0x55555553] - assert list(solver.eval(b / d, 2)) == [0] + assert list(solver.eval(a // c, 2)) == [1] + assert list(solver.eval(a // d, 2)) == [0] + assert list(solver.eval(b // c, 2)) == [0x55555553] + assert list(solver.eval(b // d, 2)) == [0] assert list(solver.eval(a % c, 2)) == [2] assert list(solver.eval(a % d, 2)) == [5] assert list(solver.eval(b % c, 2)) == [2] diff --git a/tests/test_regressions.py b/tests/test_regressions.py index d3bbd6bcd..2e7ae0040 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -1,19 +1,22 @@ +import unittest + import claripy -def test_issue16(): - s = claripy.SolverComposite() +class TestRegressions(unittest.TestCase): + def test_issue16(self): + s = claripy.SolverComposite() - c = claripy.BVS("test", 32) - s.add(c[7:0] != 0) + c = claripy.BVS("test", 32) + s.add(c[7:0] != 0) - assert s.satisfiable() - s.add(c == 0) + assert s.satisfiable() + s.add(c == 0) - # print(s.satisfiable()) - assert not s.satisfiable(extra_constraints=[claripy.BVS("lol", 32) == 0]) - assert not s.satisfiable() + # print(s.satisfiable()) + assert not s.satisfiable(extra_constraints=[claripy.BVS("lol", 32) == 0]) + assert not s.satisfiable() if __name__ == "__main__": - test_issue16() + unittest.main() diff --git a/tests/test_replacements.py b/tests/test_replacements.py index 67eead2c8..1075b893e 100644 --- a/tests/test_replacements.py +++ b/tests/test_replacements.py @@ -1,85 +1,81 @@ -import logging +import unittest import claripy -l = logging.getLogger("claripy.test.replacements") +class TestReplacements(unittest.TestCase): + def test_replacement_solver(self): + sr = claripy.SolverReplacement(claripy.SolverVSA(), replace_constraints=True, complex_auto_replace=True) + x = claripy.BVS("x", 32) -def test_replacement_solver(): - sr = claripy.SolverReplacement(claripy.SolverVSA(), replace_constraints=True, complex_auto_replace=True) - x = claripy.BVS("x", 32) + sr.add(x + 8 == 10) + assert sr.max(x) == sr.min(x) - sr.add(x + 8 == 10) - assert sr.max(x) == sr.min(x) + sr2 = sr.branch() + sr2.add(x + 8 < 2000) + assert sr2.max(x) == sr2.min(x) == sr.max(x) - sr2 = sr.branch() - sr2.add(x + 8 < 2000) - assert sr2.max(x) == sr2.min(x) == sr.max(x) + def test_contradiction(self): + sr = claripy.SolverReplacement(claripy.Solver(), replace_constraints=True) + x = claripy.BVS("x", 32) + sr.add(x == 10) + assert sr.satisfiable() + assert sr.eval(x, 10) == (10,) -def test_contradiction(): - sr = claripy.SolverReplacement(claripy.Solver(), replace_constraints=True) - x = claripy.BVS("x", 32) + sr.add(x == 100) + assert not sr.satisfiable() - sr.add(x == 10) - assert sr.satisfiable() - assert sr.eval(x, 10) == (10,) + def test_branching_replacement_solver(self): + # + # Simple case: replaceable thing first + # - sr.add(x == 100) - assert not sr.satisfiable() + x = claripy.BVS("x", 32) + s0 = claripy.SolverReplacement(claripy.Solver()) + s0.add(x == 0) + s1a = s0.branch() + s1b = s0.branch() -def test_branching_replacement_solver(): - # - # Simple case: replaceable thing first - # + s1a.add(x == 0) + s1b.add(x != 0) - x = claripy.BVS("x", 32) - s0 = claripy.SolverReplacement(claripy.Solver()) - s0.add(x == 0) + assert s1a.satisfiable() + assert not s1b.satisfiable() - s1a = s0.branch() - s1b = s0.branch() + # + # Slightly more complex: different == + # - s1a.add(x == 0) - s1b.add(x != 0) + x = claripy.BVS("x", 32) + s0 = claripy.SolverReplacement(claripy.Solver()) + s0.add(x == 0) - assert s1a.satisfiable() - assert not s1b.satisfiable() + s1a = s0.branch() + s1b = s0.branch() - # - # Slightly more complex: different == - # + s1a.add(x == 0) + s1b.add(x == 1) - x = claripy.BVS("x", 32) - s0 = claripy.SolverReplacement(claripy.Solver()) - s0.add(x == 0) + assert s1a.satisfiable() + assert not s1b.satisfiable() - s1a = s0.branch() - s1b = s0.branch() + # + # Complex case: non-replaceable thing first + # - s1a.add(x == 0) - s1b.add(x == 1) - - assert s1a.satisfiable() - assert not s1b.satisfiable() - - # - # Complex case: non-replaceable thing first - # - - # x = claripy.BVS('x', 32) - # s0 = claripy.SolverReplacement(claripy.Solver()) - # s0.add(x != 0) - # s1a = s0.branch() - # s1b = s0.branch() - # s1a.add(x != 0) - # s1b.add(x == 0) - # assert s1a.satisfiable() - # assert not s1b.satisfiable() + # FIXME: Figure this out and uncomment it + # x = claripy.BVS('x', 32) + # s0 = claripy.SolverReplacement(claripy.Solver()) + # s0.add(x != 0) + # s1a = s0.branch() + # s1b = s0.branch() + # s1a.add(x != 0) + # s1b.add(x == 0) + # assert s1a.satisfiable() + # assert not s1b.satisfiable() if __name__ == "__main__": - test_branching_replacement_solver() - test_replacement_solver() - test_contradiction() + unittest.main() diff --git a/tests/test_serial.py b/tests/test_serial.py index 244e71d44..1b8fdb89f 100644 --- a/tests/test_serial.py +++ b/tests/test_serial.py @@ -1,137 +1,74 @@ -import logging import pickle +import unittest import claripy -l = logging.getLogger("claripy.test.serial") - - -def test_pickle_ast(): - bz = claripy.backends.z3 - - a = claripy.BVV(1, 32) - b = claripy.BVS("x", 32, explicit_name=True) - - c = a + b - assert bz.convert(c).__module__ == "z3.z3" - assert str(bz.convert(c)), "1 + x" - - c_copy = pickle.loads(pickle.dumps(c, -1)) - assert c_copy is c - assert bz.convert(c_copy).__module__ == "z3.z3" - assert str(bz.convert(c_copy)) == "1 + x" - - -def test_pickle_frontend(): - s = claripy.Solver() - x = claripy.BVS("x", 32) - - s.add(x == 1) - assert s.eval(x, 10), (1,) - - ss = pickle.dumps(s) - del s - import gc - - gc.collect() - - s = pickle.loads(ss) - assert s.eval(x, 10), (1,) - - -def test_identity(): - l.info("Running test_identity") - - a = claripy.BVV(1, 32) - b = claripy.BVS("x", 32) - c = a + b - d = ( - a - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - + b - ) - - l.debug("Storing!") - c_info = pickle.dumps(c) - d_info = pickle.dumps(d) - - cc = pickle.loads(c_info) - assert str(cc) == str(c) - cd = pickle.loads(d_info) - assert str(cd) == str(d) - assert c.args[0] is d.args[0] - - l.debug("Time to test some solvers!") - s = claripy.Solver() - x = claripy.BVS("x", 32) - s.add(x == 3) - s.finalize() - ss = pickle.loads(pickle.dumps(s)) - assert str(s.constraints) == str(ss.constraints) - assert str(s.variables) == str(ss.variables) - - s = claripy.SolverComposite() - x = claripy.BVS("x", 32) - s.add(x == 3) - s.finalize() - ss = pickle.loads(pickle.dumps(s)) - old_constraint_sets = [[hash(j) for j in k.constraints] for k in s._solver_list] - new_constraint_sets = [[hash(j) for j in k.constraints] for k in ss._solver_list] - assert old_constraint_sets == new_constraint_sets - assert str(s.variables) == str(ss.variables) + +class TestPickle(unittest.TestCase): + def test_pickle_ast(self): + bz = claripy.backends.z3 + + a = claripy.BVV(1, 32) + b = claripy.BVS("x", 32, explicit_name=True) + + c = a + b + assert bz.convert(c).__module__ == "z3.z3" + assert str(bz.convert(c)), "1 + x" + + c_copy = pickle.loads(pickle.dumps(c, -1)) + assert c_copy is c + assert bz.convert(c_copy).__module__ == "z3.z3" + assert str(bz.convert(c_copy)) == "1 + x" + + def test_pickle_frontend(self): + s = claripy.Solver() + x = claripy.BVS("x", 32) + + s.add(x == 1) + assert s.eval(x, 10), (1,) + + ss = pickle.dumps(s) + del s + import gc + + gc.collect() + + s = pickle.loads(ss) + assert s.eval(x, 10), (1,) + + def test_identity(self): + a = claripy.BVV(1, 32) + b = claripy.BVS("x", 32) + c = a + b + d = a + b * 50 + + c_info = pickle.dumps(c) + d_info = pickle.dumps(d) + + cc = pickle.loads(c_info) + assert str(cc) == str(c) + cd = pickle.loads(d_info) + assert str(cd) == str(d) + assert c.args[0] is d.args[0] + + s = claripy.Solver() + x = claripy.BVS("x", 32) + s.add(x == 3) + s.finalize() + ss = pickle.loads(pickle.dumps(s)) + assert str(s.constraints) == str(ss.constraints) + assert str(s.variables) == str(ss.variables) + + s = claripy.SolverComposite() + x = claripy.BVS("x", 32) + s.add(x == 3) + s.finalize() + ss = pickle.loads(pickle.dumps(s)) + old_constraint_sets = [[hash(j) for j in k.constraints] for k in s._solver_list] + new_constraint_sets = [[hash(j) for j in k.constraints] for k in ss._solver_list] + assert old_constraint_sets == new_constraint_sets + assert str(s.variables) == str(ss.variables) if __name__ == "__main__": - test_pickle_ast() - test_pickle_frontend() - test_identity() + unittest.main() diff --git a/tests/test_simplify.py b/tests/test_simplify.py index ac53dedc8..036a13cf0 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -1,298 +1,254 @@ -import claripy - - -def test_bool_simplification(): - def assert_correct(a, b): - assert claripy.backends.z3.identical(claripy.simplify(a), b) - - a, b, c = (claripy.BoolS(name) for name in ("a", "b", "c")) - - assert_correct(claripy.And(a, claripy.Not(a)), claripy.false) - assert_correct(claripy.Or(a, claripy.Not(a)), claripy.true) - - complex_true_expression = claripy.Or( - claripy.And(a, b), - claripy.Or(claripy.And(a, claripy.Not(b)), claripy.And(claripy.Not(a), c)), - claripy.Or(claripy.And(a, claripy.Not(b)), claripy.And(claripy.Not(a), claripy.Not(c))), - ) - assert_correct(complex_true_expression, claripy.true) - - -def test_simplification(): - def assert_correct(a, b): - assert claripy.backends.z3.identical(a, b) - - x, y, z = (claripy.BVS(name, 32) for name in ("x", "y", "z")) - - # test extraction of concatted values - concatted = claripy.Concat(x, y, z) - - assert_correct(concatted[95:64], x) - assert_correct(concatted[63:32], y) - assert_correct(concatted[31:0], z) - - assert_correct(concatted[95:32], claripy.Concat(x, y)) - assert_correct(concatted[63:0], claripy.Concat(y, z)) - - assert_correct(concatted[95:0], concatted) - - assert_correct(concatted[47:0], claripy.Concat(y, z)[47:0]) - assert_correct(concatted[70:0], concatted[70:0]) - assert_correct(concatted[70:15], concatted[70:15]) - assert_correct(concatted[70:35], claripy.Concat(x, y)[38:3]) - - # test extraction of nested concats - concatted_nested = claripy.Concat(claripy.Reverse(claripy.Concat(x, y)), z) - assert_correct(concatted_nested[63:0], claripy.Concat(claripy.Reverse(x), z)) - - # make sure the division simplification works - assert_correct(2 + x, claripy.backends.z3.simplify(1 + x + 1)) - assert_correct(x / y, claripy.backends.z3.simplify(x / y)) - assert_correct(x % y, claripy.backends.z3.simplify(x % y)) - - -def test_rotate_shift_mask_simplification(): - a = claripy.BVS("N", 32, max=0xC, min=0x1) - extend_ = claripy.BVS("extend", 32, uninitialized=True) - a_ext = extend_.concat(a) - expr = ((a_ext << 3) | (claripy.LShR(a_ext, 61))) & 0x7FFFFFFF8 - # print(expr) - # print(expr._model_vsa) - model_vsa = expr._model_vsa - assert model_vsa.lower_bound == 8 - assert model_vsa.upper_bound == 0x60 - assert model_vsa.cardinality == 12 - - -def test_reverse_extract_reverse_simplification(): - # without the reverse_extract_reverse simplifier, loading dx from rdx will result in the following complicated - # expression: - # Reverse(Extract(63, 48, Reverse(BVS('rdx', 64)))) - - a = claripy.BVS("rdx", 64) - dx = claripy.Reverse(claripy.Extract(63, 48, claripy.Reverse(a))) - - # simplification should have kicked in at this moment - assert dx.op == "Extract" - assert dx.args[0] == 15 - assert dx.args[1] == 0 - assert dx.args[2] is a - - -def test_reverse_concat_reverse_simplification(): - # Reverse(Concat(Reverse(a), Reverse(b))) = Concat(b, a) - - a = claripy.BVS("a", 32) - b = claripy.BVS("b", 32) - x = claripy.Reverse(claripy.Concat(claripy.Reverse(a), claripy.Reverse(b))) - - assert x.op == "Concat" - assert x.args[0] is b - assert x.args[1] is a - - -def perf_boolean_and_simplification_0(): - # Create a gigantic And AST with many operands, one variable at a time - bool_vars = [claripy.BoolS("b%d" % i) for i in range(1500)] - v = bool_vars[0] - for i in range(1, len(bool_vars)): - v = claripy.And(v, bool_vars[i]) - - -def perf_boolean_and_simplification_1(): - # Create a gigantic And AST with many operands, many variables at a time - bool_vars = [claripy.BoolS("b%d" % i) for i in range(500)] - v = bool_vars[0] - for i in range(1, len(bool_vars)): - v = claripy.And(*((*v.args, bool_vars[i] is False))) if v.op == "And" else claripy.And(v, bool_vars[i]) - - -def test_concrete_flatten(): - a = claripy.BVS("a", 32) - b = a + 10 - c = 10 + b - d = a + 20 - assert c is d - - # to future test writers or debuggers: whether the answer is b_neg or b is not particularly important - e = a - 10 - f = e + 20 - b_neg = a - -10 - assert f is b_neg - - g = e - 10 - h = a - 20 - assert g is h - - i = d - 10 - assert i is b - - -def test_mask_eq_constant(): - # - - a = claripy.BVS("sim_data", 8, explicit_name=True) - expr = (claripy.ZeroExt(48, claripy.Extract(15, 0, claripy.Concat(claripy.BVV(0, 63), a[0:0]))) & 0xFFFF) == 0x0 - - assert expr.op == "__eq__" - assert expr.args[0].op == "Extract" - assert expr.args[0].args[0] == 0 and expr.args[0].args[1] == 0 - assert expr.args[0].args[2] is a - assert expr.args[1].op == "BVV" and expr.args[1].args == (0, 1) - - # the highest bit of the mask (0x1fff) is not aligned to 8 - # we want the mask to be BVV(16, 0x1fff) instead of BVV(13, 0x1fff) - a = claripy.BVS("sim_data", 8, explicit_name=True) - expr = (claripy.ZeroExt(48, claripy.Extract(15, 0, claripy.Concat(claripy.BVV(0, 63), a[0:0]))) & 0x1FFF) == 0x0 - - assert expr.op == "__eq__" - assert expr.args[0].op == "__and__" - _, arg1 = expr.args[0].args - assert arg1.size() == 16 - assert arg1.args[0] == 0x1FFF - - -def test_and_mask_comparing_against_constant_simplifier(): - # A & mask == b ==> Extract(_, _, A) == Extract(_, _, b) iff high bits of a and b are zeros - a = claripy.BVS("a", 8) - b = claripy.BVV(0x10, 32) - - expr = claripy.ZeroExt(24, a) & 0xFFFF == b - assert expr is (a == 16) - - expr = claripy.Concat(claripy.BVV(0, 24), a) & 0xFFFF == b - assert expr is (a == 16) - - # A & mask != b ==> Extract(_, _, A) != Extract(_, _, b) iff high bits of a and b are zeros - a = claripy.BVS("a", 8) - b = claripy.BVV(0x102000AA, 32) - - expr = claripy.ZeroExt(24, a) & 0xFFFF == b - assert expr.is_false() - - expr = claripy.Concat(claripy.BVV(0, 24), a) & 0xFFFF == b - assert expr.is_false() - - # A & 0 == 0 ==> true - a = claripy.BVS("a", 32) - b = claripy.BVV(0, 32) - expr = (a & 0) == b - assert expr.is_true() - expr = (a & 0) == claripy.BVV(1, 32) - assert expr.is_false() - - -def test_zeroext_extract_comparing_against_constant_simplifier(): - a = claripy.BVS("a", 8, explicit_name=True) - b = claripy.BVV(0x28, 16) - - expr = claripy.Extract(15, 0, claripy.ZeroExt(24, a)) == b - assert expr is (a == claripy.BVV(0x28, 8)) - - expr = claripy.Extract(7, 0, claripy.ZeroExt(24, a)) == claripy.BVV(0x28, 8) - assert expr is (a == claripy.BVV(0x28, 8)) - - expr = claripy.Extract(7, 0, claripy.ZeroExt(1, a)) == claripy.BVV(0x28, 8) - assert expr is (a == claripy.BVV(0x28, 8)) - - expr = claripy.Extract(6, 0, claripy.ZeroExt(24, a)) == claripy.BVV(0x28, 7) - assert expr.op == "__eq__" - assert expr.args[0].op == "Extract" and expr.args[0].args[0] == 6 and expr.args[0].args[1] == 0 - assert expr.args[0].args[2] is a - assert expr.args[1].args == (0x28, 7) - - expr = claripy.Extract(15, 0, claripy.Concat(claripy.BVV(0, 48), a)) == b - assert expr is (a == claripy.BVV(0x28, 8)) - - bb = claripy.BVV(0x28, 24) - d = claripy.BVS("d", 8, explicit_name=True) - expr = claripy.Extract(23, 0, claripy.Concat(claripy.BVV(0, 24), d)) == bb - assert expr is (d == claripy.BVV(0x28, 8)) - - dd = claripy.BVS("dd", 23, explicit_name=True) - expr = claripy.Extract(23, 0, claripy.Concat(claripy.BVV(0, 2), dd)) == bb - assert expr is (dd == claripy.BVV(0x28, 23)) - - # this was incorrect before - # claripy issue #201 - expr = claripy.Extract(31, 8, claripy.Concat(claripy.BVV(0, 24), dd)) == claripy.BVV(0xFFFF, 24) - assert expr is not (dd == claripy.BVV(0xFFFF, 23)) - - -def test_one_xor_exp_eq_zero(): - var1 = claripy.FPV(150, claripy.fp.FSORT_DOUBLE) - var2 = claripy.FPS("test", claripy.fp.FSORT_DOUBLE) - result = var1 <= var2 - expr = claripy.BVV(1, 1) ^ (claripy.If(result, claripy.BVV(1, 1), claripy.BVV(0, 1))) == claripy.BVV(0, 1) - - assert expr is result - - -def test_bitwise_and_if(): - e = claripy.BVS("e", 8) - cond1 = e >= 5 - cond2 = e != 5 - ifcond1 = claripy.If(cond1, claripy.BVV(1, 1), claripy.BVV(0, 1)) - ifcond2 = claripy.If(cond2, claripy.BVV(1, 1), claripy.BVV(0, 1)) - result = claripy.If(e > 5, claripy.BVV(1, 1), claripy.BVV(0, 1)) - assert ifcond1 & ifcond2 is result - - -def test_invert_if(): - cond = claripy.BoolS("cond") - expr = ~(claripy.If(cond, claripy.BVV(1, 1), claripy.BVV(0, 1))) - result = claripy.If(claripy.Not(cond), claripy.BVV(1, 1), claripy.BVV(0, 1)) - assert expr is result - - -def test_sub_constant(): - expr = claripy.BVS("expr", 32) - assert (expr - 5 == 0) is (expr == 5) +import unittest +import claripy -def test_extract(): - cond = claripy.BoolS("cond") - expr = claripy.If(cond, claripy.BVV(1, 32), claripy.BVV(0, 32))[0:0] - result = claripy.If(cond, claripy.BVV(1, 1), claripy.BVV(0, 1)) - assert expr is result - e = claripy.BVS("e", 32) - expr2 = (~e)[0:0] # pylint:disable=unsubscriptable-object - result2 = ~(e[0:0]) - assert expr2 is result2 +class TestSimplify(unittest.TestCase): + def test_bool_simplification(self): + def assert_correct(a, b): + assert claripy.backends.z3.identical(claripy.simplify(a), b) + a, b, c = (claripy.BoolS(name) for name in ("a", "b", "c")) -def perf(): - import timeit # pylint:disable=import-outside-toplevel + assert_correct(claripy.And(a, claripy.Not(a)), claripy.false) + assert_correct(claripy.Or(a, claripy.Not(a)), claripy.true) - print( - timeit.timeit( - "perf_boolean_and_simplification_0()", - number=10, - setup="from __main__ import perf_boolean_and_simplification_0", - ) - ) - print( - timeit.timeit( - "perf_boolean_and_simplification_1()", - number=10, - setup="from __main__ import perf_boolean_and_simplification_1", + complex_true_expression = claripy.Or( + claripy.And(a, b), + claripy.Or(claripy.And(a, claripy.Not(b)), claripy.And(claripy.Not(a), c)), + claripy.Or(claripy.And(a, claripy.Not(b)), claripy.And(claripy.Not(a), claripy.Not(c))), ) - ) + assert_correct(complex_true_expression, claripy.true) + + def test_simplification(self): + def assert_correct(a, b): + assert claripy.backends.z3.identical(a, b) + + x, y, z = (claripy.BVS(name, 32) for name in ("x", "y", "z")) + + # test extraction of concatted values + concatted = claripy.Concat(x, y, z) + + assert_correct(concatted[95:64], x) + assert_correct(concatted[63:32], y) + assert_correct(concatted[31:0], z) + + assert_correct(concatted[95:32], claripy.Concat(x, y)) + assert_correct(concatted[63:0], claripy.Concat(y, z)) + + assert_correct(concatted[95:0], concatted) + + assert_correct(concatted[47:0], claripy.Concat(y, z)[47:0]) + assert_correct(concatted[70:0], concatted[70:0]) + assert_correct(concatted[70:15], concatted[70:15]) + assert_correct(concatted[70:35], claripy.Concat(x, y)[38:3]) + + # test extraction of nested concats + concatted_nested = claripy.Concat(claripy.Reverse(claripy.Concat(x, y)), z) + assert_correct(concatted_nested[63:0], claripy.Concat(claripy.Reverse(x), z)) + + # make sure the division simplification works + assert_correct(2 + x, claripy.backends.z3.simplify(1 + x + 1)) + assert_correct(x // y, claripy.backends.z3.simplify(x // y)) + assert_correct(x % y, claripy.backends.z3.simplify(x % y)) + + def test_rotate_shift_mask_simplification(self): + a = claripy.BVS("N", 32, max=0xC, min=0x1) + extend_ = claripy.BVS("extend", 32, uninitialized=True) + a_ext = extend_.concat(a) + expr = ((a_ext << 3) | (claripy.LShR(a_ext, 61))) & 0x7FFFFFFF8 + # print(expr) + # print(expr._model_vsa) + model_vsa = expr._model_vsa + assert model_vsa.lower_bound == 8 + assert model_vsa.upper_bound == 0x60 + assert model_vsa.cardinality == 12 + + def test_reverse_extract_reverse_simplification(self): + # without the reverse_extract_reverse simplifier, loading dx from rdx will result in the following complicated + # expression: + # Reverse(Extract(63, 48, Reverse(BVS('rdx', 64)))) + + a = claripy.BVS("rdx", 64) + dx = claripy.Reverse(claripy.Extract(63, 48, claripy.Reverse(a))) + + # simplification should have kicked in at this moment + assert dx.op == "Extract" + assert dx.args[0] == 15 + assert dx.args[1] == 0 + assert dx.args[2] is a + + def test_reverse_concat_reverse_simplification(self): + # Reverse(Concat(Reverse(a), Reverse(b))) = Concat(b, a) + + a = claripy.BVS("a", 32) + b = claripy.BVS("b", 32) + x = claripy.Reverse(claripy.Concat(claripy.Reverse(a), claripy.Reverse(b))) + + assert x.op == "Concat" + assert x.args[0] is b + assert x.args[1] is a + + def perf_boolean_and_simplification_0(self): + # Create a gigantic And AST with many operands, one variable at a time + bool_vars = [claripy.BoolS("b%d" % i) for i in range(1500)] + v = bool_vars[0] + for i in range(1, len(bool_vars)): + v = claripy.And(v, bool_vars[i]) + + def perf_boolean_and_simplification_1(self): + # Create a gigantic And AST with many operands, many variables at a time + bool_vars = [claripy.BoolS("b%d" % i) for i in range(500)] + v = bool_vars[0] + for i in range(1, len(bool_vars)): + v = claripy.And(*((*v.args, bool_vars[i] is False))) if v.op == "And" else claripy.And(v, bool_vars[i]) + + def test_concrete_flatten(self): + a = claripy.BVS("a", 32) + b = a + 10 + c = 10 + b + d = a + 20 + assert c is d + + # to future test writers or debuggers: whether the answer is b_neg or b is not particularly important + e = a - 10 + f = e + 20 + b_neg = a - -10 + assert f is b_neg + + g = e - 10 + h = a - 20 + assert g is h + + i = d - 10 + assert i is b + + def test_mask_eq_constant(self): + # + + a = claripy.BVS("sim_data", 8, explicit_name=True) + expr = (claripy.ZeroExt(48, claripy.Extract(15, 0, claripy.Concat(claripy.BVV(0, 63), a[0:0]))) & 0xFFFF) == 0x0 + + assert expr.op == "__eq__" + assert expr.args[0].op == "Extract" + assert expr.args[0].args[0] == 0 and expr.args[0].args[1] == 0 + assert expr.args[0].args[2] is a + assert expr.args[1].op == "BVV" and expr.args[1].args == (0, 1) + + # the highest bit of the mask (0x1fff) is not aligned to 8 + # we want the mask to be BVV(16, 0x1fff) instead of BVV(13, 0x1fff) + a = claripy.BVS("sim_data", 8, explicit_name=True) + expr = (claripy.ZeroExt(48, claripy.Extract(15, 0, claripy.Concat(claripy.BVV(0, 63), a[0:0]))) & 0x1FFF) == 0x0 + + assert expr.op == "__eq__" + assert expr.args[0].op == "__and__" + _, arg1 = expr.args[0].args + assert arg1.size() == 16 + assert arg1.args[0] == 0x1FFF + + def test_and_mask_comparing_against_constant_simplifier(self): + # A & mask == b ==> Extract(_, _, A) == Extract(_, _, b) iff high bits of a and b are zeros + a = claripy.BVS("a", 8) + b = claripy.BVV(0x10, 32) + + expr = claripy.ZeroExt(24, a) & 0xFFFF == b + assert expr is (a == 16) + + expr = claripy.Concat(claripy.BVV(0, 24), a) & 0xFFFF == b + assert expr is (a == 16) + + # A & mask != b ==> Extract(_, _, A) != Extract(_, _, b) iff high bits of a and b are zeros + a = claripy.BVS("a", 8) + b = claripy.BVV(0x102000AA, 32) + + expr = claripy.ZeroExt(24, a) & 0xFFFF == b + assert expr.is_false() + + expr = claripy.Concat(claripy.BVV(0, 24), a) & 0xFFFF == b + assert expr.is_false() + + # A & 0 == 0 ==> true + a = claripy.BVS("a", 32) + b = claripy.BVV(0, 32) + expr = (a & 0) == b + assert expr.is_true() + expr = (a & 0) == claripy.BVV(1, 32) + assert expr.is_false() + + def test_zeroext_extract_comparing_against_constant_simplifier(self): + a = claripy.BVS("a", 8, explicit_name=True) + b = claripy.BVV(0x28, 16) + + expr = claripy.Extract(15, 0, claripy.ZeroExt(24, a)) == b + assert expr is (a == claripy.BVV(0x28, 8)) + + expr = claripy.Extract(7, 0, claripy.ZeroExt(24, a)) == claripy.BVV(0x28, 8) + assert expr is (a == claripy.BVV(0x28, 8)) + + expr = claripy.Extract(7, 0, claripy.ZeroExt(1, a)) == claripy.BVV(0x28, 8) + assert expr is (a == claripy.BVV(0x28, 8)) + + expr = claripy.Extract(6, 0, claripy.ZeroExt(24, a)) == claripy.BVV(0x28, 7) + assert expr.op == "__eq__" + assert expr.args[0].op == "Extract" and expr.args[0].args[0] == 6 and expr.args[0].args[1] == 0 + assert expr.args[0].args[2] is a + assert expr.args[1].args == (0x28, 7) + + expr = claripy.Extract(15, 0, claripy.Concat(claripy.BVV(0, 48), a)) == b + assert expr is (a == claripy.BVV(0x28, 8)) + + bb = claripy.BVV(0x28, 24) + d = claripy.BVS("d", 8, explicit_name=True) + expr = claripy.Extract(23, 0, claripy.Concat(claripy.BVV(0, 24), d)) == bb + assert expr is (d == claripy.BVV(0x28, 8)) + + dd = claripy.BVS("dd", 23, explicit_name=True) + expr = claripy.Extract(23, 0, claripy.Concat(claripy.BVV(0, 2), dd)) == bb + assert expr is (dd == claripy.BVV(0x28, 23)) + + # this was incorrect before + # claripy issue #201 + expr = claripy.Extract(31, 8, claripy.Concat(claripy.BVV(0, 24), dd)) == claripy.BVV(0xFFFF, 24) + assert expr is not (dd == claripy.BVV(0xFFFF, 23)) + + def test_one_xor_exp_eq_zero(self): + var1 = claripy.FPV(150, claripy.fp.FSORT_DOUBLE) + var2 = claripy.FPS("test", claripy.fp.FSORT_DOUBLE) + result = var1 <= var2 + expr = claripy.BVV(1, 1) ^ (claripy.If(result, claripy.BVV(1, 1), claripy.BVV(0, 1))) == claripy.BVV(0, 1) + + assert expr is result + + def test_bitwise_and_if(self): + e = claripy.BVS("e", 8) + cond1 = e >= 5 + cond2 = e != 5 + ifcond1 = claripy.If(cond1, claripy.BVV(1, 1), claripy.BVV(0, 1)) + ifcond2 = claripy.If(cond2, claripy.BVV(1, 1), claripy.BVV(0, 1)) + result = claripy.If(e > 5, claripy.BVV(1, 1), claripy.BVV(0, 1)) + assert ifcond1 & ifcond2 is result + + def test_invert_if(self): + cond = claripy.BoolS("cond") + expr = ~(claripy.If(cond, claripy.BVV(1, 1), claripy.BVV(0, 1))) + result = claripy.If(claripy.Not(cond), claripy.BVV(1, 1), claripy.BVV(0, 1)) + assert expr is result + + def test_sub_constant(self): + expr = claripy.BVS("expr", 32) + assert (expr - 5 == 0) is (expr == 5) + + def test_extract(self): + cond = claripy.BoolS("cond") + expr = claripy.If(cond, claripy.BVV(1, 32), claripy.BVV(0, 32))[0:0] + result = claripy.If(cond, claripy.BVV(1, 1), claripy.BVV(0, 1)) + assert expr is result + + e = claripy.BVS("e", 32) + expr2 = (~e)[0:0] + result2 = ~(e[0:0]) + assert expr2 is result2 if __name__ == "__main__": - test_simplification() - test_bool_simplification() - test_rotate_shift_mask_simplification() - test_reverse_extract_reverse_simplification() - test_reverse_concat_reverse_simplification() - test_concrete_flatten() - test_mask_eq_constant() - test_and_mask_comparing_against_constant_simplifier() - test_zeroext_extract_comparing_against_constant_simplifier() - test_one_xor_exp_eq_zero() - test_bitwise_and_if() - test_invert_if() - test_sub_constant() - test_extract() + unittest.main() diff --git a/tests/test_smart_join.py b/tests/test_smart_join.py index 7c8028d2c..3c190e24e 100644 --- a/tests/test_smart_join.py +++ b/tests/test_smart_join.py @@ -1,9 +1,7 @@ -import logging +import unittest from claripy.vsa import StridedInterval -l = logging.getLogger("angr_tests") - def check_si_fields(si, stride, lb, ub): if si.stride != stride: @@ -13,27 +11,28 @@ def check_si_fields(si, stride, lb, ub): return si.upper_bound == ub -def test_smart_join(): - s1 = StridedInterval(bits=4, stride=3, lower_bound=9, upper_bound=12) - s2 = StridedInterval(bits=4, stride=3, lower_bound=0, upper_bound=3) - j = StridedInterval.pseudo_join(s1, s2) - u = StridedInterval.least_upper_bound(s1, s2) - assert check_si_fields(u, 3, 0, 12) - assert check_si_fields(j, 3, 0, 12) +class TestStridedInterval(unittest.TestCase): + def test_smart_join(self): + s1 = StridedInterval(bits=4, stride=3, lower_bound=9, upper_bound=12) + s2 = StridedInterval(bits=4, stride=3, lower_bound=0, upper_bound=3) + j = StridedInterval.pseudo_join(s1, s2) + u = StridedInterval.least_upper_bound(s1, s2) + assert check_si_fields(u, 3, 0, 12) + assert check_si_fields(j, 3, 0, 12) - s1 = StridedInterval(bits=4, stride=0, lower_bound=8, upper_bound=8) - s2 = StridedInterval(bits=4, stride=1, lower_bound=14, upper_bound=15) - s3 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=4) - u = StridedInterval.least_upper_bound(s1, s2, s3) - assert check_si_fields(u, 1, 14, 8) + s1 = StridedInterval(bits=4, stride=0, lower_bound=8, upper_bound=8) + s2 = StridedInterval(bits=4, stride=1, lower_bound=14, upper_bound=15) + s3 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=4) + u = StridedInterval.least_upper_bound(s1, s2, s3) + assert check_si_fields(u, 1, 14, 8) - s1 = StridedInterval(bits=4, stride=3, lower_bound=2, upper_bound=8) - s2 = StridedInterval(bits=4, stride=0, lower_bound=1, upper_bound=1) - j = StridedInterval.pseudo_join(s1, s2) - u = StridedInterval.least_upper_bound(s1, s2) - assert check_si_fields(u, 3, 2, 1) - assert check_si_fields(j, 3, 2, 1) + s1 = StridedInterval(bits=4, stride=3, lower_bound=2, upper_bound=8) + s2 = StridedInterval(bits=4, stride=0, lower_bound=1, upper_bound=1) + j = StridedInterval.pseudo_join(s1, s2) + u = StridedInterval.least_upper_bound(s1, s2) + assert check_si_fields(u, 3, 2, 1) + assert check_si_fields(j, 3, 2, 1) if __name__ == "__main__": - test_smart_join() + unittest.main() diff --git a/tests/test_solver.py b/tests/test_solver.py index e15288214..88c14e2b5 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -544,7 +544,7 @@ def test_zero_division_in_cache_mixin(self): s = claripy.Solver() s.add(e == 8) assert s.satisfiable() - s.add(claripy.If(denum == 0, 0, num / denum) == e) + s.add(claripy.If(denum == 0, 0, num // denum) == e) assert s.satisfiable() # As a bonus: s.add(num == 16) diff --git a/tests/test_strided_intervals.py b/tests/test_strided_intervals.py index e24486199..0a7038d8e 100644 --- a/tests/test_strided_intervals.py +++ b/tests/test_strided_intervals.py @@ -1,9 +1,7 @@ -import logging +import unittest from claripy.vsa import StridedInterval -l = logging.getLogger("angr_tests") - def check_si_fields(si, stride, lb, ub): lb &= si.max_int(si.bits) @@ -15,332 +13,365 @@ def check_si_fields(si, stride, lb, ub): return si.upper_bound == ub -def test_division(): +class TestDivision(unittest.TestCase): # non-overlapping - - # simple case 1 - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) - # <4>0x1[0x0, 0x1] - assert check_si_fields(op1.sdiv(op2), 1, 0, 1) - assert check_si_fields(op1.udiv(op2), 1, 0, 1) - - # simple case 2 - op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) - op2 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=1) - # <4>0x1[0xa, 0xc] - assert check_si_fields(op1.sdiv(op2), 1, 10, 12) - assert check_si_fields(op1.udiv(op2), 1, 10, 12) - - # simple case 3 - op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-2) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=-2) - # udiv: <4>0x1[0x0, 0x1] - # sdiv: <4>0x1[0xc, 0x4] - assert check_si_fields(op1.udiv(op2), 1, 0, 1) - assert check_si_fields(op1.sdiv(op2), 1, 12, 4) - - # simple case 4 : Result should be zero - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=0) - # BOT - assert op1.sdiv(op2).is_empty - assert op1.udiv(op2).is_empty - - # Both in 0-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=1, lower_bound=4, upper_bound=6) - # udiv: <4>0x0[0x0, 0x0] - # sdiv: <4>0x0[0x0, 0x0] - assert check_si_fields(op1.udiv(op2), 0, 0, 0) - assert check_si_fields(op1.sdiv(op2), 0, 0, 0) - - # Both in 1-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-3, upper_bound=-1) - # sdiv: <4>0x1[0x1, 0x6] - # udiv: <4>0x0[0x0, 0x0] - assert check_si_fields(op1.sdiv(op2), 1, 1, 6) - assert check_si_fields(op1.udiv(op2), 0, 0, 0) - - # one in 0-hemisphere and one in 1-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=4) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-5, upper_bound=-1) - # sdiv: <4>0x1[0xc, 0xf] - # udiv: <4>0x0[0x0, 0x0] - assert check_si_fields(op1.sdiv(op2), 1, 12, 15) - assert check_si_fields(op1.udiv(op2), 0, 0, 0) + def test_simple_1(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) + # <4>0x1[0x0, 0x1] + assert check_si_fields(op1.sdiv(op2), 1, 0, 1) + assert check_si_fields(op1.udiv(op2), 1, 0, 1) + + def test_simple_2(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) + op2 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=1) + # <4>0x1[0xa, 0xc] + assert check_si_fields(op1.sdiv(op2), 1, 10, 12) + assert check_si_fields(op1.udiv(op2), 1, 10, 12) + + def test_simple_3(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-2) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=-2) + # udiv: <4>0x1[0x0, 0x1] + # sdiv: <4>0x1[0xc, 0x4] + assert check_si_fields(op1.udiv(op2), 1, 0, 1) + assert check_si_fields(op1.sdiv(op2), 1, 12, 4) + + def test_simple_4(self): + # simple case 4 : Result should be zero + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=0) + # BOT + assert op1.sdiv(op2).is_empty + assert op1.udiv(op2).is_empty + + def test_both_0_hemisphere(self): + # Both in 0-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=1, lower_bound=4, upper_bound=6) + # udiv: <4>0x0[0x0, 0x0] + # sdiv: <4>0x0[0x0, 0x0] + assert check_si_fields(op1.udiv(op2), 0, 0, 0) + assert check_si_fields(op1.sdiv(op2), 0, 0, 0) + + def test_both_1_hemisphere(self): + # Both in 1-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-3, upper_bound=-1) + # sdiv: <4>0x1[0x1, 0x6] + # udiv: <4>0x0[0x0, 0x0] + assert check_si_fields(op1.sdiv(op2), 1, 1, 6) + assert check_si_fields(op1.udiv(op2), 0, 0, 0) + + def test_one_0_one_1_hemisphere(self): + # one in 0-hemisphere and one in 1-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=4) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-5, upper_bound=-1) + # sdiv: <4>0x1[0xc, 0xf] + # udiv: <4>0x0[0x0, 0x0] + assert check_si_fields(op1.sdiv(op2), 1, 12, 15) + assert check_si_fields(op1.udiv(op2), 0, 0, 0) # Overlapping - # case a of figure 2 - # Both in 0-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=6) - # sdiv: <4>0x1[0x0, 0x3] - # udiv: <4>0x1[0x0, 0x3] - assert check_si_fields(op1.sdiv(op2), 1, 0, 3) - assert check_si_fields(op1.udiv(op2), 1, 0, 3) - - # Both in 1-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-7, upper_bound=-1) - # sdiv: <4>0x1[0x0, 0x6] - # udiv: <4>0x1[0x0, 0x1] - assert check_si_fields(op1.sdiv(op2), 1, 0, 6) - assert check_si_fields(op1.udiv(op2), 1, 0, 1) - - # case b Fig 2 - op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-6) - op2 = StridedInterval(bits=4, stride=1, lower_bound=7, upper_bound=5) - # sdiv: <4>0x1[0x0, 0xf] - # udiv: <4>0x1[0x0, 0xa] - assert check_si_fields(op1.sdiv(op2), 1, 0, 15) - assert check_si_fields(op1.udiv(op2), 1, 0, 10) - - # case c Fig 2 - op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-6) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=5) - # sdiv: <4>0x1[0x0, 0xe] - # udiv: <4>0x1[0x0, 0xa] - assert check_si_fields(op1.sdiv(op2), 1, 0, 14) - assert check_si_fields(op1.udiv(op2), 1, 0, 10) + def test_overlapping_both_0_hemisphere(self): + # case a of figure 2 + # Both in 0-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=6) + # sdiv: <4>0x1[0x0, 0x3] + # udiv: <4>0x1[0x0, 0x3] + assert check_si_fields(op1.sdiv(op2), 1, 0, 3) + assert check_si_fields(op1.udiv(op2), 1, 0, 3) + + def test_overlapping_both_1_hemisphere(self): + # Both in 1-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-7, upper_bound=-1) + # sdiv: <4>0x1[0x0, 0x6] + # udiv: <4>0x1[0x0, 0x1] + assert check_si_fields(op1.sdiv(op2), 1, 0, 6) + assert check_si_fields(op1.udiv(op2), 1, 0, 1) + + def test_overlapping_case_b(self): + # case b Fig 2 + op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-6) + op2 = StridedInterval(bits=4, stride=1, lower_bound=7, upper_bound=5) + # sdiv: <4>0x1[0x0, 0xf] + # udiv: <4>0x1[0x0, 0xa] + assert check_si_fields(op1.sdiv(op2), 1, 0, 15) + assert check_si_fields(op1.udiv(op2), 1, 0, 10) + + def test_overlapping_case_c(self): + # case c Fig 2 + op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-6) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=5) + # sdiv: <4>0x1[0x0, 0xe] + # udiv: <4>0x1[0x0, 0xa] + assert check_si_fields(op1.sdiv(op2), 1, 0, 14) + assert check_si_fields(op1.udiv(op2), 1, 0, 10) # Strided Tests - # Both in 0-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=2, lower_bound=4, upper_bound=6) - # sdiv: <4>0x0[0x0, 0x0] - # udiv: <4>0x0[0x0, 0x0] - assert check_si_fields(op1.sdiv(op2), 0, 0, 0) - assert check_si_fields(op1.udiv(op2), 0, 0, 0) - - # Both in 1-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) - op2 = StridedInterval(bits=4, stride=2, lower_bound=-3, upper_bound=-1) - # sdiv: <4>0x1[0x1, 0x6] - # udiv: <4>0x0[0x0, 0x0] - assert check_si_fields(op1.sdiv(op2), 1, 1, 6) - assert check_si_fields(op1.udiv(op2), 0, 0, 0) - - # Overlapping case 1 - op1 = StridedInterval(bits=4, stride=2, lower_bound=3, upper_bound=-7) - op2 = StridedInterval(bits=4, stride=3, lower_bound=7, upper_bound=3) - # sdiv: <4>0x1[0x9, 0x7] - # udiv: <4>0x1[0x0, 0x9] - assert check_si_fields(op1.sdiv(op2), 1, 9, 7) - assert check_si_fields(op1.udiv(op2), 1, 0, 9) - - # Overlapping case 2 - op1 = StridedInterval(bits=4, stride=2, lower_bound=3, upper_bound=-7) - op2 = StridedInterval(bits=4, stride=2, lower_bound=1, upper_bound=3) - # sdiv: <4>0x1[0x1, 0xd] - # udiv: <4>0x1[0x1, 0x9] - assert check_si_fields(op1.sdiv(op2), 1, 1, 13) - assert check_si_fields(op1.udiv(op2), 1, 1, 9) - - -def test_multiplication(): + def test_strided_both_0_hemisphere(self): + # Both in 0-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=2, lower_bound=4, upper_bound=6) + # sdiv: <4>0x0[0x0, 0x0] + # udiv: <4>0x0[0x0, 0x0] + assert check_si_fields(op1.sdiv(op2), 0, 0, 0) + assert check_si_fields(op1.udiv(op2), 0, 0, 0) + + def test_strided_both_1_hemisphere(self): + # Both in 1-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) + op2 = StridedInterval(bits=4, stride=2, lower_bound=-3, upper_bound=-1) + # sdiv: <4>0x1[0x1, 0x6] + # udiv: <4>0x0[0x0, 0x0] + assert check_si_fields(op1.sdiv(op2), 1, 1, 6) + assert check_si_fields(op1.udiv(op2), 0, 0, 0) + + def test_strided_overlapping_case_1(self): + # Overlapping case 1 + op1 = StridedInterval(bits=4, stride=2, lower_bound=3, upper_bound=-7) + op2 = StridedInterval(bits=4, stride=3, lower_bound=7, upper_bound=3) + # sdiv: <4>0x1[0x9, 0x7] + # udiv: <4>0x1[0x0, 0x9] + assert check_si_fields(op1.sdiv(op2), 1, 9, 7) + assert check_si_fields(op1.udiv(op2), 1, 0, 9) + + def test_strided_overlapping_case_2(self): + # Overlapping case 2 + op1 = StridedInterval(bits=4, stride=2, lower_bound=3, upper_bound=-7) + op2 = StridedInterval(bits=4, stride=2, lower_bound=1, upper_bound=3) + # sdiv: <4>0x1[0x1, 0xd] + # udiv: <4>0x1[0x1, 0x9] + assert check_si_fields(op1.sdiv(op2), 1, 1, 13) + assert check_si_fields(op1.udiv(op2), 1, 1, 9) + + +class TestMultiplication(unittest.TestCase): # non-overlapping - - # simple case 1 - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) - # <4>0x2[0x2, 0x6] - assert check_si_fields(op1.mul(op2), 2, 2, 6) - - # simple case 2 - op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) - op2 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=1) - # <4>0x1[0xa, 0xc] - assert check_si_fields(op1.mul(op2), 1, 10, 12) - - # simple case 3 - op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-2) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=-2) - # Stride should be 2. - # NOTE: previous result was: <4>0x1[0x4, 0x0] which is wrong. - # possible values of 1[3,e] * 0[e,e] on 4 bits are [a, 8, 6, 4, 2, 0, e, c] - # in the previous SI 2 was not present. - assert check_si_fields(op1.mul(op2), 2, 2, 0) - - # simple case 4 : Result should be zero - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=0) - # <4>0x0[0x0, 0x0] - assert check_si_fields(op1.mul(op2), 0, 0, 0) - - # Both in 0-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=1, lower_bound=4, upper_bound=6) - # Result: <4>0x1[0x4, 0x2] - assert check_si_fields(op1.mul(op2), 1, 4, 2) - - # Both in 1-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-3, upper_bound=-1) - # Result <4>0x1[0x4, 0x2] - assert check_si_fields(op1.mul(op2), 1, 4, 2) - - # one in 0-hemisphere and one in 1-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=4) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-5, upper_bound=-1) - # TOP - assert check_si_fields(op1.mul(op2), 1, 0, 15) + def test_simple_1(self): + # simple case 1 + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) + # <4>0x2[0x2, 0x6] + assert check_si_fields(op1.mul(op2), 2, 2, 6) + + def test_simple_2(self): + # simple case 2 + op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) + op2 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=1) + # <4>0x1[0xa, 0xc] + assert check_si_fields(op1.mul(op2), 1, 10, 12) + + def test_simple_3(self): + # simple case 3 + op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-2) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=-2) + # Stride should be 2. + # NOTE: previous result was: <4>0x1[0x4, 0x0] which is wrong. + # possible values of 1[3,e] * 0[e,e] on 4 bits are [a, 8, 6, 4, 2, 0, e, c] + # in the previous SI 2 was not present. + assert check_si_fields(op1.mul(op2), 2, 2, 0) + + def test_simple_4(self): + # simple case 4 : Result should be zero + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=0) + # <4>0x0[0x0, 0x0] + assert check_si_fields(op1.mul(op2), 0, 0, 0) + + def test_both_0_hemisphere(self): + # simple case 4 : Result should be zero + # Both in 0-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=1, lower_bound=4, upper_bound=6) + # Result: <4>0x1[0x4, 0x2] + assert check_si_fields(op1.mul(op2), 1, 4, 2) + + def test_both_1_hemisphere(self): + # Both in 1-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-3, upper_bound=-1) + # Result <4>0x1[0x4, 0x2] + assert check_si_fields(op1.mul(op2), 1, 4, 2) + + def test_one_0_one_1_hemisphere(self): + # one in 0-hemisphere and one in 1-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=4) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-5, upper_bound=-1) + # TOP + assert check_si_fields(op1.mul(op2), 1, 0, 15) # Overlapping - # case a of figure 2 - # Both in 0-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=6) - # TOP - assert check_si_fields(op1.mul(op2), 1, 0, 15) - - # Both in 1-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-7, upper_bound=-1) - # TOP - assert check_si_fields(op1.mul(op2), 1, 0, 15) - - # case b Fig 2 - op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-6) - op2 = StridedInterval(bits=4, stride=1, lower_bound=7, upper_bound=5) - # <4>0x1[0x0, 0xf] - assert check_si_fields(op1.mul(op2), 1, 0, 15) - - # case c Fig 2 - op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-6) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=5) - # <4>0x1[0x0, 0xf] - assert check_si_fields(op1.mul(op2), 1, 0, 15) + def test_overlapping_both_0_hemisphere(self): + # case a of figure 2 + # Both in 0-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=6) + # TOP + assert check_si_fields(op1.mul(op2), 1, 0, 15) + + def test_overlapping_both_1_hemisphere(self): + # Both in 1-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-7, upper_bound=-1) + # TOP + assert check_si_fields(op1.mul(op2), 1, 0, 15) + + def test_overlapping_case_b(self): + # case b Fig 2 + op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-6) + op2 = StridedInterval(bits=4, stride=1, lower_bound=7, upper_bound=5) + # <4>0x1[0x0, 0xf] + assert check_si_fields(op1.mul(op2), 1, 0, 15) + + def test_overlapping_case_c(self): + # case c Fig 2 + op1 = StridedInterval(bits=4, stride=1, lower_bound=3, upper_bound=-6) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=5) + # <4>0x1[0x0, 0xf] + assert check_si_fields(op1.mul(op2), 1, 0, 15) # Strided Tests - # Both in 0-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) - op2 = StridedInterval(bits=4, stride=2, lower_bound=4, upper_bound=6) - # <4>0x1[0x4, 0x2] - assert check_si_fields(op1.mul(op2), 1, 4, 2) - - # Both in 1-hemisphere - op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) - op2 = StridedInterval(bits=4, stride=2, lower_bound=-3, upper_bound=-1) - # <4>0x1[0x4, 0x2] - assert check_si_fields(op1.mul(op2), 1, 4, 2) - - # Overlapping case 1 - op1 = StridedInterval(bits=4, stride=2, lower_bound=3, upper_bound=-7) - op2 = StridedInterval(bits=4, stride=3, lower_bound=7, upper_bound=3) - # <4>0x1[0x0, 0xf] - assert check_si_fields(op1.mul(op2), 1, 0, 15) - - # Overlapping case 2 - op1 = StridedInterval(bits=4, stride=2, lower_bound=3, upper_bound=-7) - op2 = StridedInterval(bits=4, stride=2, lower_bound=1, upper_bound=3) - # TOP - assert check_si_fields(op1.mul(op2), 1, 0, 15) - - -def test_subtraction(): - # Basic Interval Tests - op1 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=-6) - # Result should be TOP - assert check_si_fields(op1.sub(op2), 1, 0, 15) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=6) - # Result should be 1,[-5, 5] - # print(str(op1.sub(op2))) - assert check_si_fields(op1.sub(op2), 1, -5, 5) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) - # Result should be 1,[15, 5] - assert check_si_fields(op1.sub(op2), 1, 15, 5) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) - # Result should be 1,[-4, 5] - assert check_si_fields(op1.sub(op2), 1, -4, 5) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=0) - # Result should be 1,[1, 7] - assert check_si_fields(op1.sub(op2), 1, 1, 7) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-5, upper_bound=-1) - # Result should be 1,[2, 12] - # print(str(op1.sub(op2))) - assert check_si_fields(op1.sub(op2), 1, 2, 12) + def test_strided_both_0_hemisphere(self): + # Both in 0-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=3) + op2 = StridedInterval(bits=4, stride=2, lower_bound=4, upper_bound=6) + # <4>0x1[0x4, 0x2] + assert check_si_fields(op1.mul(op2), 1, 4, 2) + + def test_strided_both_1_hemisphere(self): + # Both in 1-hemisphere + op1 = StridedInterval(bits=4, stride=1, lower_bound=-6, upper_bound=-4) + op2 = StridedInterval(bits=4, stride=2, lower_bound=-3, upper_bound=-1) + # <4>0x1[0x4, 0x2] + assert check_si_fields(op1.mul(op2), 1, 4, 2) + + def test_strided_overlapping_case_1(self): + # Overlapping case 1 + op1 = StridedInterval(bits=4, stride=2, lower_bound=3, upper_bound=-7) + op2 = StridedInterval(bits=4, stride=3, lower_bound=7, upper_bound=3) + # <4>0x1[0x0, 0xf] + assert check_si_fields(op1.mul(op2), 1, 0, 15) + + def test_strided_overlapping_case_2(self): + # Overlapping case 2 + op1 = StridedInterval(bits=4, stride=2, lower_bound=3, upper_bound=-7) + op2 = StridedInterval(bits=4, stride=2, lower_bound=1, upper_bound=3) + # TOP + assert check_si_fields(op1.mul(op2), 1, 0, 15) + + +class TestSubtraction(unittest.TestCase): + def test_basic_interval_1(self): + # Basic Interval Tests + op1 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=-6) + # Result should be TOP + assert check_si_fields(op1.sub(op2), 1, 0, 15) + + def test_basic_interval_2(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=6) + # Result should be 1,[-5, 5] + # print(str(op1.sub(op2))) + assert check_si_fields(op1.sub(op2), 1, -5, 5) + + def test_basic_interval_3(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) + # Result should be 1,[15, 5] + assert check_si_fields(op1.sub(op2), 1, 15, 5) + + def test_basic_interval_4(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) + # Result should be 1,[-4, 5] + assert check_si_fields(op1.sub(op2), 1, -4, 5) + + def test_basic_interval_5(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=0) + # Result should be 1,[1, 7] + assert check_si_fields(op1.sub(op2), 1, 1, 7) + + def test_basic_interval_6(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-5, upper_bound=-1) + # Result should be 1,[2, 12] + # print(str(op1.sub(op2))) + assert check_si_fields(op1.sub(op2), 1, 2, 12) # Strided Tests - op1 = StridedInterval(bits=4, stride=2, lower_bound=-2, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=-6) - # Result should be TOP - assert check_si_fields(op1.sub(op2), 1, 0, 15) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=2, lower_bound=2, upper_bound=6) - # Result should be 1,[11, 5] - assert check_si_fields(op1.sub(op2), 1, 11, 5) - - -def test_add(): - # Basic Interval Tests - op1 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=-6) - # Result should be TOP - assert check_si_fields(op1.add(op2), 1, 0, 15) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=6) - # Result should be 1,[3, 13] - assert check_si_fields(op1.add(op2), 1, 3, 13) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) - # Result should be 1,[3, 9] - assert check_si_fields(op1.add(op2), 1, 3, 9) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) - # Result should be 1,[0,9] - assert check_si_fields(op1.add(op2), 1, 0, 9) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=0) - # Result should be 1,[1,7] - assert check_si_fields(op1.add(op2), 1, 1, 7) - - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=-5, upper_bound=-1) - # Result should be 1,[-4, 6] - assert check_si_fields(op1.add(op2), 1, -4, 6) + def test_strided_1(self): + op1 = StridedInterval(bits=4, stride=2, lower_bound=-2, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=-6) + # Result should be TOP + assert check_si_fields(op1.sub(op2), 1, 0, 15) + + def test_strided_2(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=2, lower_bound=2, upper_bound=6) + # Result should be 1,[11, 5] + assert check_si_fields(op1.sub(op2), 1, 11, 5) + + +class TestAddition(unittest.TestCase): + def test_interval_1(self): + # Basic Interval Tests + op1 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=-6) + # Result should be TOP + assert check_si_fields(op1.add(op2), 1, 0, 15) + + def test_interval_2(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=6) + # Result should be 1,[3, 13] + assert check_si_fields(op1.add(op2), 1, 3, 13) + + def test_interval_3(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) + # Result should be 1,[3, 9] + assert check_si_fields(op1.add(op2), 1, 3, 9) + + def test_interval_4(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=-2, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=2, upper_bound=2) + # Result should be 1,[0,9] + assert check_si_fields(op1.add(op2), 1, 0, 9) + + def test_interval_5(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=0) + # Result should be 1,[1,7] + assert check_si_fields(op1.add(op2), 1, 1, 7) + + def test_interval_6(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=-5, upper_bound=-1) + # Result should be 1,[-4, 6] + assert check_si_fields(op1.add(op2), 1, -4, 6) # Strided Tests - op1 = StridedInterval(bits=4, stride=2, lower_bound=-2, upper_bound=7) - op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=-6) - # Result should be TOP - assert check_si_fields(op1.add(op2), 1, 0, 15) + def test_strided_1(self): + op1 = StridedInterval(bits=4, stride=2, lower_bound=-2, upper_bound=7) + op2 = StridedInterval(bits=4, stride=1, lower_bound=0, upper_bound=-6) + # Result should be TOP + assert check_si_fields(op1.add(op2), 1, 0, 15) - op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) - op2 = StridedInterval(bits=4, stride=2, lower_bound=2, upper_bound=6) - # Result should be 1,[3, 13] - assert check_si_fields(op1.add(op2), 1, 3, 13) + def test_strided_2(self): + op1 = StridedInterval(bits=4, stride=1, lower_bound=1, upper_bound=7) + op2 = StridedInterval(bits=4, stride=2, lower_bound=2, upper_bound=6) + # Result should be 1,[3, 13] + assert check_si_fields(op1.add(op2), 1, 3, 13) if __name__ == "__main__": - # Addition tests - l.info("Performing Add Tests") - test_add() - l.info("Performing Subtraction Tests") - test_subtraction() - l.info("Performing Multiplication Tests") - test_multiplication() - l.info("Performing Division Tests") - test_division() - print("[+] All Tests Passed") + unittest.main() diff --git a/tests/test_vsa.py b/tests/test_vsa.py index cf21e93d0..28645372c 100644 --- a/tests/test_vsa.py +++ b/tests/test_vsa.py @@ -1,3 +1,5 @@ +import unittest + import claripy from claripy.vsa import ( BoolResult, @@ -11,1111 +13,1092 @@ def vsa_model(a): return claripy.backends.vsa.convert(a) -def test_fucked_extract(): - not_fucked = claripy.Reverse( - claripy.Concat( - claripy.BVS("file_/dev/stdin_6_0_16_8", 8, explicit_name=True), - claripy.BVS("file_/dev/stdin_6_1_17_8", 8, explicit_name=True), +class TestVSA(unittest.TestCase): + def test_fucked_extract(self): + not_fucked = claripy.Reverse( + claripy.Concat( + claripy.BVS("file_/dev/stdin_6_0_16_8", 8, explicit_name=True), + claripy.BVS("file_/dev/stdin_6_1_17_8", 8, explicit_name=True), + ) ) - ) - m = claripy.backends.vsa.max(not_fucked) - assert m > 0 - - zx = claripy.ZeroExt(16, not_fucked) - pre_fucked = claripy.Reverse(zx) - m = claripy.backends.vsa.max(pre_fucked) - assert m > 0 - - # print(zx, claripy.backends.vsa.convert(zx)) - # print(pre_fucked, claripy.backends.vsa.convert(pre_fucked)) - fucked = pre_fucked[31:16] - m = claripy.backends.vsa.max(fucked) - assert m > 0 - - # here's another case - wtf = ( - ( - claripy.Reverse( - claripy.Concat( - claripy.BVS("w", 8), - claripy.BVS("x", 8), - claripy.BVS("y", 8), - claripy.BVS("z", 8), + m = claripy.backends.vsa.max(not_fucked) + assert m > 0 + + zx = claripy.ZeroExt(16, not_fucked) + pre_fucked = claripy.Reverse(zx) + m = claripy.backends.vsa.max(pre_fucked) + assert m > 0 + + fucked = pre_fucked[31:16] + m = claripy.backends.vsa.max(fucked) + assert m > 0 + + # here's another case + wtf = ( + ( + claripy.Reverse( + claripy.Concat( + claripy.BVS("w", 8), + claripy.BVS("x", 8), + claripy.BVS("y", 8), + claripy.BVS("z", 8), + ) ) + & claripy.BVV(15, 32) ) - & claripy.BVV(15, 32) + + claripy.BVV(48, 32) + )[7:0] + + m = claripy.backends.vsa.max(wtf) + assert m > 0 + + def test_reversed_concat(self): + a = claripy.SI("a", 32, lower_bound=10, upper_bound=0x80, stride=10) + b = claripy.SI("b", 32, lower_bound=1, upper_bound=0xFF, stride=1) + + reversed_a = claripy.Reverse(a) + reversed_b = claripy.Reverse(b) + + # First let's check if the reversing makes sense + assert claripy.backends.vsa.min(reversed_a) == 0xA000000 + assert claripy.backends.vsa.max(reversed_a) == 0x80000000 + assert claripy.backends.vsa.min(reversed_b) == 0x1000000 + assert claripy.backends.vsa.max(reversed_b) == 0xFF000000 + + a_concat_b = claripy.Concat(a, b) + assert a_concat_b._model_vsa._reversed is False + + ra_concat_b = claripy.Concat(reversed_a, b) + assert ra_concat_b._model_vsa._reversed is False + + a_concat_rb = claripy.Concat(a, reversed_b) + assert a_concat_rb._model_vsa._reversed is False + + ra_concat_rb = claripy.Concat(reversed_a, reversed_b) + assert ra_concat_rb._model_vsa._reversed is False + + def test_simple_cardinality(self): + x = claripy.BVS("x", 32, 0xA, 0x14, 0xA) + assert x.cardinality == 2 + + def test_wrapped_intervals(self): + # SI = claripy.StridedInterval + + # Disable the use of DiscreteStridedIntervalSet + claripy.vsa.strided_interval.allow_dsis = False + + # + # Signedness/unsignedness conversion + # + + si1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFF) + assert vsa_model(si1)._signed_bounds() == [(0x0, 0x7FFFFFFF), (-0x80000000, -0x1)] + assert vsa_model(si1)._unsigned_bounds() == [(0x0, 0xFFFFFFFF)] + + # + # Pole-splitting + # + + # south-pole splitting + si1 = claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=1) + si_list = vsa_model(si1)._ssplit() + assert len(si_list) == 2 + assert si_list[0].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=-1))) + assert si_list[1].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=1))) + + # north-pole splitting + si1 = claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=-3) + si_list = vsa_model(si1)._nsplit() + assert len(si_list) == 2 + assert si_list[0].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=0x7FFFFFFF))) + assert si_list[1].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=0x80000000, upper_bound=-3))) + + # north-pole splitting, episode 2 + si1 = claripy.SI(bits=32, stride=3, lower_bound=3, upper_bound=0) + si_list = vsa_model(si1)._nsplit() + assert len(si_list) == 2 + assert si_list[0].identical(vsa_model(claripy.SI(bits=32, stride=3, lower_bound=3, upper_bound=0x7FFFFFFE))) + assert si_list[1].identical(vsa_model(claripy.SI(bits=32, stride=3, lower_bound=0x80000001, upper_bound=0))) + + # bipolar splitting + si1 = claripy.SI(bits=32, stride=1, lower_bound=-2, upper_bound=-8) + si_list = vsa_model(si1)._psplit() + assert len(si_list) == 3 + assert si_list[0].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=-2, upper_bound=-1))) + assert si_list[1].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0x7FFFFFFF))) + assert si_list[2].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=0x80000000, upper_bound=-8))) + + # + # Addition + # + + # Plain addition + si1 = claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=1) + si2 = claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=1) + si3 = claripy.SI(bits=32, stride=1, lower_bound=-2, upper_bound=2) + assert claripy.backends.vsa.identical(si1 + si2, si3) + si4 = claripy.SI(bits=32, stride=1, lower_bound=0xFFFFFFFE, upper_bound=2) + assert claripy.backends.vsa.identical(si1 + si2, si4) + si5 = claripy.SI(bits=32, stride=1, lower_bound=2, upper_bound=-2) + assert not claripy.backends.vsa.identical(si1 + si2, si5) + + # Addition with overflowing cardinality + si1 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0xFE) + si2 = claripy.SI(bits=8, stride=1, lower_bound=0xFE, upper_bound=0xFF) + assert vsa_model(si1 + si2).is_top + + # Addition that shouldn't get a TOP + si1 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0xFE) + si2 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0) + assert not vsa_model(si1 + si2).is_top + + # + # Subtraction + # + + si1 = claripy.SI(bits=8, stride=1, lower_bound=10, upper_bound=15) + si2 = claripy.SI(bits=8, stride=1, lower_bound=11, upper_bound=12) + si3 = claripy.SI(bits=8, stride=1, lower_bound=-2, upper_bound=4) + assert claripy.backends.vsa.identical(si1 - si2, si3) + + # + # Multiplication + # + + # integer multiplication + si1 = claripy.SI(bits=32, to_conv=0xFFFF) + si2 = claripy.SI(bits=32, to_conv=0x10000) + si3 = claripy.SI(bits=32, to_conv=0xFFFF0000) + assert claripy.backends.vsa.identical(si1 * si2, si3) + + # intervals multiplication + si1 = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=15) + si2 = claripy.SI(bits=32, stride=1, lower_bound=20, upper_bound=30) + si3 = claripy.SI(bits=32, stride=1, lower_bound=200, upper_bound=450) + assert claripy.backends.vsa.identical(si1 * si2, si3) + + # + # Division + # + + # integer division + si1 = claripy.SI(bits=32, to_conv=10) + si2 = claripy.SI(bits=32, to_conv=5) + si3 = claripy.SI(bits=32, to_conv=2) + assert claripy.backends.vsa.identical(si1 // si2, si3) + + si3 = claripy.SI(bits=32, to_conv=0) + assert claripy.backends.vsa.identical(si2 // si1, si3) + + # intervals division + si1 = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=100) + si2 = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=20) + si3 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=10) + assert claripy.backends.vsa.identical(si1 // si2, si3) + + # + # Extension + # + + # zero-extension + si1 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0xFD) + si_zext = si1.zero_extend(32 - 8) + si_zext_ = claripy.SI(bits=32, stride=1, lower_bound=0x0, upper_bound=0xFD) + assert claripy.backends.vsa.identical(si_zext, si_zext_) + + # sign-extension + si1 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0xFD) + si_sext = si1.sign_extend(32 - 8) + si_sext_ = claripy.SI(bits=32, stride=1, lower_bound=0xFFFFFF80, upper_bound=0x7F) + assert claripy.backends.vsa.identical(si_sext, si_sext_) + + # + # Comparisons + # + + # -1 == 0xff + si1 = claripy.SI(bits=8, stride=1, lower_bound=-1, upper_bound=-1) + si2 = claripy.SI(bits=8, stride=1, lower_bound=0xFF, upper_bound=0xFF) + assert claripy.backends.vsa.is_true(si1 == si2) + + # -2 != 0xff + si1 = claripy.SI(bits=8, stride=1, lower_bound=-2, upper_bound=-2) + si2 = claripy.SI(bits=8, stride=1, lower_bound=0xFF, upper_bound=0xFF) + assert claripy.backends.vsa.is_true(si1 != si2) + + # [-2, -1] < [1, 2] (signed arithmetic) + si1 = claripy.SI(bits=8, stride=1, lower_bound=1, upper_bound=2) + si2 = claripy.SI(bits=8, stride=1, lower_bound=-2, upper_bound=-1) + assert claripy.backends.vsa.is_true(si2.SLT(si1)) + + # [-2, -1] <= [1, 2] (signed arithmetic) + assert claripy.backends.vsa.is_true(si2.SLE(si1)) + + # [0xfe, 0xff] > [1, 2] (unsigned arithmetic) + assert claripy.backends.vsa.is_true(si2.UGT(si1)) + + # [0xfe, 0xff] >= [1, 2] (unsigned arithmetic) + assert claripy.backends.vsa.is_true(si2.UGE(si1)) + + def test_join(self): + # Set backend + b = claripy.backends.vsa + claripy.solver_backends = [] + + SI = claripy.SI + + a = claripy.SI(bits=8, to_conv=2) + b = claripy.SI(bits=8, to_conv=10) + c = claripy.SI(bits=8, to_conv=120) + d = claripy.SI(bits=8, to_conv=130) + e = claripy.SI(bits=8, to_conv=132) + f = claripy.SI(bits=8, to_conv=135) + + # union a, b, c, d, e => [2, 132] with a stride of 2 + tmp1 = a.union(b) + assert claripy.backends.vsa.identical(tmp1, SI(bits=8, stride=8, lower_bound=2, upper_bound=10)) + tmp2 = tmp1.union(c) + assert claripy.backends.vsa.identical(tmp2, SI(bits=8, stride=2, lower_bound=2, upper_bound=120)) + tmp3 = tmp2.union(d).union(e) + assert claripy.backends.vsa.identical(tmp3, SI(bits=8, stride=2, lower_bound=2, upper_bound=132)) + + # union a, b, c, d, e, f => [2, 135] with a stride of 1 + tmp = a.union(b).union(c).union(d).union(e).union(f) + assert claripy.backends.vsa.identical(tmp, SI(bits=8, stride=1, lower_bound=2, upper_bound=135)) + + a = claripy.SI(bits=8, to_conv=1) + b = claripy.SI(bits=8, to_conv=10) + c = claripy.SI(bits=8, to_conv=120) + d = claripy.SI(bits=8, to_conv=130) + e = claripy.SI(bits=8, to_conv=132) + f = claripy.SI(bits=8, to_conv=135) + g = claripy.SI(bits=8, to_conv=220) + h = claripy.SI(bits=8, to_conv=50) + + # union a, b, c, d, e, f, g, h => [220, 135] with a stride of 1 + tmp = a.union(b).union(c).union(d).union(e).union(f).union(g).union(h) + assert claripy.backends.vsa.identical(tmp, SI(bits=8, stride=1, lower_bound=220, upper_bound=135)) + assert 220 in vsa_model(tmp).eval(255) + assert 225 in vsa_model(tmp).eval(255) + assert 0 in vsa_model(tmp).eval(255) + assert 135 in vsa_model(tmp).eval(255) + assert 138 not in vsa_model(tmp).eval(255) + + def test_vsa(self): + # Set backend + b = claripy.backends.vsa + + SI = claripy.SI + VS = claripy.ValueSet + BVV = claripy.BVV + + # Disable the use of DiscreteStridedIntervalSet + claripy.vsa.strided_interval.allow_dsis = False + + def is_equal(ast_0, ast_1): + return claripy.backends.vsa.identical(ast_0, ast_1) + + si1 = claripy.TSI(32, name="foo", explicit_name=True) + assert vsa_model(si1).name == "foo" + + # Normalization + si1 = SI(bits=32, stride=1, lower_bound=10, upper_bound=10) + assert vsa_model(si1).stride == 0 + + # Integers + si1 = claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10) + si2 = claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10) + si3 = claripy.SI(bits=32, stride=0, lower_bound=28, upper_bound=28) + # Strided intervals + si_a = claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=20) + si_b = claripy.SI(bits=32, stride=2, lower_bound=-100, upper_bound=200) + si_c = claripy.SI(bits=32, stride=3, lower_bound=-100, upper_bound=200) + si_d = claripy.SI(bits=32, stride=2, lower_bound=50, upper_bound=60) + si_e = claripy.SI(bits=16, stride=1, lower_bound=0x2000, upper_bound=0x3000) + si_f = claripy.SI(bits=16, stride=1, lower_bound=0, upper_bound=255) + si_g = claripy.SI(bits=16, stride=1, lower_bound=0, upper_bound=0xFF) + si_h = claripy.SI(bits=32, stride=0, lower_bound=0x80000000, upper_bound=0x80000000) + + assert is_equal(si1, claripy.SI(bits=32, to_conv=10)) + assert is_equal(si2, claripy.SI(bits=32, to_conv=10)) + assert is_equal(si1, si2) + # __add__ + si_add_1 = si1 + si2 + assert is_equal(si_add_1, claripy.SI(bits=32, stride=0, lower_bound=20, upper_bound=20)) + si_add_2 = si1 + si_a + assert is_equal(si_add_2, claripy.SI(bits=32, stride=2, lower_bound=20, upper_bound=30)) + si_add_3 = si_a + si_b + assert is_equal(si_add_3, claripy.SI(bits=32, stride=2, lower_bound=-90, upper_bound=220)) + si_add_4 = si_b + si_c + assert is_equal(si_add_4, claripy.SI(bits=32, stride=1, lower_bound=-200, upper_bound=400)) + # __add__ with overflow + si_add_5 = si_h + 0xFFFFFFFF + assert is_equal( + si_add_5, + claripy.SI(bits=32, stride=0, lower_bound=0x7FFFFFFF, upper_bound=0x7FFFFFFF), + ) + # __sub__ + si_minus_1 = si1 - si2 + assert is_equal(si_minus_1, claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0)) + si_minus_2 = si_a - si_b + assert is_equal(si_minus_2, claripy.SI(bits=32, stride=2, lower_bound=-190, upper_bound=120)) + si_minus_3 = si_b - si_c + assert is_equal(si_minus_3, claripy.SI(bits=32, stride=1, lower_bound=-300, upper_bound=300)) + # __neg__ / __invert__ / bitwise not + si_neg_1 = ~si1 + assert is_equal(si_neg_1, claripy.SI(bits=32, to_conv=-11)) + si_neg_2 = ~si_b + assert is_equal(si_neg_2, claripy.SI(bits=32, stride=2, lower_bound=-201, upper_bound=99)) + # __or__ + si_or_1 = si1 | si3 + assert is_equal(si_or_1, claripy.SI(bits=32, to_conv=30)) + si_or_2 = si1 | si2 + assert is_equal(si_or_2, claripy.SI(bits=32, to_conv=10)) + si_or_3 = si1 | si_a # An integer | a strided interval + assert is_equal(si_or_3, claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=30)) + si_or_3 = si_a | si1 # Exchange the operands + assert is_equal(si_or_3, claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=30)) + si_or_4 = si_a | si_d # A strided interval | another strided interval + assert is_equal(si_or_4, claripy.SI(bits=32, stride=2, lower_bound=50, upper_bound=62)) + si_or_4 = si_d | si_a # Exchange the operands + assert is_equal(si_or_4, claripy.SI(bits=32, stride=2, lower_bound=50, upper_bound=62)) + si_or_5 = si_e | si_f # + assert is_equal(si_or_5, claripy.SI(bits=16, stride=1, lower_bound=0x2000, upper_bound=0x30FF)) + si_or_6 = si_e | si_g # + assert is_equal(si_or_6, claripy.SI(bits=16, stride=1, lower_bound=0x2000, upper_bound=0x30FF)) + # Shifting + si_shl_1 = si1 << 3 + assert si_shl_1.size() == 32 + assert is_equal(si_shl_1, claripy.SI(bits=32, stride=0, lower_bound=80, upper_bound=80)) + # Multiplication + si_mul_1 = si1 * 3 + assert si_mul_1.size() == 32 + assert is_equal(si_mul_1, claripy.SI(bits=32, stride=0, lower_bound=30, upper_bound=30)) + si_mul_2 = si_a * 3 + assert si_mul_2.size() == 32 + assert is_equal(si_mul_2, claripy.SI(bits=32, stride=6, lower_bound=30, upper_bound=60)) + si_mul_3 = si_a * si_b + assert si_mul_3.size() == 32 + assert is_equal(si_mul_3, claripy.SI(bits=32, stride=2, lower_bound=-2000, upper_bound=4000)) + # Division + si_div_1 = si1 // 3 + assert si_div_1.size() == 32 + assert is_equal(si_div_1, claripy.SI(bits=32, stride=0, lower_bound=3, upper_bound=3)) + si_div_2 = si_a // 3 + assert si_div_2.size() == 32 + assert is_equal(si_div_2, claripy.SI(bits=32, stride=1, lower_bound=3, upper_bound=6)) + # Modulo + si_mo_1 = si1 % 3 + assert si_mo_1.size() == 32 + assert is_equal(si_mo_1, claripy.SI(bits=32, stride=0, lower_bound=1, upper_bound=1)) + si_mo_2 = si_a % 3 + assert si_mo_2.size() == 32 + assert is_equal(si_mo_2, claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=2)) + + # + # Extracting the sign bit + # + + # a negative integer + si = claripy.SI(bits=64, stride=0, lower_bound=-1, upper_bound=-1) + sb = si[63:63] + assert is_equal(sb, claripy.SI(bits=1, to_conv=1)) + + # non-positive integers + si = claripy.SI(bits=64, stride=1, lower_bound=-1, upper_bound=0) + sb = si[63:63] + assert is_equal(sb, claripy.SI(bits=1, stride=1, lower_bound=0, upper_bound=1)) + + # Extracting an integer + si = claripy.SI( + bits=64, + stride=0, + lower_bound=0x7FFFFFFFFFFF0000, + upper_bound=0x7FFFFFFFFFFF0000, + ) + part1 = si[63:32] + part2 = si[31:0] + assert is_equal( + part1, + claripy.SI(bits=32, stride=0, lower_bound=0x7FFFFFFF, upper_bound=0x7FFFFFFF), + ) + assert is_equal( + part2, + claripy.SI(bits=32, stride=0, lower_bound=0xFFFF0000, upper_bound=0xFFFF0000), + ) + + # Concatenating two integers + si_concat = part1.concat(part2) + assert is_equal(si_concat, si) + + # Extracting a claripy.SI + si = claripy.SI(bits=64, stride=0x9, lower_bound=0x1, upper_bound=0xA) + part1 = si[63:32] + part2 = si[31:0] + assert is_equal(part1, claripy.SI(bits=32, stride=0, lower_bound=0x0, upper_bound=0x0)) + assert is_equal(part2, claripy.SI(bits=32, stride=9, lower_bound=1, upper_bound=10)) + + # Concatenating two claripy.SIs + si_concat = part1.concat(part2) + assert is_equal(si_concat, si) + + # Concatenating two SIs that are of different sizes + si_1 = SI(bits=64, stride=1, lower_bound=0, upper_bound=0xFFFFFFFFFFFFFFFF) + si_2 = SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFF) + si_concat = si_1.concat(si_2) + assert is_equal( + si_concat, + SI(bits=96, stride=1, lower_bound=0, upper_bound=0xFFFFFFFFFFFFFFFFFFFFFFFF), ) - + claripy.BVV(48, 32) - )[7:0] - - m = claripy.backends.vsa.max(wtf) - assert m > 0 - - -def test_reversed_concat(): - a = claripy.SI("a", 32, lower_bound=10, upper_bound=0x80, stride=10) - b = claripy.SI("b", 32, lower_bound=1, upper_bound=0xFF, stride=1) - - reversed_a = claripy.Reverse(a) - reversed_b = claripy.Reverse(b) - - # First let's check if the reversing makes sense - assert claripy.backends.vsa.min(reversed_a) == 0xA000000 - assert claripy.backends.vsa.max(reversed_a) == 0x80000000 - assert claripy.backends.vsa.min(reversed_b) == 0x1000000 - assert claripy.backends.vsa.max(reversed_b) == 0xFF000000 - - a_concat_b = claripy.Concat(a, b) - assert a_concat_b._model_vsa._reversed is False - - ra_concat_b = claripy.Concat(reversed_a, b) - assert ra_concat_b._model_vsa._reversed is False - - a_concat_rb = claripy.Concat(a, reversed_b) - assert a_concat_rb._model_vsa._reversed is False - - ra_concat_rb = claripy.Concat(reversed_a, reversed_b) - assert ra_concat_rb._model_vsa._reversed is False - - -def test_simple_cardinality(): - x = claripy.BVS("x", 32, 0xA, 0x14, 0xA) - assert x.cardinality == 2 - - -def test_wrapped_intervals(): - # SI = claripy.StridedInterval - - # Disable the use of DiscreteStridedIntervalSet - claripy.vsa.strided_interval.allow_dsis = False - - # - # Signedness/unsignedness conversion - # - - si1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFF) - assert vsa_model(si1)._signed_bounds() == [(0x0, 0x7FFFFFFF), (-0x80000000, -0x1)] - assert vsa_model(si1)._unsigned_bounds() == [(0x0, 0xFFFFFFFF)] - - # - # Pole-splitting - # - - # south-pole splitting - si1 = claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=1) - si_list = vsa_model(si1)._ssplit() - assert len(si_list) == 2 - assert si_list[0].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=-1))) - assert si_list[1].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=1))) - - # north-pole splitting - si1 = claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=-3) - si_list = vsa_model(si1)._nsplit() - assert len(si_list) == 2 - assert si_list[0].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=0x7FFFFFFF))) - assert si_list[1].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=0x80000000, upper_bound=-3))) - - # north-pole splitting, episode 2 - si1 = claripy.SI(bits=32, stride=3, lower_bound=3, upper_bound=0) - si_list = vsa_model(si1)._nsplit() - assert len(si_list) == 2 - assert si_list[0].identical(vsa_model(claripy.SI(bits=32, stride=3, lower_bound=3, upper_bound=0x7FFFFFFE))) - assert si_list[1].identical(vsa_model(claripy.SI(bits=32, stride=3, lower_bound=0x80000001, upper_bound=0))) - - # bipolar splitting - si1 = claripy.SI(bits=32, stride=1, lower_bound=-2, upper_bound=-8) - si_list = vsa_model(si1)._psplit() - assert len(si_list) == 3 - assert si_list[0].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=-2, upper_bound=-1))) - assert si_list[1].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0x7FFFFFFF))) - assert si_list[2].identical(vsa_model(claripy.SI(bits=32, stride=1, lower_bound=0x80000000, upper_bound=-8))) - - # - # Addition - # - - # Plain addition - si1 = claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=1) - si2 = claripy.SI(bits=32, stride=1, lower_bound=-1, upper_bound=1) - si3 = claripy.SI(bits=32, stride=1, lower_bound=-2, upper_bound=2) - assert claripy.backends.vsa.identical(si1 + si2, si3) - si4 = claripy.SI(bits=32, stride=1, lower_bound=0xFFFFFFFE, upper_bound=2) - assert claripy.backends.vsa.identical(si1 + si2, si4) - si5 = claripy.SI(bits=32, stride=1, lower_bound=2, upper_bound=-2) - assert not claripy.backends.vsa.identical(si1 + si2, si5) - - # Addition with overflowing cardinality - si1 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0xFE) - si2 = claripy.SI(bits=8, stride=1, lower_bound=0xFE, upper_bound=0xFF) - assert vsa_model(si1 + si2).is_top - - # Addition that shouldn't get a TOP - si1 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0xFE) - si2 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0) - assert not vsa_model(si1 + si2).is_top - - # - # Subtraction - # - - si1 = claripy.SI(bits=8, stride=1, lower_bound=10, upper_bound=15) - si2 = claripy.SI(bits=8, stride=1, lower_bound=11, upper_bound=12) - si3 = claripy.SI(bits=8, stride=1, lower_bound=-2, upper_bound=4) - assert claripy.backends.vsa.identical(si1 - si2, si3) - - # - # Multiplication - # - - # integer multiplication - si1 = claripy.SI(bits=32, to_conv=0xFFFF) - si2 = claripy.SI(bits=32, to_conv=0x10000) - si3 = claripy.SI(bits=32, to_conv=0xFFFF0000) - assert claripy.backends.vsa.identical(si1 * si2, si3) - - # intervals multiplication - si1 = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=15) - si2 = claripy.SI(bits=32, stride=1, lower_bound=20, upper_bound=30) - si3 = claripy.SI(bits=32, stride=1, lower_bound=200, upper_bound=450) - assert claripy.backends.vsa.identical(si1 * si2, si3) - - # - # Division - # - - # integer division - si1 = claripy.SI(bits=32, to_conv=10) - si2 = claripy.SI(bits=32, to_conv=5) - si3 = claripy.SI(bits=32, to_conv=2) - assert claripy.backends.vsa.identical(si1 / si2, si3) - - si3 = claripy.SI(bits=32, to_conv=0) - assert claripy.backends.vsa.identical(si2 / si1, si3) - - # intervals division - si1 = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=100) - si2 = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=20) - si3 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=10) - assert claripy.backends.vsa.identical(si1 / si2, si3) - - # - # Extension - # - - # zero-extension - si1 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0xFD) - si_zext = si1.zero_extend(32 - 8) - si_zext_ = claripy.SI(bits=32, stride=1, lower_bound=0x0, upper_bound=0xFD) - assert claripy.backends.vsa.identical(si_zext, si_zext_) - - # sign-extension - si1 = claripy.SI(bits=8, stride=1, lower_bound=0, upper_bound=0xFD) - si_sext = si1.sign_extend(32 - 8) - si_sext_ = claripy.SI(bits=32, stride=1, lower_bound=0xFFFFFF80, upper_bound=0x7F) - assert claripy.backends.vsa.identical(si_sext, si_sext_) - - # - # Comparisons - # - - # -1 == 0xff - si1 = claripy.SI(bits=8, stride=1, lower_bound=-1, upper_bound=-1) - si2 = claripy.SI(bits=8, stride=1, lower_bound=0xFF, upper_bound=0xFF) - assert claripy.backends.vsa.is_true(si1 == si2) - - # -2 != 0xff - si1 = claripy.SI(bits=8, stride=1, lower_bound=-2, upper_bound=-2) - si2 = claripy.SI(bits=8, stride=1, lower_bound=0xFF, upper_bound=0xFF) - assert claripy.backends.vsa.is_true(si1 != si2) - - # [-2, -1] < [1, 2] (signed arithmetic) - si1 = claripy.SI(bits=8, stride=1, lower_bound=1, upper_bound=2) - si2 = claripy.SI(bits=8, stride=1, lower_bound=-2, upper_bound=-1) - assert claripy.backends.vsa.is_true(si2.SLT(si1)) - - # [-2, -1] <= [1, 2] (signed arithmetic) - assert claripy.backends.vsa.is_true(si2.SLE(si1)) - - # [0xfe, 0xff] > [1, 2] (unsigned arithmetic) - assert claripy.backends.vsa.is_true(si2.UGT(si1)) - - # [0xfe, 0xff] >= [1, 2] (unsigned arithmetic) - assert claripy.backends.vsa.is_true(si2.UGE(si1)) - - -def test_join(): - # Set backend - b = claripy.backends.vsa - claripy.solver_backends = [] - - SI = claripy.SI - - a = claripy.SI(bits=8, to_conv=2) - b = claripy.SI(bits=8, to_conv=10) - c = claripy.SI(bits=8, to_conv=120) - d = claripy.SI(bits=8, to_conv=130) - e = claripy.SI(bits=8, to_conv=132) - f = claripy.SI(bits=8, to_conv=135) - - # union a, b, c, d, e => [2, 132] with a stride of 2 - tmp1 = a.union(b) - assert claripy.backends.vsa.identical(tmp1, SI(bits=8, stride=8, lower_bound=2, upper_bound=10)) - tmp2 = tmp1.union(c) - assert claripy.backends.vsa.identical(tmp2, SI(bits=8, stride=2, lower_bound=2, upper_bound=120)) - tmp3 = tmp2.union(d).union(e) - assert claripy.backends.vsa.identical(tmp3, SI(bits=8, stride=2, lower_bound=2, upper_bound=132)) - - # union a, b, c, d, e, f => [2, 135] with a stride of 1 - tmp = a.union(b).union(c).union(d).union(e).union(f) - assert claripy.backends.vsa.identical(tmp, SI(bits=8, stride=1, lower_bound=2, upper_bound=135)) - - a = claripy.SI(bits=8, to_conv=1) - b = claripy.SI(bits=8, to_conv=10) - c = claripy.SI(bits=8, to_conv=120) - d = claripy.SI(bits=8, to_conv=130) - e = claripy.SI(bits=8, to_conv=132) - f = claripy.SI(bits=8, to_conv=135) - g = claripy.SI(bits=8, to_conv=220) - h = claripy.SI(bits=8, to_conv=50) - - # union a, b, c, d, e, f, g, h => [220, 135] with a stride of 1 - tmp = a.union(b).union(c).union(d).union(e).union(f).union(g).union(h) - assert claripy.backends.vsa.identical(tmp, SI(bits=8, stride=1, lower_bound=220, upper_bound=135)) - assert 220 in vsa_model(tmp).eval(255) - assert 225 in vsa_model(tmp).eval(255) - assert 0 in vsa_model(tmp).eval(255) - assert 135 in vsa_model(tmp).eval(255) - assert 138 not in vsa_model(tmp).eval(255) - - -def test_vsa(): - # Set backend - b = claripy.backends.vsa - - SI = claripy.SI - VS = claripy.ValueSet - BVV = claripy.BVV - - # Disable the use of DiscreteStridedIntervalSet - claripy.vsa.strided_interval.allow_dsis = False - - def is_equal(ast_0, ast_1): - return claripy.backends.vsa.identical(ast_0, ast_1) - - si1 = claripy.TSI(32, name="foo", explicit_name=True) - assert vsa_model(si1).name == "foo" - - # Normalization - si1 = SI(bits=32, stride=1, lower_bound=10, upper_bound=10) - assert vsa_model(si1).stride == 0 - - # Integers - si1 = claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10) - si2 = claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10) - si3 = claripy.SI(bits=32, stride=0, lower_bound=28, upper_bound=28) - # Strided intervals - si_a = claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=20) - si_b = claripy.SI(bits=32, stride=2, lower_bound=-100, upper_bound=200) - si_c = claripy.SI(bits=32, stride=3, lower_bound=-100, upper_bound=200) - si_d = claripy.SI(bits=32, stride=2, lower_bound=50, upper_bound=60) - si_e = claripy.SI(bits=16, stride=1, lower_bound=0x2000, upper_bound=0x3000) - si_f = claripy.SI(bits=16, stride=1, lower_bound=0, upper_bound=255) - si_g = claripy.SI(bits=16, stride=1, lower_bound=0, upper_bound=0xFF) - si_h = claripy.SI(bits=32, stride=0, lower_bound=0x80000000, upper_bound=0x80000000) - - assert is_equal(si1, claripy.SI(bits=32, to_conv=10)) - assert is_equal(si2, claripy.SI(bits=32, to_conv=10)) - assert is_equal(si1, si2) - # __add__ - si_add_1 = si1 + si2 - assert is_equal(si_add_1, claripy.SI(bits=32, stride=0, lower_bound=20, upper_bound=20)) - si_add_2 = si1 + si_a - assert is_equal(si_add_2, claripy.SI(bits=32, stride=2, lower_bound=20, upper_bound=30)) - si_add_3 = si_a + si_b - assert is_equal(si_add_3, claripy.SI(bits=32, stride=2, lower_bound=-90, upper_bound=220)) - si_add_4 = si_b + si_c - assert is_equal(si_add_4, claripy.SI(bits=32, stride=1, lower_bound=-200, upper_bound=400)) - # __add__ with overflow - si_add_5 = si_h + 0xFFFFFFFF - assert is_equal( - si_add_5, - claripy.SI(bits=32, stride=0, lower_bound=0x7FFFFFFF, upper_bound=0x7FFFFFFF), - ) - # __sub__ - si_minus_1 = si1 - si2 - assert is_equal(si_minus_1, claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0)) - si_minus_2 = si_a - si_b - assert is_equal(si_minus_2, claripy.SI(bits=32, stride=2, lower_bound=-190, upper_bound=120)) - si_minus_3 = si_b - si_c - assert is_equal(si_minus_3, claripy.SI(bits=32, stride=1, lower_bound=-300, upper_bound=300)) - # __neg__ / __invert__ / bitwise not - si_neg_1 = ~si1 - assert is_equal(si_neg_1, claripy.SI(bits=32, to_conv=-11)) - si_neg_2 = ~si_b - assert is_equal(si_neg_2, claripy.SI(bits=32, stride=2, lower_bound=-201, upper_bound=99)) - # __or__ - si_or_1 = si1 | si3 - assert is_equal(si_or_1, claripy.SI(bits=32, to_conv=30)) - si_or_2 = si1 | si2 - assert is_equal(si_or_2, claripy.SI(bits=32, to_conv=10)) - si_or_3 = si1 | si_a # An integer | a strided interval - assert is_equal(si_or_3, claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=30)) - si_or_3 = si_a | si1 # Exchange the operands - assert is_equal(si_or_3, claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=30)) - si_or_4 = si_a | si_d # A strided interval | another strided interval - assert is_equal(si_or_4, claripy.SI(bits=32, stride=2, lower_bound=50, upper_bound=62)) - si_or_4 = si_d | si_a # Exchange the operands - assert is_equal(si_or_4, claripy.SI(bits=32, stride=2, lower_bound=50, upper_bound=62)) - si_or_5 = si_e | si_f # - assert is_equal(si_or_5, claripy.SI(bits=16, stride=1, lower_bound=0x2000, upper_bound=0x30FF)) - si_or_6 = si_e | si_g # - assert is_equal(si_or_6, claripy.SI(bits=16, stride=1, lower_bound=0x2000, upper_bound=0x30FF)) - # Shifting - si_shl_1 = si1 << 3 - assert si_shl_1.size() == 32 - assert is_equal(si_shl_1, claripy.SI(bits=32, stride=0, lower_bound=80, upper_bound=80)) - # Multiplication - si_mul_1 = si1 * 3 - assert si_mul_1.size() == 32 - assert is_equal(si_mul_1, claripy.SI(bits=32, stride=0, lower_bound=30, upper_bound=30)) - si_mul_2 = si_a * 3 - assert si_mul_2.size() == 32 - assert is_equal(si_mul_2, claripy.SI(bits=32, stride=6, lower_bound=30, upper_bound=60)) - si_mul_3 = si_a * si_b - assert si_mul_3.size() == 32 - assert is_equal(si_mul_3, claripy.SI(bits=32, stride=2, lower_bound=-2000, upper_bound=4000)) - # Division - si_div_1 = si1 / 3 - assert si_div_1.size() == 32 - assert is_equal(si_div_1, claripy.SI(bits=32, stride=0, lower_bound=3, upper_bound=3)) - si_div_2 = si_a / 3 - assert si_div_2.size() == 32 - assert is_equal(si_div_2, claripy.SI(bits=32, stride=1, lower_bound=3, upper_bound=6)) - # Modulo - si_mo_1 = si1 % 3 - assert si_mo_1.size() == 32 - assert is_equal(si_mo_1, claripy.SI(bits=32, stride=0, lower_bound=1, upper_bound=1)) - si_mo_2 = si_a % 3 - assert si_mo_2.size() == 32 - assert is_equal(si_mo_2, claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=2)) - - # - # Extracting the sign bit - # - - # a negative integer - si = claripy.SI(bits=64, stride=0, lower_bound=-1, upper_bound=-1) - sb = si[63:63] - assert is_equal(sb, claripy.SI(bits=1, to_conv=1)) - - # non-positive integers - si = claripy.SI(bits=64, stride=1, lower_bound=-1, upper_bound=0) - sb = si[63:63] - assert is_equal(sb, claripy.SI(bits=1, stride=1, lower_bound=0, upper_bound=1)) - - # Extracting an integer - si = claripy.SI( - bits=64, - stride=0, - lower_bound=0x7FFFFFFFFFFF0000, - upper_bound=0x7FFFFFFFFFFF0000, - ) - part1 = si[63:32] - part2 = si[31:0] - assert is_equal( - part1, - claripy.SI(bits=32, stride=0, lower_bound=0x7FFFFFFF, upper_bound=0x7FFFFFFF), - ) - assert is_equal( - part2, - claripy.SI(bits=32, stride=0, lower_bound=0xFFFF0000, upper_bound=0xFFFF0000), - ) - - # Concatenating two integers - si_concat = part1.concat(part2) - assert is_equal(si_concat, si) - - # Extracting a claripy.SI - si = claripy.SI(bits=64, stride=0x9, lower_bound=0x1, upper_bound=0xA) - part1 = si[63:32] - part2 = si[31:0] - assert is_equal(part1, claripy.SI(bits=32, stride=0, lower_bound=0x0, upper_bound=0x0)) - assert is_equal(part2, claripy.SI(bits=32, stride=9, lower_bound=1, upper_bound=10)) - - # Concatenating two claripy.SIs - si_concat = part1.concat(part2) - assert is_equal(si_concat, si) - - # Concatenating two SIs that are of different sizes - si_1 = SI(bits=64, stride=1, lower_bound=0, upper_bound=0xFFFFFFFFFFFFFFFF) - si_2 = SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFF) - si_concat = si_1.concat(si_2) - assert is_equal( - si_concat, - SI(bits=96, stride=1, lower_bound=0, upper_bound=0xFFFFFFFFFFFFFFFFFFFFFFFF), - ) - - # Zero-Extend the low part - si_zeroextended = part2.zero_extend(32) - assert is_equal(si_zeroextended, claripy.SI(bits=64, stride=9, lower_bound=1, upper_bound=10)) - - # Sign-extension - si_signextended = part2.sign_extend(32) - assert is_equal(si_signextended, claripy.SI(bits=64, stride=9, lower_bound=1, upper_bound=10)) - - # Extract from the result above - si_extracted = si_zeroextended[31:0] - assert is_equal(si_extracted, claripy.SI(bits=32, stride=9, lower_bound=1, upper_bound=10)) - - # Union - si_union_1 = si1.union(si2) - assert is_equal(si_union_1, claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10)) - si_union_2 = si1.union(si3) - assert is_equal(si_union_2, claripy.SI(bits=32, stride=18, lower_bound=10, upper_bound=28)) - si_union_3 = si1.union(si_a) - assert is_equal(si_union_3, claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=20)) - si_union_4 = si_a.union(si_b) - assert is_equal(si_union_4, claripy.SI(bits=32, stride=2, lower_bound=-100, upper_bound=200)) - si_union_5 = si_b.union(si_c) - assert is_equal(si_union_5, claripy.SI(bits=32, stride=1, lower_bound=-100, upper_bound=200)) - - # Intersection - si_intersection_1 = si1.intersection(si1) - assert is_equal(si_intersection_1, si2) - si_intersection_2 = si1.intersection(si2) - assert is_equal(si_intersection_2, claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10)) - si_intersection_3 = si1.intersection(si_a) - assert is_equal(si_intersection_3, claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10)) - - si_intersection_4 = si_a.intersection(si_b) - - assert is_equal(si_intersection_4, claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=20)) - si_intersection_5 = si_b.intersection(si_c) - assert is_equal( - si_intersection_5, - claripy.SI(bits=32, stride=6, lower_bound=-100, upper_bound=200), - ) - - # More intersections - t0 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0x27) - t1 = claripy.SI(bits=32, stride=0x7FFFFFFF, lower_bound=0x80000002, upper_bound=1) - - si_is_6 = t0.intersection(t1) - assert is_equal(si_is_6, claripy.SI(bits=32, stride=0, lower_bound=1, upper_bound=1)) - - t2 = claripy.SI(bits=32, stride=5, lower_bound=20, upper_bound=30) - t3 = claripy.SI(bits=32, stride=1, lower_bound=27, upper_bound=0xFFFFFFFF) - - si_is_7 = t2.intersection(t3) - assert is_equal(si_is_7, claripy.SI(bits=32, stride=0, lower_bound=30, upper_bound=30)) - - t4 = claripy.SI(bits=32, stride=5, lower_bound=-400, upper_bound=400) - t5 = claripy.SI(bits=32, stride=1, lower_bound=395, upper_bound=-395) - si_is_8 = t4.intersection(t5) - assert is_equal(si_is_8, claripy.SI(bits=32, stride=5, lower_bound=-400, upper_bound=400)) - - # Sign-extension - si = claripy.SI(bits=1, stride=0, lower_bound=1, upper_bound=1) - si_signextended = si.sign_extend(31) - assert is_equal( - si_signextended, - claripy.SI(bits=32, stride=0, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), - ) - - # Comparison between claripy.SI and BVV - si = claripy.SI(bits=32, stride=1, lower_bound=-0x7F, upper_bound=0x7F) - si._model_vsa.uninitialized = True - bvv = BVV(0x30, 32) - comp = si < bvv - assert vsa_model(comp).identical(MaybeResult()) - - # Better extraction - # si = <32>0x1000000[0xcffffff, 0xdffffff]R - si = claripy.SI(bits=32, stride=0x1000000, lower_bound=0xCFFFFFF, upper_bound=0xDFFFFFF) - si_byte0 = si[7:0] - si_byte1 = si[15:8] - si_byte2 = si[23:16] - si_byte3 = si[31:24] - assert is_equal(si_byte0, claripy.SI(bits=8, stride=0, lower_bound=0xFF, upper_bound=0xFF)) - assert is_equal(si_byte1, claripy.SI(bits=8, stride=0, lower_bound=0xFF, upper_bound=0xFF)) - assert is_equal(si_byte2, claripy.SI(bits=8, stride=0, lower_bound=0xFF, upper_bound=0xFF)) - assert is_equal(si_byte3, claripy.SI(bits=8, stride=1, lower_bound=0xC, upper_bound=0xD)) - - # Optimization on bitwise-and - si_1 = claripy.SI(bits=32, stride=1, lower_bound=0x0, upper_bound=0xFFFFFFFF) - si_2 = claripy.SI(bits=32, stride=0, lower_bound=0x80000000, upper_bound=0x80000000) - si = si_1 & si_2 - assert is_equal( - si, - claripy.SI(bits=32, stride=0x80000000, lower_bound=0, upper_bound=0x80000000), - ) - - si_1 = claripy.SI(bits=32, stride=1, lower_bound=0x0, upper_bound=0x7FFFFFFF) - si_2 = claripy.SI(bits=32, stride=0, lower_bound=0x80000000, upper_bound=0x80000000) - si = si_1 & si_2 - assert is_equal(si, claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0)) - - # Concatenation: concat with zeros only increases the stride - si_1 = claripy.SI(bits=8, stride=0xFF, lower_bound=0x0, upper_bound=0xFF) - si_2 = claripy.SI(bits=8, stride=0, lower_bound=0, upper_bound=0) - si = si_1.concat(si_2) - assert is_equal(si, claripy.SI(bits=16, stride=0xFF00, lower_bound=0, upper_bound=0xFF00)) - - # Extract from a reversed value - si_1 = claripy.SI(bits=64, stride=0xFF, lower_bound=0x0, upper_bound=0xFF) - si_2 = si_1.reversed[63:56] - assert is_equal(si_2, claripy.SI(bits=8, stride=0xFF, lower_bound=0x0, upper_bound=0xFF)) - - # - # ValueSet - # - - def VS(name=None, bits=None, region=None, val=None): # noqa: F811 # TODO: Refactor this test case - region = "foobar" if region is None else region - return claripy.ValueSet(bits, region=region, region_base_addr=0, value=val, name=name) - - vs_1 = VS(bits=32, val=0) - vs_1 = vs_1.intersection(VS(bits=32, val=1)) - assert vsa_model(vs_1).is_empty - # Test merging two addresses - vsa_model(vs_1)._merge_si("global", 0, vsa_model(si1)) - vsa_model(vs_1)._merge_si("global", 0, vsa_model(si3)) - assert vsa_model(vs_1).get_si("global").identical(vsa_model(SI(bits=32, stride=18, lower_bound=10, upper_bound=28))) - # Length of this ValueSet - assert len(vsa_model(vs_1)) == 32 - - vs_1 = VS(name="boo", bits=32, val=0).intersection(VS(name="makeitempty", bits=32, val=1)) - vs_2 = VS(name="foo", bits=32, val=0).intersection(VS(name="makeitempty", bits=32, val=1)) - assert claripy.backends.vsa.identical(vs_1, vs_1) - assert claripy.backends.vsa.identical(vs_2, vs_2) - vsa_model(vs_1)._merge_si("global", 0, vsa_model(si1)) - assert not claripy.backends.vsa.identical(vs_1, vs_2) - vsa_model(vs_2)._merge_si("global", 0, vsa_model(si1)) - assert claripy.backends.vsa.identical(vs_1, vs_2) - assert claripy.backends.vsa.is_true((vs_1 & vs_2) == vs_1) - vsa_model(vs_1)._merge_si("global", 0, vsa_model(si3)) - assert not claripy.backends.vsa.identical(vs_1, vs_2) - - # Subtraction - # Subtraction of two pointers yields a concrete value - - vs_1 = VS(name="foo", region="global", bits=32, val=0x400010) - vs_2 = VS(name="bar", region="global", bits=32, val=0x400000) - si = vs_1 - vs_2 - assert type(vsa_model(si)) is StridedInterval - assert claripy.backends.vsa.identical(si, claripy.SI(bits=32, stride=0, lower_bound=0x10, upper_bound=0x10)) - - # - # IfProxy - # - - si = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=0xFFFFFFFF) - if_0 = claripy.If(si == 0, si, si - 1) - assert claripy.backends.vsa.identical(if_0, if_0) - assert not claripy.backends.vsa.identical(if_0, si) - - # max and min on IfProxy - si = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFF) - if_0 = claripy.If(si == 0, si, si - 1) - max_val = b.max(if_0) - min_val = b.min(if_0) - assert max_val == 0xFFFFFFFF - assert min_val == 0x00000000 - - # identical - assert claripy.backends.vsa.identical(if_0, if_0) - assert claripy.backends.vsa.identical(if_0, si) - if_0_copy = claripy.If(si == 0, si, si - 1) - assert claripy.backends.vsa.identical(if_0, if_0_copy) - if_1 = claripy.If(si == 1, si, si - 1) - assert claripy.backends.vsa.identical(if_0, if_1) - - si = SI(bits=32, stride=0, lower_bound=1, upper_bound=1) - if_0 = claripy.If(si == 0, si, si - 1) - if_0_copy = claripy.If(si == 0, si, si - 1) - assert claripy.backends.vsa.identical(if_0, if_0_copy) - if_1 = claripy.If(si == 1, si, si - 1) - assert not claripy.backends.vsa.identical(if_0, if_1) - if_1 = claripy.If(si == 0, si + 1, si - 1) - assert claripy.backends.vsa.identical(if_0, if_1) - if_1 = claripy.If(si == 0, si, si) - assert not claripy.backends.vsa.identical(if_0, if_1) - - # if_1 = And(VS_2, IfProxy(si == 0, 0, 1)) - vs_2 = VS(region="global", bits=32, val=0xFA7B00B) - si = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=1) - if_1 = vs_2 & claripy.If( - si == 0, - claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0), - claripy.SI(bits=32, stride=0, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), - ) - assert claripy.backends.vsa.is_true( - vsa_model(if_1.ite_excavated.args[1]) == vsa_model(VS(region="global", bits=32, val=0)) - ) - assert claripy.backends.vsa.is_true(vsa_model(if_1.ite_excavated.args[2]) == vsa_model(vs_2)) - - # if_2 = And(VS_3, IfProxy(si != 0, 0, 1) - vs_3 = VS(region="global", bits=32, val=0xDEADCA7) - si = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=1) - if_2 = vs_3 & claripy.If( - si != 0, - claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0), - claripy.SI(bits=32, stride=0, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), - ) - assert claripy.backends.vsa.is_true( - vsa_model(if_2.ite_excavated.args[1]) == vsa_model(VS(region="global", bits=32, val=0)) - ) - assert claripy.backends.vsa.is_true(vsa_model(if_2.ite_excavated.args[2]) == vsa_model(vs_3)) - - # Something crazy is gonna happen... - # if_3 = if_1 + if_2 - # assert claripy.backends.vsa.is_true(vsa_model(if_3.ite_excavated.args[1]) == vsa_model(vs_3))) - # assert claripy.backends.vsa.is_true(vsa_model(if_3.ite_excavated.args[1]) == vsa_model(vs_2))) - - -def test_vsa_constraint_to_si(): - # Set backend - b = claripy.backends.vsa - s = claripy.SolverVSA() # pylint:disable=unused-variable - - SI = claripy.SI - BVV = claripy.BVV - - claripy.vsa.strided_interval.allow_dsis = False - - # - # If(SI == 0, 1, 0) == 1 - # - - s1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=2) - ast_true = claripy.If(s1 == BVV(0, 32), BVV(1, 1), BVV(0, 1)) == BVV(1, 1) - ast_false = claripy.If(s1 == BVV(0, 32), BVV(1, 1), BVV(0, 1)) != BVV(1, 1) - - trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) - assert trueside_sat - assert len(trueside_replacement) == 1 - assert trueside_replacement[0][0] is s1 - # True side: claripy.SI<32>0[0, 0] - assert claripy.backends.vsa.is_true( - trueside_replacement[0][1] == claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0) - ) - - falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) - assert falseside_sat is True - assert len(falseside_replacement) == 1 - assert falseside_replacement[0][0] is s1 - # False side; claripy.SI<32>1[1, 2] - - assert claripy.backends.vsa.identical( - falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) - ) - # - # If(SI == 0, 1, 0) <= 1 - # - - s1 = SI(bits=32, stride=1, lower_bound=0, upper_bound=2) - ast_true = claripy.If(s1 == BVV(0, 32), BVV(1, 1), BVV(0, 1)) <= BVV(1, 1) - ast_false = claripy.If(s1 == BVV(0, 32), BVV(1, 1), BVV(0, 1)) > BVV(1, 1) - - trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) - assert trueside_sat # Always satisfiable - - falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) - assert not falseside_sat # Not sat - - # - # If(SI == 0, 20, 10) > 15 - # - - s1 = SI(bits=32, stride=1, lower_bound=0, upper_bound=2) - ast_true = claripy.If(s1 == BVV(0, 32), BVV(20, 32), BVV(10, 32)) > BVV(15, 32) - ast_false = claripy.If(s1 == BVV(0, 32), BVV(20, 32), BVV(10, 32)) <= BVV(15, 32) - - trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) - assert trueside_sat - assert len(trueside_replacement) == 1 - assert trueside_replacement[0][0] is s1 - # True side: SI<32>0[0, 0] - assert claripy.backends.vsa.identical( - trueside_replacement[0][1], SI(bits=32, stride=0, lower_bound=0, upper_bound=0) - ) - - falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) - assert falseside_sat - assert len(falseside_replacement) == 1 - assert falseside_replacement[0][0] is s1 - # False side; SI<32>1[1, 2] - assert claripy.backends.vsa.identical( - falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) - ) - - # - # If(SI == 0, 20, 10) >= 15 - # - - s1 = SI(bits=32, stride=1, lower_bound=0, upper_bound=2) - ast_true = claripy.If(s1 == BVV(0, 32), BVV(15, 32), BVV(10, 32)) >= BVV(15, 32) - ast_false = claripy.If(s1 == BVV(0, 32), BVV(15, 32), BVV(10, 32)) < BVV(15, 32) - - trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) - assert trueside_sat - assert len(trueside_replacement) == 1 - assert trueside_replacement[0][0] is s1 - # True side: SI<32>0[0, 0] - assert claripy.backends.vsa.identical( - trueside_replacement[0][1], SI(bits=32, stride=0, lower_bound=0, upper_bound=0) - ) - - falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) - assert falseside_sat - assert len(falseside_replacement) == 1 - assert falseside_replacement[0][0] is s1 - # False side; SI<32>0[0,0] - assert claripy.backends.vsa.identical( - falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) - ) - - # - # Extract(0, 0, Concat(BVV(0, 63), If(SI == 0, 1, 0))) == 1 - # - - s2 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=2) - ast_true = claripy.Extract(0, 0, claripy.Concat(BVV(0, 63), claripy.If(s2 == 0, BVV(1, 1), BVV(0, 1)))) == 1 - ast_false = claripy.Extract(0, 0, claripy.Concat(BVV(0, 63), claripy.If(s2 == 0, BVV(1, 1), BVV(0, 1)))) != 1 - - trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) - assert trueside_sat - assert len(trueside_replacement) == 1 - assert trueside_replacement[0][0] is s2 - # True side: claripy.SI<32>0[0, 0] - assert claripy.backends.vsa.identical( - trueside_replacement[0][1], SI(bits=32, stride=0, lower_bound=0, upper_bound=0) - ) - - falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) - assert falseside_sat - assert len(falseside_replacement) == 1 - assert falseside_replacement[0][0] is s2 - # False side; claripy.SI<32>1[1, 2] - assert claripy.backends.vsa.identical( - falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) - ) - - # - # Extract(0, 0, ZeroExt(32, If(SI == 0, BVV(1, 32), BVV(0, 32)))) == 1 - # - - s3 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=2) - ast_true = claripy.Extract(0, 0, claripy.ZeroExt(32, claripy.If(s3 == 0, BVV(1, 32), BVV(0, 32)))) == 1 - ast_false = claripy.Extract(0, 0, claripy.ZeroExt(32, claripy.If(s3 == 0, BVV(1, 32), BVV(0, 32)))) != 1 - - trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) - assert trueside_sat - assert len(trueside_replacement) == 1 - assert trueside_replacement[0][0] is s3 - # True side: claripy.SI<32>0[0, 0] - assert claripy.backends.vsa.identical( - trueside_replacement[0][1], SI(bits=32, stride=0, lower_bound=0, upper_bound=0) - ) - - falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) - assert falseside_sat - assert len(falseside_replacement) == 1 - assert falseside_replacement[0][0] is s3 - # False side; claripy.SI<32>1[1, 2] - assert claripy.backends.vsa.identical( - falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) - ) - - # - # Extract(0, 0, ZeroExt(32, If(Extract(32, 0, (SI & claripy.SI)) < 0, BVV(1, 1), BVV(0, 1)))) - # - - s4 = claripy.SI(bits=64, stride=1, lower_bound=0, upper_bound=0xFFFFFFFFFFFFFFFF) - ast_true = ( - claripy.Extract( - 0, - 0, - claripy.ZeroExt( - 32, - claripy.If(claripy.Extract(31, 0, (s4 & s4)).SLT(0), BVV(1, 32), BVV(0, 32)), - ), + + # Zero-Extend the low part + si_zeroextended = part2.zero_extend(32) + assert is_equal(si_zeroextended, claripy.SI(bits=64, stride=9, lower_bound=1, upper_bound=10)) + + # Sign-extension + si_signextended = part2.sign_extend(32) + assert is_equal(si_signextended, claripy.SI(bits=64, stride=9, lower_bound=1, upper_bound=10)) + + # Extract from the result above + si_extracted = si_zeroextended[31:0] + assert is_equal(si_extracted, claripy.SI(bits=32, stride=9, lower_bound=1, upper_bound=10)) + + # Union + si_union_1 = si1.union(si2) + assert is_equal(si_union_1, claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10)) + si_union_2 = si1.union(si3) + assert is_equal(si_union_2, claripy.SI(bits=32, stride=18, lower_bound=10, upper_bound=28)) + si_union_3 = si1.union(si_a) + assert is_equal(si_union_3, claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=20)) + si_union_4 = si_a.union(si_b) + assert is_equal(si_union_4, claripy.SI(bits=32, stride=2, lower_bound=-100, upper_bound=200)) + si_union_5 = si_b.union(si_c) + assert is_equal(si_union_5, claripy.SI(bits=32, stride=1, lower_bound=-100, upper_bound=200)) + + # Intersection + si_intersection_1 = si1.intersection(si1) + assert is_equal(si_intersection_1, si2) + si_intersection_2 = si1.intersection(si2) + assert is_equal(si_intersection_2, claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10)) + si_intersection_3 = si1.intersection(si_a) + assert is_equal(si_intersection_3, claripy.SI(bits=32, stride=0, lower_bound=10, upper_bound=10)) + + si_intersection_4 = si_a.intersection(si_b) + + assert is_equal(si_intersection_4, claripy.SI(bits=32, stride=2, lower_bound=10, upper_bound=20)) + si_intersection_5 = si_b.intersection(si_c) + assert is_equal( + si_intersection_5, + claripy.SI(bits=32, stride=6, lower_bound=-100, upper_bound=200), + ) + + # More intersections + t0 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0x27) + t1 = claripy.SI(bits=32, stride=0x7FFFFFFF, lower_bound=0x80000002, upper_bound=1) + + si_is_6 = t0.intersection(t1) + assert is_equal(si_is_6, claripy.SI(bits=32, stride=0, lower_bound=1, upper_bound=1)) + + t2 = claripy.SI(bits=32, stride=5, lower_bound=20, upper_bound=30) + t3 = claripy.SI(bits=32, stride=1, lower_bound=27, upper_bound=0xFFFFFFFF) + + si_is_7 = t2.intersection(t3) + assert is_equal(si_is_7, claripy.SI(bits=32, stride=0, lower_bound=30, upper_bound=30)) + + t4 = claripy.SI(bits=32, stride=5, lower_bound=-400, upper_bound=400) + t5 = claripy.SI(bits=32, stride=1, lower_bound=395, upper_bound=-395) + si_is_8 = t4.intersection(t5) + assert is_equal(si_is_8, claripy.SI(bits=32, stride=5, lower_bound=-400, upper_bound=400)) + + # Sign-extension + si = claripy.SI(bits=1, stride=0, lower_bound=1, upper_bound=1) + si_signextended = si.sign_extend(31) + assert is_equal( + si_signextended, + claripy.SI(bits=32, stride=0, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), + ) + + # Comparison between claripy.SI and BVV + si = claripy.SI(bits=32, stride=1, lower_bound=-0x7F, upper_bound=0x7F) + si._model_vsa.uninitialized = True + bvv = BVV(0x30, 32) + comp = si < bvv + assert vsa_model(comp).identical(MaybeResult()) + + # Better extraction + # si = <32>0x1000000[0xcffffff, 0xdffffff]R + si = claripy.SI(bits=32, stride=0x1000000, lower_bound=0xCFFFFFF, upper_bound=0xDFFFFFF) + si_byte0 = si[7:0] + si_byte1 = si[15:8] + si_byte2 = si[23:16] + si_byte3 = si[31:24] + assert is_equal(si_byte0, claripy.SI(bits=8, stride=0, lower_bound=0xFF, upper_bound=0xFF)) + assert is_equal(si_byte1, claripy.SI(bits=8, stride=0, lower_bound=0xFF, upper_bound=0xFF)) + assert is_equal(si_byte2, claripy.SI(bits=8, stride=0, lower_bound=0xFF, upper_bound=0xFF)) + assert is_equal(si_byte3, claripy.SI(bits=8, stride=1, lower_bound=0xC, upper_bound=0xD)) + + # Optimization on bitwise-and + si_1 = claripy.SI(bits=32, stride=1, lower_bound=0x0, upper_bound=0xFFFFFFFF) + si_2 = claripy.SI(bits=32, stride=0, lower_bound=0x80000000, upper_bound=0x80000000) + si = si_1 & si_2 + assert is_equal( + si, + claripy.SI(bits=32, stride=0x80000000, lower_bound=0, upper_bound=0x80000000), + ) + + si_1 = claripy.SI(bits=32, stride=1, lower_bound=0x0, upper_bound=0x7FFFFFFF) + si_2 = claripy.SI(bits=32, stride=0, lower_bound=0x80000000, upper_bound=0x80000000) + si = si_1 & si_2 + assert is_equal(si, claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0)) + + # Concatenation: concat with zeros only increases the stride + si_1 = claripy.SI(bits=8, stride=0xFF, lower_bound=0x0, upper_bound=0xFF) + si_2 = claripy.SI(bits=8, stride=0, lower_bound=0, upper_bound=0) + si = si_1.concat(si_2) + assert is_equal(si, claripy.SI(bits=16, stride=0xFF00, lower_bound=0, upper_bound=0xFF00)) + + # Extract from a reversed value + si_1 = claripy.SI(bits=64, stride=0xFF, lower_bound=0x0, upper_bound=0xFF) + si_2 = si_1.reversed[63:56] + assert is_equal(si_2, claripy.SI(bits=8, stride=0xFF, lower_bound=0x0, upper_bound=0xFF)) + + # + # ValueSet + # + + def VS(name=None, bits=None, region=None, val=None): # noqa: F811 + region = "foobar" if region is None else region + return claripy.ValueSet(bits, region=region, region_base_addr=0, value=val, name=name) + + vs_1 = VS(bits=32, val=0) + vs_1 = vs_1.intersection(VS(bits=32, val=1)) + assert vsa_model(vs_1).is_empty + # Test merging two addresses + vsa_model(vs_1)._merge_si("global", 0, vsa_model(si1)) + vsa_model(vs_1)._merge_si("global", 0, vsa_model(si3)) + assert ( + vsa_model(vs_1) + .get_si("global") + .identical(vsa_model(SI(bits=32, stride=18, lower_bound=10, upper_bound=28))) + ) + # Length of this ValueSet + assert len(vsa_model(vs_1)) == 32 + + vs_1 = VS(name="boo", bits=32, val=0).intersection(VS(name="makeitempty", bits=32, val=1)) + vs_2 = VS(name="foo", bits=32, val=0).intersection(VS(name="makeitempty", bits=32, val=1)) + assert claripy.backends.vsa.identical(vs_1, vs_1) + assert claripy.backends.vsa.identical(vs_2, vs_2) + vsa_model(vs_1)._merge_si("global", 0, vsa_model(si1)) + assert not claripy.backends.vsa.identical(vs_1, vs_2) + vsa_model(vs_2)._merge_si("global", 0, vsa_model(si1)) + assert claripy.backends.vsa.identical(vs_1, vs_2) + assert claripy.backends.vsa.is_true((vs_1 & vs_2) == vs_1) + vsa_model(vs_1)._merge_si("global", 0, vsa_model(si3)) + assert not claripy.backends.vsa.identical(vs_1, vs_2) + + # Subtraction + # Subtraction of two pointers yields a concrete value + + vs_1 = VS(name="foo", region="global", bits=32, val=0x400010) + vs_2 = VS(name="bar", region="global", bits=32, val=0x400000) + si = vs_1 - vs_2 + assert type(vsa_model(si)) is StridedInterval + assert claripy.backends.vsa.identical(si, claripy.SI(bits=32, stride=0, lower_bound=0x10, upper_bound=0x10)) + + # + # IfProxy + # + + si = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=0xFFFFFFFF) + if_0 = claripy.If(si == 0, si, si - 1) + assert claripy.backends.vsa.identical(if_0, if_0) + assert not claripy.backends.vsa.identical(if_0, si) + + # max and min on IfProxy + si = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFF) + if_0 = claripy.If(si == 0, si, si - 1) + max_val = b.max(if_0) + min_val = b.min(if_0) + assert max_val == 0xFFFFFFFF + assert min_val == 0x00000000 + + # identical + assert claripy.backends.vsa.identical(if_0, if_0) + assert claripy.backends.vsa.identical(if_0, si) + if_0_copy = claripy.If(si == 0, si, si - 1) + assert claripy.backends.vsa.identical(if_0, if_0_copy) + if_1 = claripy.If(si == 1, si, si - 1) + assert claripy.backends.vsa.identical(if_0, if_1) + + si = SI(bits=32, stride=0, lower_bound=1, upper_bound=1) + if_0 = claripy.If(si == 0, si, si - 1) + if_0_copy = claripy.If(si == 0, si, si - 1) + assert claripy.backends.vsa.identical(if_0, if_0_copy) + if_1 = claripy.If(si == 1, si, si - 1) + assert not claripy.backends.vsa.identical(if_0, if_1) + if_1 = claripy.If(si == 0, si + 1, si - 1) + assert claripy.backends.vsa.identical(if_0, if_1) + if_1 = claripy.If(si == 0, si, si) + assert not claripy.backends.vsa.identical(if_0, if_1) + + # if_1 = And(VS_2, IfProxy(si == 0, 0, 1)) + vs_2 = VS(region="global", bits=32, val=0xFA7B00B) + si = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=1) + if_1 = vs_2 & claripy.If( + si == 0, + claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0), + claripy.SI(bits=32, stride=0, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), + ) + assert claripy.backends.vsa.is_true( + vsa_model(if_1.ite_excavated.args[1]) == vsa_model(VS(region="global", bits=32, val=0)) + ) + assert claripy.backends.vsa.is_true(vsa_model(if_1.ite_excavated.args[2]) == vsa_model(vs_2)) + + # if_2 = And(VS_3, IfProxy(si != 0, 0, 1) + vs_3 = VS(region="global", bits=32, val=0xDEADCA7) + si = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=1) + if_2 = vs_3 & claripy.If( + si != 0, + claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0), + claripy.SI(bits=32, stride=0, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), + ) + assert claripy.backends.vsa.is_true( + vsa_model(if_2.ite_excavated.args[1]) == vsa_model(VS(region="global", bits=32, val=0)) + ) + assert claripy.backends.vsa.is_true(vsa_model(if_2.ite_excavated.args[2]) == vsa_model(vs_3)) + + # Something crazy is gonna happen... + # if_3 = if_1 + if_2 + # assert claripy.backends.vsa.is_true(vsa_model(if_3.ite_excavated.args[1]) == vsa_model(vs_3))) + # assert claripy.backends.vsa.is_true(vsa_model(if_3.ite_excavated.args[1]) == vsa_model(vs_2))) + + def test_vsa_constraint_to_si(self): + # Set backend + b = claripy.backends.vsa + s = claripy.SolverVSA() # pylint:disable=unused-variable + + SI = claripy.SI + BVV = claripy.BVV + + claripy.vsa.strided_interval.allow_dsis = False + + # + # If(SI == 0, 1, 0) == 1 + # + + s1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=2) + ast_true = claripy.If(s1 == BVV(0, 32), BVV(1, 1), BVV(0, 1)) == BVV(1, 1) + ast_false = claripy.If(s1 == BVV(0, 32), BVV(1, 1), BVV(0, 1)) != BVV(1, 1) + + trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) + assert trueside_sat + assert len(trueside_replacement) == 1 + assert trueside_replacement[0][0] is s1 + # True side: claripy.SI<32>0[0, 0] + assert claripy.backends.vsa.is_true( + trueside_replacement[0][1] == claripy.SI(bits=32, stride=0, lower_bound=0, upper_bound=0) + ) + + falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) + assert falseside_sat is True + assert len(falseside_replacement) == 1 + assert falseside_replacement[0][0] is s1 + # False side; claripy.SI<32>1[1, 2] + + assert claripy.backends.vsa.identical( + falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) ) - == 1 - ) - ast_false = ( - claripy.Extract( - 0, - 0, - claripy.ZeroExt( - 32, - claripy.If(claripy.Extract(31, 0, (s4 & s4)).SLT(0), BVV(1, 32), BVV(0, 32)), - ), + # + # If(SI == 0, 1, 0) <= 1 + # + + s1 = SI(bits=32, stride=1, lower_bound=0, upper_bound=2) + ast_true = claripy.If(s1 == BVV(0, 32), BVV(1, 1), BVV(0, 1)) <= BVV(1, 1) + ast_false = claripy.If(s1 == BVV(0, 32), BVV(1, 1), BVV(0, 1)) > BVV(1, 1) + + trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) + assert trueside_sat # Always satisfiable + + falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) + assert not falseside_sat # Not sat + + # + # If(SI == 0, 20, 10) > 15 + # + + s1 = SI(bits=32, stride=1, lower_bound=0, upper_bound=2) + ast_true = claripy.If(s1 == BVV(0, 32), BVV(20, 32), BVV(10, 32)) > BVV(15, 32) + ast_false = claripy.If(s1 == BVV(0, 32), BVV(20, 32), BVV(10, 32)) <= BVV(15, 32) + + trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) + assert trueside_sat + assert len(trueside_replacement) == 1 + assert trueside_replacement[0][0] is s1 + # True side: SI<32>0[0, 0] + assert claripy.backends.vsa.identical( + trueside_replacement[0][1], SI(bits=32, stride=0, lower_bound=0, upper_bound=0) ) - != 1 - ) - - trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) - assert trueside_sat - assert len(trueside_replacement) == 1 - assert trueside_replacement[0][0] is s4[31:0] - # True side: claripy.SI<32>0[0, 0] - assert claripy.backends.vsa.identical( - trueside_replacement[0][1], - SI(bits=32, stride=1, lower_bound=-0x80000000, upper_bound=-1), - ) - - falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) - assert falseside_sat - assert len(falseside_replacement) == 1 - assert falseside_replacement[0][0] is s4[31:0] - # False side; claripy.SI<32>1[1, 2] - assert claripy.backends.vsa.identical( - falseside_replacement[0][1], - SI(bits=32, stride=1, lower_bound=0, upper_bound=0x7FFFFFFF), - ) - - # - # TOP_SI != -1 - # - - s5 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFF) - ast_true = s5 == claripy.SI(bits=32, stride=1, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF) - ast_false = s5 != claripy.SI(bits=32, stride=1, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF) - - trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) - assert trueside_sat - assert len(trueside_replacement) == 1 - assert trueside_replacement[0][0] is s5 - assert claripy.backends.vsa.identical( - trueside_replacement[0][1], - SI(bits=32, stride=1, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), - ) - - falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) - assert falseside_sat - assert len(falseside_replacement) == 1 - assert falseside_replacement[0][0] is s5 - assert claripy.backends.vsa.identical( - falseside_replacement[0][1], - SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFE), - ) - - # TODO: Add some more insane test cases - - -def test_vsa_discrete_value_set(): - """ - Test cases for DiscreteStridedIntervalSet. - """ - # Set backend - b = claripy.backends.vsa - - s = claripy.SolverVSA() # pylint:disable=unused-variable - - SI = claripy.SI - BVV = claripy.BVV - - # Allow the use of DiscreteStridedIntervalSet (cuz we wanna test it!) - claripy.vsa.strided_interval.allow_dsis = True - - # - # Union - # - val_1 = BVV(0, 32) - val_2 = BVV(1, 32) - r = val_1.union(val_2) - assert isinstance(vsa_model(r), DiscreteStridedIntervalSet) - assert vsa_model(r).collapse(), claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=1) - - r = r.union(BVV(3, 32)) - ints = b.eval(r, 4) - assert len(ints) == 3 - assert ints == [0, 1, 3] - - # - # Intersection - # - - val_1 = BVV(0, 32) - val_2 = BVV(1, 32) - r = val_1.intersection(val_2) - assert isinstance(vsa_model(r), StridedInterval) - assert vsa_model(r).is_empty - - val_1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=10) - val_2 = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=20) - val_3 = claripy.SI(bits=32, stride=1, lower_bound=15, upper_bound=50) - r = val_1.union(val_2) - assert isinstance(vsa_model(r), DiscreteStridedIntervalSet) - r = r.intersection(val_3) - assert sorted(b.eval(r, 100)) == [15, 16, 17, 18, 19, 20] - - # - # Some logical operations - # - - val_1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=10) - val_2 = claripy.SI(bits=32, stride=1, lower_bound=5, upper_bound=20) - r_1 = val_1.union(val_2) - val_3 = claripy.SI(bits=32, stride=1, lower_bound=20, upper_bound=30) - val_4 = claripy.SI(bits=32, stride=1, lower_bound=25, upper_bound=35) - r_2 = val_3.union(val_4) - assert isinstance(vsa_model(r_1), DiscreteStridedIntervalSet) - assert isinstance(vsa_model(r_2), DiscreteStridedIntervalSet) - # r_1 < r_2 - assert BoolResult.is_maybe(vsa_model(r_1 < r_2)) - # r_1 <= r_2 - assert BoolResult.is_true(vsa_model(r_1 <= r_2)) - # r_1 >= r_2 - assert BoolResult.is_maybe(vsa_model(r_1 >= r_2)) - # r_1 > r_2 - assert BoolResult.is_false(vsa_model(r_1 > r_2)) - # r_1 == r_2 - assert BoolResult.is_maybe(vsa_model(r_1 == r_2)) - # r_1 != r_2 - assert BoolResult.is_maybe(vsa_model(r_1 != r_2)) - - # - # Some arithmetic operations - # - - val_1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=10) - val_2 = claripy.SI(bits=32, stride=1, lower_bound=5, upper_bound=20) - r_1 = val_1.union(val_2) - val_3 = claripy.SI(bits=32, stride=1, lower_bound=20, upper_bound=30) - val_4 = claripy.SI(bits=32, stride=1, lower_bound=25, upper_bound=35) - r_2 = val_3.union(val_4) - assert isinstance(vsa_model(r_1), DiscreteStridedIntervalSet) - assert isinstance(vsa_model(r_2), DiscreteStridedIntervalSet) - # r_1 + r_2 - r = r_1 + r_2 - assert isinstance(vsa_model(r), DiscreteStridedIntervalSet) - assert vsa_model(r).collapse().identical(vsa_model(SI(bits=32, stride=1, lower_bound=20, upper_bound=55))) - # r_2 - r_1 - r = r_2 - r_1 - assert isinstance(vsa_model(r), DiscreteStridedIntervalSet) - assert vsa_model(r).collapse().identical(vsa_model(SI(bits=32, stride=1, lower_bound=0, upper_bound=35))) - - # Disable it in the end - claripy.vsa.strided_interval.allow_dsis = False - - -def test_solution(): - # Set backend - solver_type = claripy.SolverVSA - s = solver_type() - - def VS(name=None, bits=None, region=None, val=None): - region = "foobar" if region is None else region - return claripy.ValueSet(bits, region=region, region_base_addr=0, value=val, name=name) - - si = claripy.SI(bits=32, stride=10, lower_bound=32, upper_bound=320) - assert s.solution(si, si) - assert s.solution(si, 32) - assert not s.solution(si, 31) - - si2 = claripy.SI(bits=32, stride=0, lower_bound=3, upper_bound=3) - assert s.solution(si2, si2) - assert s.solution(si2, 3) - assert not s.solution(si2, 18) - assert not s.solution(si2, si) - - vs = VS(region="global", bits=32, val=0xDEADCA7) - assert s.solution(vs, 0xDEADCA7) - assert not s.solution(vs, 0xDEADBEEF) - - si = claripy.SI(bits=32, stride=0, lower_bound=3, upper_bound=3) - si2 = claripy.SI(bits=32, stride=10, lower_bound=32, upper_bound=320) - - vs = VS(bits=si.size(), region="foo", val=si._model_vsa) - # vs = vs.annotate(RegionAnnotation('foo', 0, si2)) - vs2 = VS(bits=si2.size(), region="foo", val=si2._model_vsa) - vs = vs.union(vs2) - - assert s.solution(vs, 3) - assert s.solution(vs, 122) - assert s.solution(vs, si) - assert not s.solution(vs, 2) - assert not s.solution(vs, 322) - - -def test_reasonable_bounds(): - si = claripy.SI(bits=32, stride=1, lower_bound=-20, upper_bound=-10) - b = claripy.backends.vsa - assert b.max(si) == 0xFFFFFFF6 - assert b.min(si) == 0xFFFFFFEC - - si = claripy.SI(bits=32, stride=1, lower_bound=-20, upper_bound=10) - b = claripy.backends.vsa - assert b.max(si) == 0xFFFFFFFF - assert b.min(si) == 0 - - -def test_shifting(): - SI = claripy.SI - identical = claripy.backends.vsa.identical - - # <32>1[2,4] LShR 1 = <32>1[1,2] - si = SI(bits=32, stride=1, lower_bound=2, upper_bound=4) - r = si.LShR(1) - assert identical(r, SI(bits=32, stride=1, lower_bound=1, upper_bound=2)) - - # <32>4[15,11] LShR 4 = <32>1[0, 0xfffffff] - si = SI(bits=32, stride=4, lower_bound=15, upper_bound=11) - r = si.LShR(4) - assert identical(r, SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFF)) - - # Extract - si = SI(bits=32, stride=4, lower_bound=15, upper_bound=11) - r = si[31:4] - assert identical(r, SI(bits=28, stride=1, lower_bound=0, upper_bound=0xFFFFFFF)) - - # <32>1[-4,-2] >> 1 = <32>1[-2,-1] - si = SI(bits=32, stride=1, lower_bound=-4, upper_bound=-2) - r = si >> 1 - assert identical(r, SI(bits=32, stride=1, lower_bound=-2, upper_bound=-1)) - - # <32>1[-4,-2] LShR 1 = <32>1[0x7ffffffe,0x7fffffff] - si = SI(bits=32, stride=1, lower_bound=-4, upper_bound=-2) - r = si.LShR(1) - assert identical(r, SI(bits=32, stride=1, lower_bound=0x7FFFFFFE, upper_bound=0x7FFFFFFF)) - - -def test_reverse(): - x = claripy.SI(name="TOP", bits=64, lower_bound=0, upper_bound=0xFFFFFFFFFFFFFFFF, stride=1) # TOP - y = claripy.SI(name="range", bits=64, lower_bound=0, upper_bound=1337, stride=1) # [0, 1337] - - r0 = x.intersection(y) - r1 = x.reversed.intersection(y) - r2 = x.intersection(y.reversed).reversed - r3 = x.reversed.intersection(y.reversed).reversed - - assert r0._model_vsa.max == 1337 - assert r1._model_vsa.max == 1337 - assert r2._model_vsa.max == 1337 - assert r3._model_vsa.max == 1337 - # See claripy issue #95 for details. - si0 = StridedInterval(name="a", bits=32, stride=0, lower_bound=0xFFFF0000, upper_bound=0xFFFF0000) - si1 = StridedInterval(name="a", bits=32, stride=0, lower_bound=0xFFFF0001, upper_bound=0xFFFF0001) - dsis = DiscreteStridedIntervalSet(name="b", bits=32, si_set={si0, si1}) + falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) + assert falseside_sat + assert len(falseside_replacement) == 1 + assert falseside_replacement[0][0] is s1 + # False side; SI<32>1[1, 2] + assert claripy.backends.vsa.identical( + falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) + ) + + # + # If(SI == 0, 20, 10) >= 15 + # + + s1 = SI(bits=32, stride=1, lower_bound=0, upper_bound=2) + ast_true = claripy.If(s1 == BVV(0, 32), BVV(15, 32), BVV(10, 32)) >= BVV(15, 32) + ast_false = claripy.If(s1 == BVV(0, 32), BVV(15, 32), BVV(10, 32)) < BVV(15, 32) + + trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) + assert trueside_sat + assert len(trueside_replacement) == 1 + assert trueside_replacement[0][0] is s1 + # True side: SI<32>0[0, 0] + assert claripy.backends.vsa.identical( + trueside_replacement[0][1], SI(bits=32, stride=0, lower_bound=0, upper_bound=0) + ) + + falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) + assert falseside_sat + assert len(falseside_replacement) == 1 + assert falseside_replacement[0][0] is s1 + # False side; SI<32>0[0,0] + assert claripy.backends.vsa.identical( + falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) + ) + + # + # Extract(0, 0, Concat(BVV(0, 63), If(SI == 0, 1, 0))) == 1 + # + + s2 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=2) + ast_true = claripy.Extract(0, 0, claripy.Concat(BVV(0, 63), claripy.If(s2 == 0, BVV(1, 1), BVV(0, 1)))) == 1 + ast_false = claripy.Extract(0, 0, claripy.Concat(BVV(0, 63), claripy.If(s2 == 0, BVV(1, 1), BVV(0, 1)))) != 1 + + trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) + assert trueside_sat + assert len(trueside_replacement) == 1 + assert trueside_replacement[0][0] is s2 + # True side: claripy.SI<32>0[0, 0] + assert claripy.backends.vsa.identical( + trueside_replacement[0][1], SI(bits=32, stride=0, lower_bound=0, upper_bound=0) + ) + + falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) + assert falseside_sat + assert len(falseside_replacement) == 1 + assert falseside_replacement[0][0] is s2 + # False side; claripy.SI<32>1[1, 2] + assert claripy.backends.vsa.identical( + falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) + ) + + # + # Extract(0, 0, ZeroExt(32, If(SI == 0, BVV(1, 32), BVV(0, 32)))) == 1 + # + + s3 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=2) + ast_true = claripy.Extract(0, 0, claripy.ZeroExt(32, claripy.If(s3 == 0, BVV(1, 32), BVV(0, 32)))) == 1 + ast_false = claripy.Extract(0, 0, claripy.ZeroExt(32, claripy.If(s3 == 0, BVV(1, 32), BVV(0, 32)))) != 1 + + trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) + assert trueside_sat + assert len(trueside_replacement) == 1 + assert trueside_replacement[0][0] is s3 + # True side: claripy.SI<32>0[0, 0] + assert claripy.backends.vsa.identical( + trueside_replacement[0][1], SI(bits=32, stride=0, lower_bound=0, upper_bound=0) + ) + + falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) + assert falseside_sat + assert len(falseside_replacement) == 1 + assert falseside_replacement[0][0] is s3 + # False side; claripy.SI<32>1[1, 2] + assert claripy.backends.vsa.identical( + falseside_replacement[0][1], SI(bits=32, stride=1, lower_bound=1, upper_bound=2) + ) + + # + # Extract(0, 0, ZeroExt(32, If(Extract(32, 0, (SI & claripy.SI)) < 0, BVV(1, 1), BVV(0, 1)))) + # + + s4 = claripy.SI(bits=64, stride=1, lower_bound=0, upper_bound=0xFFFFFFFFFFFFFFFF) + ast_true = ( + claripy.Extract( + 0, + 0, + claripy.ZeroExt( + 32, + claripy.If(claripy.Extract(31, 0, (s4 & s4)).SLT(0), BVV(1, 32), BVV(0, 32)), + ), + ) + == 1 + ) + ast_false = ( + claripy.Extract( + 0, + 0, + claripy.ZeroExt( + 32, + claripy.If(claripy.Extract(31, 0, (s4 & s4)).SLT(0), BVV(1, 32), BVV(0, 32)), + ), + ) + != 1 + ) + + trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) + assert trueside_sat + assert len(trueside_replacement) == 1 + assert trueside_replacement[0][0] is s4[31:0] + # True side: claripy.SI<32>0[0, 0] + assert claripy.backends.vsa.identical( + trueside_replacement[0][1], + SI(bits=32, stride=1, lower_bound=-0x80000000, upper_bound=-1), + ) + + falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) + assert falseside_sat + assert len(falseside_replacement) == 1 + assert falseside_replacement[0][0] is s4[31:0] + # False side; claripy.SI<32>1[1, 2] + assert claripy.backends.vsa.identical( + falseside_replacement[0][1], + SI(bits=32, stride=1, lower_bound=0, upper_bound=0x7FFFFFFF), + ) + + # + # TOP_SI != -1 + # + + s5 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFF) + ast_true = s5 == claripy.SI(bits=32, stride=1, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF) + ast_false = s5 != claripy.SI(bits=32, stride=1, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF) + + trueside_sat, trueside_replacement = b.constraint_to_si(ast_true) + assert trueside_sat + assert len(trueside_replacement) == 1 + assert trueside_replacement[0][0] is s5 + assert claripy.backends.vsa.identical( + trueside_replacement[0][1], + SI(bits=32, stride=1, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), + ) + + falseside_sat, falseside_replacement = b.constraint_to_si(ast_false) + assert falseside_sat + assert len(falseside_replacement) == 1 + assert falseside_replacement[0][0] is s5 + assert claripy.backends.vsa.identical( + falseside_replacement[0][1], + SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFFE), + ) - dsis_r = dsis.reverse() - solver = claripy.SolverVSA() - assert set(solver.eval(dsis_r, 3)) == {0xFFFF, 0x100FFFF} + # TODO: Add some more insane test cases + + def test_vsa_discrete_value_set(self): + """ + Test cases for DiscreteStridedIntervalSet. + """ + # Set backend + b = claripy.backends.vsa + + s = claripy.SolverVSA() # pylint:disable=unused-variable + + SI = claripy.SI + BVV = claripy.BVV + + # Allow the use of DiscreteStridedIntervalSet (cuz we wanna test it!) + claripy.vsa.strided_interval.allow_dsis = True + + # + # Union + # + val_1 = BVV(0, 32) + val_2 = BVV(1, 32) + r = val_1.union(val_2) + assert isinstance(vsa_model(r), DiscreteStridedIntervalSet) + assert vsa_model(r).collapse(), claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=1) + + r = r.union(BVV(3, 32)) + ints = b.eval(r, 4) + assert len(ints) == 3 + assert ints == [0, 1, 3] + + # + # Intersection + # + + val_1 = BVV(0, 32) + val_2 = BVV(1, 32) + r = val_1.intersection(val_2) + assert isinstance(vsa_model(r), StridedInterval) + assert vsa_model(r).is_empty + + val_1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=10) + val_2 = claripy.SI(bits=32, stride=1, lower_bound=10, upper_bound=20) + val_3 = claripy.SI(bits=32, stride=1, lower_bound=15, upper_bound=50) + r = val_1.union(val_2) + assert isinstance(vsa_model(r), DiscreteStridedIntervalSet) + r = r.intersection(val_3) + assert sorted(b.eval(r, 100)) == [15, 16, 17, 18, 19, 20] + + # + # Some logical operations + # + + val_1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=10) + val_2 = claripy.SI(bits=32, stride=1, lower_bound=5, upper_bound=20) + r_1 = val_1.union(val_2) + val_3 = claripy.SI(bits=32, stride=1, lower_bound=20, upper_bound=30) + val_4 = claripy.SI(bits=32, stride=1, lower_bound=25, upper_bound=35) + r_2 = val_3.union(val_4) + assert isinstance(vsa_model(r_1), DiscreteStridedIntervalSet) + assert isinstance(vsa_model(r_2), DiscreteStridedIntervalSet) + # r_1 < r_2 + assert BoolResult.is_maybe(vsa_model(r_1 < r_2)) + # r_1 <= r_2 + assert BoolResult.is_true(vsa_model(r_1 <= r_2)) + # r_1 >= r_2 + assert BoolResult.is_maybe(vsa_model(r_1 >= r_2)) + # r_1 > r_2 + assert BoolResult.is_false(vsa_model(r_1 > r_2)) + # r_1 == r_2 + assert BoolResult.is_maybe(vsa_model(r_1 == r_2)) + # r_1 != r_2 + assert BoolResult.is_maybe(vsa_model(r_1 != r_2)) + + # + # Some arithmetic operations + # + + val_1 = claripy.SI(bits=32, stride=1, lower_bound=0, upper_bound=10) + val_2 = claripy.SI(bits=32, stride=1, lower_bound=5, upper_bound=20) + r_1 = val_1.union(val_2) + val_3 = claripy.SI(bits=32, stride=1, lower_bound=20, upper_bound=30) + val_4 = claripy.SI(bits=32, stride=1, lower_bound=25, upper_bound=35) + r_2 = val_3.union(val_4) + assert isinstance(vsa_model(r_1), DiscreteStridedIntervalSet) + assert isinstance(vsa_model(r_2), DiscreteStridedIntervalSet) + # r_1 + r_2 + r = r_1 + r_2 + assert isinstance(vsa_model(r), DiscreteStridedIntervalSet) + assert vsa_model(r).collapse().identical(vsa_model(SI(bits=32, stride=1, lower_bound=20, upper_bound=55))) + # r_2 - r_1 + r = r_2 - r_1 + assert isinstance(vsa_model(r), DiscreteStridedIntervalSet) + assert vsa_model(r).collapse().identical(vsa_model(SI(bits=32, stride=1, lower_bound=0, upper_bound=35))) + + # Disable it in the end + claripy.vsa.strided_interval.allow_dsis = False + + def test_solution(self): + # Set backend + solver_type = claripy.SolverVSA + s = solver_type() + + def VS(name=None, bits=None, region=None, val=None): + region = "foobar" if region is None else region + return claripy.ValueSet(bits, region=region, region_base_addr=0, value=val, name=name) + + si = claripy.SI(bits=32, stride=10, lower_bound=32, upper_bound=320) + assert s.solution(si, si) + assert s.solution(si, 32) + assert not s.solution(si, 31) + + si2 = claripy.SI(bits=32, stride=0, lower_bound=3, upper_bound=3) + assert s.solution(si2, si2) + assert s.solution(si2, 3) + assert not s.solution(si2, 18) + assert not s.solution(si2, si) + + vs = VS(region="global", bits=32, val=0xDEADCA7) + assert s.solution(vs, 0xDEADCA7) + assert not s.solution(vs, 0xDEADBEEF) + + si = claripy.SI(bits=32, stride=0, lower_bound=3, upper_bound=3) + si2 = claripy.SI(bits=32, stride=10, lower_bound=32, upper_bound=320) + + vs = VS(bits=si.size(), region="foo", val=si._model_vsa) + # vs = vs.annotate(RegionAnnotation('foo', 0, si2)) + vs2 = VS(bits=si2.size(), region="foo", val=si2._model_vsa) + vs = vs.union(vs2) + + assert s.solution(vs, 3) + assert s.solution(vs, 122) + assert s.solution(vs, si) + assert not s.solution(vs, 2) + assert not s.solution(vs, 322) + + def test_reasonable_bounds(self): + si = claripy.SI(bits=32, stride=1, lower_bound=-20, upper_bound=-10) + b = claripy.backends.vsa + assert b.max(si) == 0xFFFFFFF6 + assert b.min(si) == 0xFFFFFFEC + + si = claripy.SI(bits=32, stride=1, lower_bound=-20, upper_bound=10) + b = claripy.backends.vsa + assert b.max(si) == 0xFFFFFFFF + assert b.min(si) == 0 + + def test_shifting(self): + SI = claripy.SI + identical = claripy.backends.vsa.identical + + # <32>1[2,4] LShR 1 = <32>1[1,2] + si = SI(bits=32, stride=1, lower_bound=2, upper_bound=4) + r = si.LShR(1) + assert identical(r, SI(bits=32, stride=1, lower_bound=1, upper_bound=2)) + + # <32>4[15,11] LShR 4 = <32>1[0, 0xfffffff] + si = SI(bits=32, stride=4, lower_bound=15, upper_bound=11) + r = si.LShR(4) + assert identical(r, SI(bits=32, stride=1, lower_bound=0, upper_bound=0xFFFFFFF)) + + # Extract + si = SI(bits=32, stride=4, lower_bound=15, upper_bound=11) + r = si[31:4] + assert identical(r, SI(bits=28, stride=1, lower_bound=0, upper_bound=0xFFFFFFF)) + + # <32>1[-4,-2] >> 1 = <32>1[-2,-1] + si = SI(bits=32, stride=1, lower_bound=-4, upper_bound=-2) + r = si >> 1 + assert identical(r, SI(bits=32, stride=1, lower_bound=-2, upper_bound=-1)) + + # <32>1[-4,-2] LShR 1 = <32>1[0x7ffffffe,0x7fffffff] + si = SI(bits=32, stride=1, lower_bound=-4, upper_bound=-2) + r = si.LShR(1) + assert identical(r, SI(bits=32, stride=1, lower_bound=0x7FFFFFFE, upper_bound=0x7FFFFFFF)) + + def test_reverse(self): + x = claripy.SI(name="TOP", bits=64, lower_bound=0, upper_bound=0xFFFFFFFFFFFFFFFF, stride=1) # TOP + y = claripy.SI(name="range", bits=64, lower_bound=0, upper_bound=1337, stride=1) # [0, 1337] + + r0 = x.intersection(y) + r1 = x.reversed.intersection(y) + r2 = x.intersection(y.reversed).reversed + r3 = x.reversed.intersection(y.reversed).reversed + + assert r0._model_vsa.max == 1337 + assert r1._model_vsa.max == 1337 + assert r2._model_vsa.max == 1337 + assert r3._model_vsa.max == 1337 + + # See claripy issue #95 for details. + si0 = StridedInterval(name="a", bits=32, stride=0, lower_bound=0xFFFF0000, upper_bound=0xFFFF0000) + si1 = StridedInterval(name="a", bits=32, stride=0, lower_bound=0xFFFF0001, upper_bound=0xFFFF0001) + dsis = DiscreteStridedIntervalSet(name="b", bits=32, si_set={si0, si1}) + + dsis_r = dsis.reverse() + solver = claripy.SolverVSA() + assert set(solver.eval(dsis_r, 3)) == {0xFFFF, 0x100FFFF} if __name__ == "__main__": - test_reasonable_bounds() - test_reversed_concat() - test_fucked_extract() - test_simple_cardinality() - test_wrapped_intervals() - test_join() - test_vsa() - test_vsa_constraint_to_si() - test_vsa_discrete_value_set() - test_solution() - test_shifting() - test_reverse() + unittest.main() diff --git a/tests/test_z3.py b/tests/test_z3.py index 16e71aae5..62ab77d89 100644 --- a/tests/test_z3.py +++ b/tests/test_z3.py @@ -8,8 +8,7 @@ class TestZ3(unittest.TestCase): A class used for testing z3 """ - @staticmethod - def test_extrema(): + def test_extrema(self): """ Test the _extrema function within the z3 backend """