Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/runtime/program/zk_elgamal/execute.zig
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ fn processVerifyProof(
break :cd proof_data.context;
};

// create context state if additional accounts are provided with the instruction
// Create context state if additional accounts are provided with the instruction.
if (ic.ixn_info.account_metas.items.len >= accessed_accounts + 2) {
const context_authority_key = blk: {
const context_state_authority = try ic.borrowInstructionAccount(accessed_accounts + 1);
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/program/zk_elgamal/lib.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const sig = @import("../../../sig.zig");

/// [agave] https://github.com/solana-program/zk-elgamal-proof/blob/zk-sdk%40v5.0.0/zk-sdk/src/zk_elgamal_proof_program/proof_data/mod.rs#L48
pub const ProofType = enum(u8) {
/// Empty proof type used to distinguish if a proof context account is initialized
uninitialized,
Expand Down Expand Up @@ -35,6 +36,7 @@ pub fn ProofContextState(C: type) type {

pub const ID: sig.core.Pubkey = .parse("ZkE1Gama1Proof11111111111111111111111111111");

// [agave] https://github.com/anza-xyz/agave/blob/master/programs/zk-elgamal-proof/src/lib.rs#L19-L31
pub const CLOSE_CONTEXT_STATE_COMPUTE_UNITS: u64 = 3_300;
pub const VERIFY_ZERO_CIPHERTEXT_COMPUTE_UNITS: u64 = 6_000;
pub const VERIFY_CIPHERTEXT_CIPHERTEXT_EQUALITY_COMPUTE_UNITS: u64 = 8_000;
Expand Down
27 changes: 25 additions & 2 deletions src/zksdk/elgamal.zig
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ pub const Pubkey = struct {
);
return fromBytes(buffer);
}

pub fn rejectIdentity(self: *const Pubkey) error{IdentityElement}!void {
try self.point.rejectIdentity();
}
};

pub const Keypair = struct {
Expand Down Expand Up @@ -107,6 +111,11 @@ pub const Ciphertext = struct {
);
return fromBytes(buffer);
}

pub fn rejectIdentity(self: *const Ciphertext) error{IdentityElement}!void {
try self.commitment.point.rejectIdentity();
try self.handle.point.rejectIdentity();
}
};

pub fn encrypt(comptime T: type, value: T, pubkey: *const Pubkey) Ciphertext {
Expand Down Expand Up @@ -168,13 +177,27 @@ pub fn GroupedElGamalCiphertext(comptime N: u64) type {
};
}

pub fn fromBase64(string: []const u8) !Self {
const base64 = std.base64.standard;
var buffer: [BYTE_LEN]u8 = @splat(0);
const decoded_length = try base64.Decoder.calcSizeForSlice(string);
try std.base64.standard.Decoder.decode(
buffer[0..decoded_length],
string,
);
return fromBytes(buffer);
}

pub fn toBytes(self: Self) [BYTE_LEN]u8 {
var handles: [N * 32]u8 = undefined;
for (self.handles, 0..) |handle, i| {
const position = i * 32;
handles[position..][0..32].* = handle.point.toBytes();
handles[i * 32 ..][0..32].* = handle.point.toBytes();
}
return self.commitment.point.toBytes() ++ handles;
}

pub fn rejectIdentity(self: *const Self) error{IdentityElement}!void {
try self.commitment.rejectIdentity();
}
};
}
16 changes: 8 additions & 8 deletions src/zksdk/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ pub const PercentageWithCapData = percentage.Data;
pub const PubkeyProofData = pubkey_validity.Data;
pub const ZeroCiphertextData = zero_ciphertext.Data;

// grouped ciphertext validity
const grouped_cipher_handles_2 = @import("sigma_proofs/grouped_ciphertext/handles_2.zig");
const grouped_cipher_handles_3 = @import("sigma_proofs/grouped_ciphertext/handles_3.zig");
// // grouped ciphertext validity
const grouped_cipher_2_handles = @import("sigma_proofs/grouped_ciphertext/2_handles.zig");
const grouped_cipher_3_handles = @import("sigma_proofs/grouped_ciphertext/3_handles.zig");

pub const GroupedCiphertext2HandlesData = grouped_cipher_handles_2.Data;
pub const BatchedGroupedCiphertext2HandlesData = grouped_cipher_handles_2.BatchedData;
pub const GroupedCiphertext3HandlesData = grouped_cipher_handles_3.Data;
pub const BatchedGroupedCiphertext3HandlesData = grouped_cipher_handles_3.BatchedData;
pub const GroupedCiphertext2HandlesData = grouped_cipher_2_handles.Data;
pub const BatchedGroupedCiphertext2HandlesData = grouped_cipher_2_handles.BatchedData;
pub const GroupedCiphertext3HandlesData = grouped_cipher_3_handles.Data;
pub const BatchedGroupedCiphertext3HandlesData = grouped_cipher_3_handles.BatchedData;

