Skip to content

Commit

Permalink
Merge pull request #91 from msgpack/ext-type-jruby
Browse files Browse the repository at this point in the history
add full spec ext type support to JRuby implementation
  • Loading branch information
tagomoris committed Oct 23, 2015
2 parents 6014ece + 654de1b commit 540fdb9
Show file tree
Hide file tree
Showing 22 changed files with 1,210 additions and 637 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Gemfile*
pkg
test/debug.log
*~
*.swp
/rdoc
tmp
.classpath
Expand Down
71 changes: 55 additions & 16 deletions ext/java/org/msgpack/jruby/Decoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.jruby.RubyClass;
import org.jruby.RubyBignum;
import org.jruby.RubyString;
import org.jruby.RubyArray;
import org.jruby.RubyHash;
import org.jruby.exceptions.RaiseException;
import org.jruby.runtime.builtin.IRubyObject;
Expand All @@ -29,33 +30,58 @@ public class Decoder implements Iterator<IRubyObject> {
private final Encoding utf8Encoding;
private final RubyClass unpackErrorClass;
private final RubyClass underflowErrorClass;
private final RubyClass malformedFormatErrorClass;
private final RubyClass stackErrorClass;
private final RubyClass unexpectedTypeErrorClass;
private final RubyClass unknownExtTypeErrorClass;

private ExtensionRegistry registry;
private ByteBuffer buffer;
private boolean symbolizeKeys;
private boolean allowUnknownExt;

public Decoder(Ruby runtime) {
this(runtime, new byte[] {}, 0, 0);
this(runtime, null, new byte[] {}, 0, 0, false, false);
}

public Decoder(Ruby runtime, ExtensionRegistry registry) {
this(runtime, registry, new byte[] {}, 0, 0, false, false);
}

public Decoder(Ruby runtime, byte[] bytes) {
this(runtime, bytes, 0, bytes.length);
this(runtime, null, bytes, 0, bytes.length, false, false);
}

public Decoder(Ruby runtime, ExtensionRegistry registry, byte[] bytes) {
this(runtime, registry, bytes, 0, bytes.length, false, false);
}

public Decoder(Ruby runtime, ExtensionRegistry registry, byte[] bytes, boolean symbolizeKeys, boolean allowUnknownExt) {
this(runtime, registry, bytes, 0, bytes.length, symbolizeKeys, allowUnknownExt);
}

public Decoder(Ruby runtime, ExtensionRegistry registry, byte[] bytes, int offset, int length) {
this(runtime, registry, bytes, offset, length, false, false);
}

public Decoder(Ruby runtime, byte[] bytes, int offset, int length) {
public Decoder(Ruby runtime, ExtensionRegistry registry, byte[] bytes, int offset, int length, boolean symbolizeKeys, boolean allowUnknownExt) {
this.runtime = runtime;
this.registry = registry;
this.symbolizeKeys = symbolizeKeys;
this.allowUnknownExt = allowUnknownExt;
this.binaryEncoding = runtime.getEncodingService().getAscii8bitEncoding();
this.utf8Encoding = UTF8Encoding.INSTANCE;
this.unpackErrorClass = runtime.getModule("MessagePack").getClass("UnpackError");
this.underflowErrorClass = runtime.getModule("MessagePack").getClass("UnderflowError");
this.malformedFormatErrorClass = runtime.getModule("MessagePack").getClass("MalformedFormatError");
this.stackErrorClass = runtime.getModule("MessagePack").getClass("StackError");
this.unexpectedTypeErrorClass = runtime.getModule("MessagePack").getClass("UnexpectedTypeError");
this.unknownExtTypeErrorClass = runtime.getModule("MessagePack").getClass("UnknownExtTypeError");
this.symbolizeKeys = symbolizeKeys;
this.allowUnknownExt = allowUnknownExt;
feed(bytes, offset, length);
}

public void symbolizeKeys(boolean symbolize) {
this.symbolizeKeys = symbolize;
}

public void feed(byte[] bytes) {
feed(bytes, 0, bytes.length);
}
Expand All @@ -73,7 +99,7 @@ public void feed(byte[] bytes, int offset, int length) {
}

public void reset() {
buffer.rewind();
buffer = null;
}

public int offset() {
Expand Down Expand Up @@ -118,7 +144,20 @@ private IRubyObject consumeHash(int size) {
private IRubyObject consumeExtension(int size) {
int type = buffer.get();
byte[] payload = readBytes(size);
return ExtensionValue.newExtensionValue(runtime, type, payload);

if (registry != null) {
IRubyObject proc = registry.lookupUnpackerByTypeId(type);
if (proc != null) {
ByteList byteList = new ByteList(payload, runtime.getEncodingService().getAscii8bitEncoding());
return proc.callMethod(runtime.getCurrentContext(), "call", runtime.newString(byteList));
}
}

if (this.allowUnknownExt) {
return ExtensionValue.newExtensionValue(runtime, type, payload);
}

throw runtime.newRaiseException(unknownExtTypeErrorClass, "unexpected extension type");
}

private byte[] readBytes(int size) {
Expand All @@ -142,11 +181,11 @@ public IRubyObject read_array_header() {
try {
byte b = buffer.get();
if ((b & 0xf0) == 0x90) {
return runtime.newFixnum(b & 0x0f);
return runtime.newFixnum(b & 0x0f);
} else if (b == ARY16) {
return runtime.newFixnum(buffer.getShort() & 0xffff);
return runtime.newFixnum(buffer.getShort() & 0xffff);
} else if (b == ARY32) {
return runtime.newFixnum(buffer.getInt());
return runtime.newFixnum(buffer.getInt());
}
throw runtime.newRaiseException(unexpectedTypeErrorClass, "unexpected type");
} catch (RaiseException re) {
Expand All @@ -163,11 +202,11 @@ public IRubyObject read_map_header() {
try {
byte b = buffer.get();
if ((b & 0xf0) == 0x80) {
return runtime.newFixnum(b & 0x0f);
return runtime.newFixnum(b & 0x0f);
} else if (b == MAP16) {
return runtime.newFixnum(buffer.getShort() & 0xffff);
return runtime.newFixnum(buffer.getShort() & 0xffff);
} else if (b == MAP32) {
return runtime.newFixnum(buffer.getInt());
return runtime.newFixnum(buffer.getInt());
}
throw runtime.newRaiseException(unexpectedTypeErrorClass, "unexpected type");
} catch (RaiseException re) {
Expand Down Expand Up @@ -233,7 +272,7 @@ public IRubyObject next() {
default: return runtime.newFixnum(b);
}
buffer.position(position);
throw runtime.newRaiseException(unpackErrorClass, "Illegal byte sequence");
throw runtime.newRaiseException(malformedFormatErrorClass, "Illegal byte sequence");
} catch (RaiseException re) {
buffer.position(position);
throw re;
Expand Down
39 changes: 31 additions & 8 deletions ext/java/org/msgpack/jruby/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,21 @@ public class Encoder {
private final Encoding binaryEncoding;
private final Encoding utf8Encoding;
private final boolean compatibilityMode;
private final ExtensionRegistry registry;

private ByteBuffer buffer;

public Encoder(Ruby runtime, boolean compatibilityMode) {
public Encoder(Ruby runtime, boolean compatibilityMode, ExtensionRegistry registry) {
this.runtime = runtime;
this.buffer = ByteBuffer.allocate(CACHE_LINE_SIZE - ARRAY_HEADER_SIZE);
this.binaryEncoding = runtime.getEncodingService().getAscii8bitEncoding();
this.utf8Encoding = UTF8Encoding.INSTANCE;
this.compatibilityMode = compatibilityMode;
this.registry = registry;
}

public boolean isCompatibilityMode() {
return compatibilityMode;
}

private void ensureRemainingCapacity(int c) {
Expand Down Expand Up @@ -107,7 +113,7 @@ private void appendObject(IRubyObject object, IRubyObject destination) {
} else if (object instanceof ExtensionValue) {
appendExtensionValue((ExtensionValue) object);
} else {
appendCustom(object, destination);
appendOther(object, destination);
}
}

Expand Down Expand Up @@ -295,12 +301,7 @@ public void visit(IRubyObject key, IRubyObject value) {
}
}

private void appendExtensionValue(ExtensionValue object) {
long type = ((RubyFixnum)object.get_type()).getLongValue();
if (type < -128 || type > 127) {
throw object.getRuntime().newRangeError(String.format("integer %d too big to convert to `signed char'", type));
}
ByteList payloadBytes = ((RubyString)object.payload()).getByteList();
private void appendExt(int type, ByteList payloadBytes) {
int payloadSize = payloadBytes.length();
int outputSize = 0;
boolean fixSize = payloadSize == 1 || payloadSize == 2 || payloadSize == 4 || payloadSize == 8 || payloadSize == 16;
Expand Down Expand Up @@ -338,6 +339,28 @@ private void appendExtensionValue(ExtensionValue object) {
buffer.put(payloadBytes.unsafeBytes(), payloadBytes.begin(), payloadSize);
}

private void appendExtensionValue(ExtensionValue object) {
long type = ((RubyFixnum)object.get_type()).getLongValue();
if (type < -128 || type > 127) {
throw object.getRuntime().newRangeError(String.format("integer %d too big to convert to `signed char'", type));
}
ByteList payloadBytes = ((RubyString)object.payload()).getByteList();
appendExt((int) type, payloadBytes);
}

private void appendOther(IRubyObject object, IRubyObject destination) {
if (registry != null) {
IRubyObject[] pair = registry.lookupPackerByClass(object.getType());
if (pair != null) {
RubyString bytes = pair[0].callMethod(runtime.getCurrentContext(), "call", object).asString();
int type = (int) ((RubyFixnum) pair[1]).getLongValue();
appendExt(type, bytes.getByteList());
return;
}
}
appendCustom(object, destination);
}

private void appendCustom(IRubyObject object, IRubyObject destination) {
if (destination == null) {
IRubyObject result = object.callMethod(runtime.getCurrentContext(), "to_msgpack");
Expand Down
159 changes: 159 additions & 0 deletions ext/java/org/msgpack/jruby/ExtensionRegistry.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package org.msgpack.jruby;

import org.jruby.Ruby;
import org.jruby.RubyHash;
import org.jruby.RubyArray;
import org.jruby.RubyClass;
import org.jruby.RubyFixnum;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;

import java.util.Map;
import java.util.HashMap;

public class ExtensionRegistry {
private final Map<RubyClass, ExtensionEntry> extensionsByClass;
private final Map<RubyClass, ExtensionEntry> extensionsByAncestor;
private final ExtensionEntry[] extensionsByTypeId;

public ExtensionRegistry() {
this(new HashMap<RubyClass, ExtensionEntry>());
}

private ExtensionRegistry(Map<RubyClass, ExtensionEntry> extensionsByClass) {
this.extensionsByClass = new HashMap<RubyClass, ExtensionEntry>(extensionsByClass);
this.extensionsByAncestor = new HashMap<RubyClass, ExtensionEntry>();
this.extensionsByTypeId = new ExtensionEntry[256];
for (ExtensionEntry entry : extensionsByClass.values()) {
if (entry.hasUnpacker()) {
extensionsByTypeId[entry.getTypeId() + 128] = entry;
}
}
}

public ExtensionRegistry dup() {
return new ExtensionRegistry(extensionsByClass);
}

public IRubyObject toInternalPackerRegistry(ThreadContext ctx) {
RubyHash hash = RubyHash.newHash(ctx.getRuntime());
for (RubyClass extensionClass : extensionsByClass.keySet()) {
ExtensionEntry entry = extensionsByClass.get(extensionClass);
if (entry.hasPacker()) {
hash.put(extensionClass, entry.toPackerTuple(ctx));
}
}
return hash;
}

public IRubyObject toInternalUnpackerRegistry(ThreadContext ctx) {
RubyHash hash = RubyHash.newHash(ctx.getRuntime());
for (int typeIdIndex = 0 ; typeIdIndex < 256 ; typeIdIndex++) {
ExtensionEntry entry = extensionsByTypeId[typeIdIndex];
if (entry != null && entry.hasUnpacker()) {
IRubyObject typeId = RubyFixnum.newFixnum(ctx.getRuntime(), typeIdIndex - 128);
hash.put(typeId, entry.toUnpackerTuple(ctx));
}
}
return hash;
}

public void put(RubyClass cls, int typeId, IRubyObject packerProc, IRubyObject packerArg, IRubyObject unpackerProc, IRubyObject unpackerArg) {
ExtensionEntry entry = new ExtensionEntry(cls, typeId, packerProc, packerArg, unpackerProc, unpackerArg);
extensionsByClass.put(cls, entry);
extensionsByTypeId[typeId + 128] = entry;
extensionsByAncestor.clear();
}

public IRubyObject lookupUnpackerByTypeId(int typeId) {
ExtensionEntry e = extensionsByTypeId[typeId + 128];
if (e != null && e.hasUnpacker()) {
return e.getUnpackerProc();
} else {
return null;
}
}

public IRubyObject[] lookupPackerByClass(RubyClass cls) {
ExtensionEntry e = extensionsByClass.get(cls);
if (e == null) {
e = extensionsByAncestor.get(cls);
}
if (e == null) {
e = findEntryByClassOrAncestor(cls);
if (e != null) {
extensionsByAncestor.put(e.getExtensionClass(), e);
}
}
if (e != null && e.hasPacker()) {
return e.toPackerProcTypeIdPair(cls.getRuntime().getCurrentContext());
} else {
return null;
}
}

private ExtensionEntry findEntryByClassOrAncestor(final RubyClass cls) {
ThreadContext ctx = cls.getRuntime().getCurrentContext();
for (RubyClass extensionClass : extensionsByClass.keySet()) {
RubyArray ancestors = (RubyArray) cls.callMethod(ctx, "ancestors");
if (ancestors.callMethod(ctx, "include?", extensionClass).isTrue()) {
return extensionsByClass.get(extensionClass);
}
}
return null;
}

private static class ExtensionEntry {
private final RubyClass cls;
private final int typeId;
private final IRubyObject packerProc;
private final IRubyObject packerArg;
private final IRubyObject unpackerProc;
private final IRubyObject unpackerArg;

public ExtensionEntry(RubyClass cls, int typeId, IRubyObject packerProc, IRubyObject packerArg, IRubyObject unpackerProc, IRubyObject unpackerArg) {
this.cls = cls;
this.typeId = typeId;
this.packerProc = packerProc;
this.packerArg = packerArg;
this.unpackerProc = unpackerProc;
this.unpackerArg = unpackerArg;
}

public RubyClass getExtensionClass() {
return cls;
}

public int getTypeId() {
return typeId;
}

public boolean hasPacker() {
return packerProc != null;
}

public boolean hasUnpacker() {
return unpackerProc != null;
}

public IRubyObject getPackerProc() {
return packerProc;
}

public IRubyObject getUnpackerProc() {
return unpackerProc;
}

public RubyArray toPackerTuple(ThreadContext ctx) {
return RubyArray.newArray(ctx.getRuntime(), new IRubyObject[] {RubyFixnum.newFixnum(ctx.getRuntime(), typeId), packerProc, packerArg});
}

public RubyArray toUnpackerTuple(ThreadContext ctx) {
return RubyArray.newArray(ctx.getRuntime(), new IRubyObject[] {cls, unpackerProc, unpackerArg});
}

public IRubyObject[] toPackerProcTypeIdPair(ThreadContext ctx) {
return new IRubyObject[] {packerProc, RubyFixnum.newFixnum(ctx.getRuntime(), typeId)};
}
}
}
Loading

0 comments on commit 540fdb9

Please sign in to comment.