11"""
22Defines EIP-2537 specification constants and functions.
33"""
4- import json
54from dataclasses import dataclass
6- from typing import Callable , List , Sized , SupportsBytes , Tuple
7-
8- from .helpers import current_python_script_directory
5+ from enum import Enum , auto
6+ from typing import Callable , Sized , SupportsBytes , Tuple
97
108
119@dataclass (frozen = True )
@@ -109,11 +107,6 @@ def __bytes__(self) -> bytes:
109107 return self .x .to_bytes (32 , byteorder = "big" )
110108
111109
112- with open (current_python_script_directory ("msm_discount_table.json" )) as f :
113- MSM_DISCOUNT_TABLE : List [int ] = json .load (f )
114- assert type (MSM_DISCOUNT_TABLE ) is list
115-
116-
117110@dataclass (frozen = True )
118111class Spec :
119112 """
@@ -149,7 +142,30 @@ class Spec:
149142 P = (X - 1 ) ** 2 * Q // 3 + X
150143 LEN_PER_PAIR = len (PointG1 () + PointG2 ())
151144 MSM_MULTIPLIER = 1_000
152- MSM_DISCOUNT_TABLE = MSM_DISCOUNT_TABLE
145+ # fmt: off
146+ G1MSM_DISCOUNT_TABLE = [
147+ 0 ,
148+ 1000 , 949 , 848 , 797 , 764 , 750 , 738 , 728 , 719 , 712 , 705 , 698 , 692 , 687 , 682 , 677 , 673 , 669 ,
149+ 665 , 661 , 658 , 654 , 651 , 648 , 645 , 642 , 640 , 637 , 635 , 632 , 630 , 627 , 625 , 623 , 621 , 619 ,
150+ 617 , 615 , 613 , 611 , 609 , 608 , 606 , 604 , 603 , 601 , 599 , 598 , 596 , 595 , 593 , 592 , 591 , 589 ,
151+ 588 , 586 , 585 , 584 , 582 , 581 , 580 , 579 , 577 , 576 , 575 , 574 , 573 , 572 , 570 , 569 , 568 , 567 ,
152+ 566 , 565 , 564 , 563 , 562 , 561 , 560 , 559 , 558 , 557 , 556 , 555 , 554 , 553 , 552 , 551 , 550 , 549 ,
153+ 548 , 547 , 547 , 546 , 545 , 544 , 543 , 542 , 541 , 540 , 540 , 539 , 538 , 537 , 536 , 536 , 535 , 534 ,
154+ 533 , 532 , 532 , 531 , 530 , 529 , 528 , 528 , 527 , 526 , 525 , 525 , 524 , 523 , 522 , 522 , 521 , 520 ,
155+ 520 , 519
156+ ]
157+ G2MSM_DISCOUNT_TABLE = [
158+ 0 ,
159+ 1000 , 1000 , 923 , 884 , 855 , 832 , 812 , 796 , 782 , 770 , 759 , 749 , 740 , 732 , 724 , 717 , 711 , 704 ,
160+ 699 , 693 , 688 , 683 , 679 , 674 , 670 , 666 , 663 , 659 , 655 , 652 , 649 , 646 , 643 , 640 , 637 , 634 ,
161+ 632 , 629 , 627 , 624 , 622 , 620 , 618 , 615 , 613 , 611 , 609 , 607 , 606 , 604 , 602 , 600 , 598 , 597 ,
162+ 595 , 593 , 592 , 590 , 589 , 587 , 586 , 584 , 583 , 582 , 580 , 579 , 578 , 576 , 575 , 574 , 573 , 571 ,
163+ 570 , 569 , 568 , 567 , 566 , 565 , 563 , 562 , 561 , 560 , 559 , 558 , 557 , 556 , 555 , 554 , 553 , 552 ,
164+ 552 , 551 , 550 , 549 , 548 , 547 , 546 , 545 , 545 , 544 , 543 , 542 , 541 , 541 , 540 , 539 , 538 , 537 ,
165+ 537 , 536 , 535 , 535 , 534 , 533 , 532 , 532 , 531 , 530 , 530 , 529 , 528 , 528 , 527 , 526 , 526 , 525 ,
166+ 524 , 524
167+ ]
168+ # fmt: on
153169
154170 # Test constants (from https://github.com/ethereum/bls12-381-tests/tree/eip-2537)
155171 P1 = PointG1 ( # random point in G1
@@ -217,17 +233,34 @@ class Spec:
217233 INVALID = b""
218234
219235
220- def msm_discount ( k : int ) -> int :
236+ class BLS12Group ( Enum ) :
221237 """
222- Returns the discount for the G1MSM and G2MSM precompiles .
238+ Helper enum to specify the BLS12 group in discount table helpers .
223239 """
224- return Spec .MSM_DISCOUNT_TABLE [min (k , 128 )]
225240
241+ G1 = auto ()
242+ G2 = auto ()
226243
227- def msm_gas_func_gen (len_per_pair : int , multiplication_cost : int ) -> Callable [[int ], int ]:
244+
245+ def msm_discount (group : BLS12Group , k : int ) -> int :
246+ """
247+ Returns the discount for the G1MSM and G2MSM precompiles.
248+ """
249+ assert k >= 1 , "k must be greater than or equal to 1"
250+ match group :
251+ case BLS12Group .G1 :
252+ return Spec .G1MSM_DISCOUNT_TABLE [min (k , 128 )]
253+ case BLS12Group .G2 :
254+ return Spec .G2MSM_DISCOUNT_TABLE [min (k , 128 )]
255+ case _:
256+ raise ValueError (f"Unsupported group: { group } " )
257+
258+
259+ def msm_gas_func_gen (
260+ group : BLS12Group , len_per_pair : int , multiplication_cost : int
261+ ) -> Callable [[int ], int ]:
228262 """
229- Generates a function that calculates the gas cost for the G1MSM and G2MSM
230- precompiles.
263+ Generate a function that calculates the gas cost for the G1MSM and G2MSM precompiles.
231264 """
232265
233266 def msm_gas (input_length : int ) -> int :
@@ -238,7 +271,7 @@ def msm_gas(input_length: int) -> int:
238271 if k == 0 :
239272 return 0
240273
241- gas_cost = k * multiplication_cost * msm_discount (k ) // Spec .MSM_MULTIPLIER
274+ gas_cost = k * multiplication_cost * msm_discount (group , k ) // Spec .MSM_MULTIPLIER
242275
243276 return gas_cost
244277
@@ -256,10 +289,10 @@ def pairing_gas(input_length: int) -> int:
256289GAS_CALCULATION_FUNCTION_MAP = {
257290 Spec .G1ADD : lambda _ : Spec .G1ADD_GAS ,
258291 Spec .G1MUL : lambda _ : Spec .G1MUL_GAS ,
259- Spec .G1MSM : msm_gas_func_gen (len (PointG1 () + Scalar ()), Spec .G1MUL_GAS ),
292+ Spec .G1MSM : msm_gas_func_gen (BLS12Group . G1 , len (PointG1 () + Scalar ()), Spec .G1MUL_GAS ),
260293 Spec .G2ADD : lambda _ : Spec .G2ADD_GAS ,
261294 Spec .G2MUL : lambda _ : Spec .G2MUL_GAS ,
262- Spec .G2MSM : msm_gas_func_gen (len (PointG2 () + Scalar ()), Spec .G2MUL_GAS ),
295+ Spec .G2MSM : msm_gas_func_gen (BLS12Group . G2 , len (PointG2 () + Scalar ()), Spec .G2MUL_GAS ),
263296 Spec .PAIRING : pairing_gas ,
264297 Spec .MAP_FP_TO_G1 : lambda _ : Spec .MAP_FP_TO_G1_GAS ,
265298 Spec .MAP_FP2_TO_G2 : lambda _ : Spec .MAP_FP2_TO_G2_GAS ,
0 commit comments