// range proof
// // range proof
pub const bulletproofs = @import("range_proof/bulletproofs.zig");

pub const RangeProofU64Data = bulletproofs.Data(64);
Expand Down
150 changes: 100 additions & 50 deletions src/zksdk/merlin.zig
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,26 @@ pub const Transcript = struct {
ciphertext: zksdk.elgamal.Ciphertext,
commitment: zksdk.pedersen.Commitment,
u64: u64,
domsep: DomainSeperator,

grouped_2: zksdk.elgamal.GroupedElGamalCiphertext(2),
grouped_3: zksdk.elgamal.GroupedElGamalCiphertext(3),
};

pub fn init(comptime seperator: DomainSeperator, inputs: []const TranscriptInput) Transcript {
/// [agave] https://github.com/solana-program/zk-elgamal-proof/blob/zk-sdk%40v5.0.0/zk-sdk/src/lib.rs#L36
const TRANSCRIPT_DOMAIN = "solana-zk-elgamal-proof-program-v1";

pub fn init(comptime seperator: DomainSeperator) Transcript {
var transcript: Transcript = .{ .strobe = Strobe128.init("Merlin v1.0") };
transcript.appendDomSep(seperator);
for (inputs) |input| transcript.appendMessage(input.label, input.message);
transcript.appendBytes("dom-sep", TRANSCRIPT_DOMAIN);
transcript.appendBytes("dom-sep", @tagName(seperator));
return transcript;
}

pub fn initTest(label: []const u8) Transcript {
comptime if (!builtin.is_test) @compileError("should only be used during tests");
var transcript: Transcript = .{ .strobe = Strobe128.init("Merlin v1.0") };
transcript.appendBytes("dom-sep", TRANSCRIPT_DOMAIN);
transcript.appendBytes("dom-sep", label);
return transcript;
}
Expand All @@ -264,6 +269,7 @@ pub const Transcript = struct {
.point => |*point| &point.toBytes(),
.pubkey => |*pubkey| &pubkey.toBytes(),
.scalar => |*scalar| &scalar.toBytes(),
.domsep => |t| @tagName(t),
.ciphertext => |*ct| b: {
@memcpy(buffer[0..32], &ct.commitment.point.toBytes());
@memcpy(buffer[32..64], &ct.handle.point.toBytes());
Expand All @@ -284,26 +290,33 @@ pub const Transcript = struct {
comptime session: *Session,
comptime t: Input.Type,
comptime label: []const u8,
data: t.Data(),
) if (t == .validate_point) error{IdentityElement}!void else void {
// if validate_point fails to validate, we no longer want to check the contract
data: @FieldType(Message, @tagName(t.base())),
) if (t.validates()) error{IdentityElement}!void else void {
// If validate_point fails to validate, we no longer want to check the contract
// because the function calling append will now return early.
errdefer session.cancel();

if (t == .bytes and !builtin.is_test)
@compileError("message type `bytes` only allowed in tests");

// assert correctness
// Get the next expected input, and inside we verify that it matches
// the type we're about to append to the transcript.
const input = comptime session.nextInput(t, label);
if (t == .validate_point) try data.rejectIdentity();
// If the input requires validation, we perform it here.
if (comptime t.validates()) try data.rejectIdentity();
// Ensure that the domain seperators are added with the correct label.
// They should always be added through the `appendDomSep` helper function.
switch (t) {
.domsep => comptime {
std.debug.assert(input.seperator.? == data);
std.debug.assert(std.mem.eql(u8, label, "dom-sep"));
},
else => {},
}

// add the message
self.appendMessage(input.label, @unionInit(
Message,
@tagName(switch (t) {
.validate_point => .point,
else => t,
}),
@tagName(t.base()),
data,
));
}
Expand All @@ -314,12 +327,16 @@ pub const Transcript = struct {
pub inline fn appendNoValidate(
self: *Transcript,
comptime session: *Session,
comptime t: Input.Type,
comptime label: []const u8,
point: Ristretto255,
data: @FieldType(Message, @tagName(t.base())),
) void {
const input = comptime session.nextInput(.validate_point, label);
point.rejectIdentity() catch {}; // ignore the error
self.appendMessage(input.label, .{ .point = point });
const input = comptime session.nextInput(
@field(Input.Type, "validate_" ++ @tagName(t)),
label,
);
data.rejectIdentity() catch {}; // ignore the error
self.appendMessage(input.label, @unionInit(Message, @tagName(t), data));
}

fn challengeBytes(
Expand All @@ -329,7 +346,6 @@ pub const Transcript = struct {
) void {
var data_len: [4]u8 = undefined;
std.mem.writeInt(u32, &data_len, @intCast(destination.len), .little);

self.strobe.metaAd(label, false);
self.strobe.metaAd(&data_len, true);
self.strobe.prf(destination, false);
Expand All @@ -351,64 +367,89 @@ pub const Transcript = struct {

// domain seperation helpers

pub fn appendDomSep(self: *Transcript, comptime seperator: DomainSeperator) void {
self.appendBytes("dom-sep", @tagName(seperator));
}

pub fn appendHandleDomSep(
pub inline fn appendDomSep(
self: *Transcript,
comptime mode: enum { batched, unbatched },
comptime handles: enum { two, three },
comptime session: *Session,
comptime seperator: DomainSeperator,
) void {
self.appendDomSep(switch (mode) {
.batched => .@"batched-validity-proof",
.unbatched => .@"validity-proof",
});
self.appendMessage("handles", .{ .u64 = switch (handles) {
.two => 2,
.three => 3,
} });
self.append(session, .domsep, "dom-sep", seperator);
}

pub fn appendRangeProof(
pub inline fn appendRangeProof(
self: *Transcript,
comptime session: *Session,
comptime mode: enum { range, inner },
n: comptime_int,
) void {
self.appendDomSep(switch (mode) {
self.appendDomSep(session, switch (mode) {
.range => .@"range-proof",
.inner => .@"inner-product",
});
self.appendMessage("n", .{ .u64 = n });
self.append(session, .u64, "n", n);
}

// sessions

pub const Input = struct {
label: []const u8,
type: Type,
seperator: ?DomainSeperator = null,

const Type = enum {
bytes,
scalar,
challenge,
u64,

point,
validate_point,
pubkey,
ciphertext,
commitment,
grouped_2,
grouped_3,

pub fn Data(comptime t: Type) type {
validate_point,
validate_pubkey,
validate_ciphertext,
validate_commitment,
validate_grouped_2,
validate_grouped_3,

domsep,
challenge,

/// Returns whether this input type performs identity validation.
fn validates(t: Type) bool {
return switch (t) {
.bytes => []const u8,
.scalar => Scalar,
.validate_point, .point => Ristretto255,
.pubkey => zksdk.elgamal.Pubkey,
.challenge => unreachable, // call `challenge*`
.validate_point,
.validate_pubkey,
.validate_ciphertext,
.validate_commitment,
.validate_grouped_2,
.validate_grouped_3,
=> true,
else => false,
};
}

/// For a given input type, returns the base type.
/// E.g. `validate_point` -> `point`
/// E.g. `point` -> `point`
fn base(t: Type) Type {
if (t.validates()) {
return @field(Type, @tagName(t)["validate_".len..]);
}
return t;
}
};

pub fn domain(sep: DomainSeperator) Input {
return .{ .label = "dom-sep", .type = .domsep, .seperator = sep };
}

fn check(self: Input, t: Type, label: []const u8) void {
std.debug.assert(self.type == t);
if (self.type != t) {
@compileError("expected: " ++ @tagName(self.type) ++ ", found: " ++ @tagName(t));
}
std.debug.assert(std.mem.eql(u8, self.label, label));
}
};
Expand All @@ -418,7 +459,8 @@ pub const Transcript = struct {
pub const Session = struct {
i: u8,
contract: Contract,
err: bool, // if validate_point errors, we skip the finish() check
// If an identity validation errors, we skip the finish() check.
err: bool,

pub inline fn nextInput(comptime self: *Session, t: Input.Type, label: []const u8) Input {
comptime {
Expand Down Expand Up @@ -453,6 +495,14 @@ pub const Transcript = struct {
return .{ .i = 0, .contract = contract, .err = false };
}
}

/// The same as `getSession`, but does not check that it ends with a challenge.
/// Only used in certain cases when we need an "init" contract, such as `percentage_with_cap`.
pub inline fn getInitSession(comptime contract: []const Input) Session {
comptime {
return .{ .i = 0, .contract = contract, .err = false };
}
}
};

test "equivalence" {
Expand All @@ -468,9 +518,9 @@ test "equivalence" {
transcript.challengeBytes("challenge", &bytes);

try std.testing.expectEqualSlices(u8, &.{
0xd5, 0xa2, 0x19, 0x72, 0xd0, 0xd5, 0xfe, 0x32,
0xc, 0xd, 0x26, 0x3f, 0xac, 0x7f, 0xff, 0xb8,
0x14, 0x5a, 0xa6, 0x40, 0xaf, 0x6e, 0x9b, 0xca,
0x17, 0x7c, 0x3, 0xc7, 0xef, 0xcf, 0x6, 0x15,
159, 115, 74, 116, 119, 227, 89, 42,
108, 83, 69, 218, 43, 29, 11, 79,
117, 141, 121, 172, 163, 50, 123, 92,
25, 21, 111, 177, 11, 232, 4, 35,
}, &bytes);
}
4 changes: 4 additions & 0 deletions src/zksdk/pedersen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ pub const Commitment = struct {
);
return fromBytes(buffer);
}

pub fn rejectIdentity(self: *const Commitment) error{IdentityElement}!void {
try self.point.rejectIdentity();
}
};

pub const DecryptHandle = struct {
Expand Down
Loading