diff --git a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericData.java b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericData.java index 30cf3262415..e4e55dd6ce8 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericData.java +++ b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericData.java @@ -1106,6 +1106,69 @@ public int compare(Object o1, Object o2, Schema s) { return compare(o1, o2, s, false); } + protected int compareMaps(final Map m1, final Map m2) { + if (m1 == m2) { + return 0; + } + + if (m2.size() != m2.size()) { + return 1; + } + + /** + * Peek at keys, assuming they're all the same type within a Map + */ + final Object key1 = m1.keySet().iterator().next(); + final Object key2 = m2.keySet().iterator().next(); + boolean utf8ToString = false; + boolean stringToUtf8 = false; + + if (key1 instanceof Utf8 && key2 instanceof String) { + utf8ToString = true; + } else if (key1 instanceof String && key2 instanceof Utf8) { + stringToUtf8 = true; + } + + try { + for (Map.Entry e : m1.entrySet()) { + final Object key = e.getKey(); + Object lookupKey = key; + if (utf8ToString) { + lookupKey = key.toString(); + } else if (stringToUtf8) { + lookupKey = new Utf8((String) lookupKey); + } + final Object value = e.getValue(); + if (value == null) { + if (!(m2.get(lookupKey) == null && m2.containsKey(lookupKey))) { + return 1; + } + } else { + final Object value2 = m2.get(lookupKey); + if (value instanceof Utf8 && value2 instanceof String) { + if (!value.toString().equals(value2)) { + return 1; + } + } else if (value instanceof String && value2 instanceof Utf8) { + if (!new Utf8((String) value).equals(value2)) { + return 1; + } + } else { + if (!value.equals(value2)) { + return 1; + } + } + } + } + } catch (ClassCastException unused) { + return 1; + } catch (NullPointerException unused) { + return 1; + } + + return 0; + } + /** * Comparison implementation. When equals is true, only checks for equality, not * for order. @@ -1142,7 +1205,7 @@ protected int compare(Object o1, Object o2, Schema s, boolean equals) { return e1.hasNext() ? 1 : (e2.hasNext() ? -1 : 0); case MAP: if (equals) - return o1.equals(o2) ? 0 : 1; + return compareMaps((Map) o1, (Map) o2); throw new AvroRuntimeException("Can't compare maps!"); case UNION: int i1 = resolveUnion(s, o1); diff --git a/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericData.java b/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericData.java index 27006ad594c..a89d740620b 100644 --- a/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericData.java +++ b/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericData.java @@ -128,6 +128,48 @@ public void testEquals() { assertEquals(r1, r2); } + @Test + public void testMapKeyEquals() { + Schema mapSchema = new Schema.Parser().parse("{\"type\": \"map\", \"values\": \"string\"}"); + Field myMapField = new Field("my_map", Schema.createMap(mapSchema), null, null); + Schema schema = Schema.createRecord("my_record", "doc", "mytest", false); + schema.setFields(Arrays.asList(myMapField)); + GenericRecord r0 = new GenericData.Record(schema); + GenericRecord r1 = new GenericData.Record(schema); + + HashMap pair1 = new HashMap<>(); + pair1.put("keyOne", "valueOne"); + r0.put("my_map", pair1); + + HashMap pair2 = new HashMap<>(); + pair2.put(new Utf8("keyOne"), "valueOne"); + r1.put("my_map", pair2); + + assertEquals(r0, r1); + assertEquals(r1, r0); + } + + @Test + public void testMapValuesEquals() { + Schema mapSchema = new Schema.Parser().parse("{\"type\": \"map\", \"values\": \"string\"}"); + Field myMapField = new Field("my_map", Schema.createMap(mapSchema), null, null); + Schema schema = Schema.createRecord("my_record", "doc", "mytest", false); + schema.setFields(Arrays.asList(myMapField)); + GenericRecord r0 = new GenericData.Record(schema); + GenericRecord r1 = new GenericData.Record(schema); + + HashMap pair1 = new HashMap<>(); + pair1.put("keyOne", "valueOne"); + r0.put("my_map", pair1); + + HashMap pair2 = new HashMap<>(); + pair2.put("keyOne", new Utf8("valueOne")); + r1.put("my_map", pair2); + + assertEquals(r0, r1); + assertEquals(r1, r0); + } + private Schema recordSchema() { List fields = new ArrayList<>(); fields.add(new Field("anArray", Schema.createArray(Schema.create(Type.STRING)), null, null));