@@ -10,7 +10,9 @@ use serde::{Deserialize, Serialize};
10
10
use std:: {
11
11
cmp,
12
12
convert:: { TryFrom , TryInto } ,
13
- fmt, i128, i64, iter, ops, str, u64,
13
+ fmt, i128, i64, iter, ops,
14
+ ops:: Sub ,
15
+ str, u64,
14
16
} ;
15
17
use thiserror:: Error ;
16
18
@@ -935,6 +937,35 @@ impl I256 {
935
937
let ( result, _) = self . overflowing_pow ( exp) ;
936
938
result
937
939
}
940
+
941
+ /// Arithmetic Shift Right operation. Shifts `shift` number of times to the right maintaining
942
+ /// the original sign. If the number is positive this is the same as logic shift right.
943
+ pub fn asr ( self , shift : u32 ) -> Self {
944
+ // Avoid shifting if we are going to know the result regardless of the value.
945
+ if shift == 0 {
946
+ self
947
+ } else if shift >= 255u32 {
948
+ match self . sign ( ) {
949
+ // It's always going to be zero (i.e. 00000000...00000000)
950
+ Sign :: Positive => Self :: zero ( ) ,
951
+ // It's always going to be -1 (i.e. 11111111...11111111)
952
+ Sign :: Negative => Self :: from ( -1i8 ) ,
953
+ }
954
+ } else {
955
+ // Perform the shift.
956
+ match self . sign ( ) {
957
+ Sign :: Positive => self >> shift,
958
+ // We need to do: `for 0..shift { self >> 1 | 2^255 }`
959
+ // We can avoid the loop by doing: `self >> shift | ~(2^(255 - shift) - 1)`
960
+ // where '~' represents ones complement
961
+ Sign :: Negative => {
962
+ let bitwise_or =
963
+ Self :: from_raw ( !U256 :: from ( 2u8 ) . pow ( U256 :: from ( 255u32 - shift) ) . sub ( 1u8 ) ) ;
964
+ ( self >> shift) | bitwise_or
965
+ }
966
+ }
967
+ }
968
+ }
938
969
}
939
970
940
971
macro_rules! impl_from {
@@ -1276,6 +1307,7 @@ mod tests {
1276
1307
use crate :: abi:: Tokenizable ;
1277
1308
use once_cell:: sync:: Lazy ;
1278
1309
use serde_json:: json;
1310
+ use std:: ops:: Neg ;
1279
1311
1280
1312
static MIN_ABS : Lazy < U256 > = Lazy :: new ( || U256 :: from ( 1 ) << 255 ) ;
1281
1313
@@ -1521,6 +1553,33 @@ mod tests {
1521
1553
assert_eq ! ( I256 :: MIN >> 255 , I256 :: one( ) ) ;
1522
1554
}
1523
1555
1556
+ #[ test]
1557
+ fn arithmetic_shift_right ( ) {
1558
+ let value = I256 :: from_raw ( U256 :: from ( 2u8 ) . pow ( U256 :: from ( 254u8 ) ) ) . neg ( ) ;
1559
+ let expected_result = I256 :: from_raw ( U256 :: MAX . sub ( 1u8 ) ) ;
1560
+ assert_eq ! ( value. asr( 253u32 ) , expected_result, "1011...1111 >> 253 was not 1111...1110" ) ;
1561
+
1562
+ let value = I256 :: from ( -1i8 ) ;
1563
+ let expected_result = I256 :: from ( -1i8 ) ;
1564
+ assert_eq ! ( value. asr( 250u32 ) , expected_result, "-1 >> any_amount was not -1" ) ;
1565
+
1566
+ let value = I256 :: from_raw ( U256 :: from ( 2u8 ) . pow ( U256 :: from ( 254u8 ) ) ) . neg ( ) ;
1567
+ let expected_result = I256 :: from ( -1i8 ) ;
1568
+ assert_eq ! ( value. asr( 255u32 ) , expected_result, "1011...1111 >> 255 was not -1" ) ;
1569
+
1570
+ let value = I256 :: from_raw ( U256 :: from ( 2u8 ) . pow ( U256 :: from ( 254u8 ) ) ) . neg ( ) ;
1571
+ let expected_result = I256 :: from ( -1i8 ) ;
1572
+ assert_eq ! ( value. asr( 1024u32 ) , expected_result, "1011...1111 >> 1024 was not -1" ) ;
1573
+
1574
+ let value = I256 :: from ( 1024i32 ) ;
1575
+ let expected_result = I256 :: from ( 32i32 ) ;
1576
+ assert_eq ! ( value. asr( 5u32 ) , expected_result, "1024 >> 5 was not 32" ) ;
1577
+
1578
+ let value = I256 :: MAX ;
1579
+ let expected_result = I256 :: zero ( ) ;
1580
+ assert_eq ! ( value. asr( 255u32 ) , expected_result, "I256::MAX >> 255 was not 0" ) ;
1581
+ }
1582
+
1524
1583
#[ test]
1525
1584
fn addition ( ) {
1526
1585
assert_eq ! ( I256 :: MIN . overflowing_add( I256 :: MIN ) , ( I256 :: zero( ) , true ) ) ;
0 commit comments