diff --git a/parquet-column/src/main/java/parquet/column/ParquetProperties.java b/parquet-column/src/main/java/parquet/column/ParquetProperties.java
index c083867c09..b32554855c 100644
--- a/parquet-column/src/main/java/parquet/column/ParquetProperties.java
+++ b/parquet-column/src/main/java/parquet/column/ParquetProperties.java
@@ -202,19 +202,20 @@ public boolean isEnableDictionary() {
 
   public ColumnWriteStore newColumnWriteStore(
       MessageType schema,
-      PageWriteStore pageStore, int pageSize,
-      int initialPageBufferSize) {
+      PageWriteStore pageStore,
+      int pageSize) {
     switch (writerVersion) {
     case PARQUET_1_0:
       return new ColumnWriteStoreV1(
           pageStore,
-          pageSize, initialPageBufferSize, dictionaryPageSizeThreshold,
+          pageSize,
+          dictionaryPageSizeThreshold,
           enableDictionary, writerVersion);
     case PARQUET_2_0:
       return new ColumnWriteStoreV2(
           schema,
           pageStore,
-          pageSize, initialPageBufferSize,
+          pageSize,
           new ParquetProperties(dictionaryPageSizeThreshold, writerVersion, enableDictionary));
     default:
       throw new IllegalArgumentException("unknown version " + writerVersion);
diff --git a/parquet-column/src/main/java/parquet/column/impl/ColumnWriteStoreV1.java b/parquet-column/src/main/java/parquet/column/impl/ColumnWriteStoreV1.java
index 884c665570..06bde5839f 100644
--- a/parquet-column/src/main/java/parquet/column/impl/ColumnWriteStoreV1.java
+++ b/parquet-column/src/main/java/parquet/column/impl/ColumnWriteStoreV1.java
@@ -36,14 +36,12 @@ public class ColumnWriteStoreV1 implements ColumnWriteStore {
   private final int pageSizeThreshold;
   private final int dictionaryPageSizeThreshold;
   private final boolean enableDictionary;
-  private final int initialSizePerCol;
   private final WriterVersion writerVersion;
 
-  public ColumnWriteStoreV1(PageWriteStore pageWriteStore, int pageSizeThreshold, int initialSizePerCol, int dictionaryPageSizeThreshold, boolean enableDictionary, WriterVersion writerVersion) {
+  public ColumnWriteStoreV1(PageWriteStore pageWriteStore, int pageSizeThreshold, int dictionaryPageSizeThreshold, boolean enableDictionary, WriterVersion writerVersion) {
     super();
     this.pageWriteStore = pageWriteStore;
     this.pageSizeThreshold = pageSizeThreshold;
-    this.initialSizePerCol = initialSizePerCol;
     this.dictionaryPageSizeThreshold = dictionaryPageSizeThreshold;
     this.enableDictionary = enableDictionary;
     this.writerVersion = writerVersion;
@@ -64,7 +62,7 @@ public Set<ColumnDescriptor> getColumnDescriptors() {
 
   private ColumnWriterV1 newMemColumn(ColumnDescriptor path) {
     PageWriter pageWriter = pageWriteStore.getPageWriter(path);
-    return new ColumnWriterV1(path, pageWriter, pageSizeThreshold, initialSizePerCol, dictionaryPageSizeThreshold, enableDictionary, writerVersion);
+    return new ColumnWriterV1(path, pageWriter, pageSizeThreshold, dictionaryPageSizeThreshold, enableDictionary, writerVersion);
   }
 
   @Override
diff --git a/parquet-column/src/main/java/parquet/column/impl/ColumnWriteStoreV2.java b/parquet-column/src/main/java/parquet/column/impl/ColumnWriteStoreV2.java
index 03a219d13e..c1046965d8 100644
--- a/parquet-column/src/main/java/parquet/column/impl/ColumnWriteStoreV2.java
+++ b/parquet-column/src/main/java/parquet/column/impl/ColumnWriteStoreV2.java
@@ -53,7 +53,7 @@ public class ColumnWriteStoreV2 implements ColumnWriteStore {
   public ColumnWriteStoreV2(
       MessageType schema,
       PageWriteStore pageWriteStore,
-      int pageSizeThreshold, int initialSizePerCol,
+      int pageSizeThreshold,
       ParquetProperties parquetProps) {
     super();
     this.pageSizeThreshold = pageSizeThreshold;
@@ -61,7 +61,7 @@ public ColumnWriteStoreV2(
     Map<ColumnDescriptor, ColumnWriterV2> mcolumns = new TreeMap<ColumnDescriptor, ColumnWriterV2>();
     for (ColumnDescriptor path : schema.getColumns()) {
       PageWriter pageWriter = pageWriteStore.getPageWriter(path);
-      mcolumns.put(path, new ColumnWriterV2(path, pageWriter, initialSizePerCol, parquetProps, pageSizeThreshold));
+      mcolumns.put(path, new ColumnWriterV2(path, pageWriter, parquetProps, pageSizeThreshold));
     }
     this.columns = unmodifiableMap(mcolumns);
     this.writers = this.columns.values();
diff --git a/parquet-column/src/main/java/parquet/column/impl/ColumnWriterV1.java b/parquet-column/src/main/java/parquet/column/impl/ColumnWriterV1.java
index ac3fc19e3c..fdca2f5502 100644
--- a/parquet-column/src/main/java/parquet/column/impl/ColumnWriterV1.java
+++ b/parquet-column/src/main/java/parquet/column/impl/ColumnWriterV1.java
@@ -20,6 +20,7 @@
 import java.io.IOException;
 
 import parquet.Log;
+import parquet.bytes.CapacityByteArrayOutputStream;
 import parquet.column.ColumnDescriptor;
 import parquet.column.ColumnWriter;
 import parquet.column.ParquetProperties;
@@ -31,6 +32,9 @@
 import parquet.io.ParquetEncodingException;
 import parquet.io.api.Binary;
 
+import static java.lang.Math.max;
+import static java.lang.Math.pow;
+
 /**
  * Writes (repetition level, definition level, value) triplets and deals with writing pages to the underlying layer.
  *
@@ -41,6 +45,7 @@ final class ColumnWriterV1 implements ColumnWriter {
   private static final Log LOG = Log.getLog(ColumnWriterV1.class);
   private static final boolean DEBUG = Log.DEBUG;
   private static final int INITIAL_COUNT_FOR_SIZE_CHECK = 100;
+  private static final int MIN_SLAB_SIZE = 64;
 
   private final ColumnDescriptor path;
   private final PageWriter pageWriter;
@@ -57,7 +62,6 @@ public ColumnWriterV1(
       ColumnDescriptor path,
       PageWriter pageWriter,
       int pageSizeThreshold,
-      int initialSizePerCol,
       int dictionaryPageSizeThreshold,
       boolean enableDictionary,
       WriterVersion writerVersion) {
@@ -69,9 +73,12 @@ public ColumnWriterV1(
     resetStatistics();
 
     ParquetProperties parquetProps = new ParquetProperties(dictionaryPageSizeThreshold, writerVersion, enableDictionary);
-    this.repetitionLevelColumn = ParquetProperties.getColumnDescriptorValuesWriter(path.getMaxRepetitionLevel(), initialSizePerCol, pageSizeThreshold);
-    this.definitionLevelColumn = ParquetProperties.getColumnDescriptorValuesWriter(path.getMaxDefinitionLevel(), initialSizePerCol, pageSizeThreshold);
-    this.dataColumn = parquetProps.getValuesWriter(path, initialSizePerCol, pageSizeThreshold);
+
+    this.repetitionLevelColumn = ParquetProperties.getColumnDescriptorValuesWriter(path.getMaxRepetitionLevel(), MIN_SLAB_SIZE, pageSizeThreshold);
+    this.definitionLevelColumn = ParquetProperties.getColumnDescriptorValuesWriter(path.getMaxDefinitionLevel(), MIN_SLAB_SIZE, pageSizeThreshold);
+
+    int initialSlabSize = CapacityByteArrayOutputStream.initialSlabSizeHeuristic(MIN_SLAB_SIZE, pageSizeThreshold, 10);
+    this.dataColumn = parquetProps.getValuesWriter(path, initialSlabSize, pageSizeThreshold);
   }
 
   private void log(Object value, int r, int d) {
diff --git a/parquet-column/src/main/java/parquet/column/impl/ColumnWriterV2.java b/parquet-column/src/main/java/parquet/column/impl/ColumnWriterV2.java
index 100bca28ae..df1075a499 100644
--- a/parquet-column/src/main/java/parquet/column/impl/ColumnWriterV2.java
+++ b/parquet-column/src/main/java/parquet/column/impl/ColumnWriterV2.java
@@ -15,6 +15,8 @@
  */
 package parquet.column.impl;
 
+import static java.lang.Math.max;
+import static java.lang.Math.pow;
 import static parquet.bytes.BytesUtils.getWidthFromMaxInt;
 
 import java.io.IOException;
@@ -22,6 +24,7 @@
 import parquet.Ints;
 import parquet.Log;
 import parquet.bytes.BytesInput;
+import parquet.bytes.CapacityByteArrayOutputStream;
 import parquet.column.ColumnDescriptor;
 import parquet.column.ColumnWriter;
 import parquet.column.Encoding;
@@ -43,6 +46,7 @@
 final class ColumnWriterV2 implements ColumnWriter {
   private static final Log LOG = Log.getLog(ColumnWriterV2.class);
   private static final boolean DEBUG = Log.DEBUG;
+  private static final int MIN_SLAB_SIZE = 64;
 
   private final ColumnDescriptor path;
   private final PageWriter pageWriter;
@@ -57,15 +61,17 @@ final class ColumnWriterV2 implements ColumnWriter {
   public ColumnWriterV2(
       ColumnDescriptor path,
       PageWriter pageWriter,
-      int initialSizePerCol,
       ParquetProperties parquetProps,
       int pageSize) {
     this.path = path;
     this.pageWriter = pageWriter;
     resetStatistics();
-    this.repetitionLevelColumn = new RunLengthBitPackingHybridEncoder(getWidthFromMaxInt(path.getMaxRepetitionLevel()), initialSizePerCol, pageSize);
-    this.definitionLevelColumn = new RunLengthBitPackingHybridEncoder(getWidthFromMaxInt(path.getMaxDefinitionLevel()), initialSizePerCol, pageSize);
-    this.dataColumn = parquetProps.getValuesWriter(path, initialSizePerCol, pageSize);
+
+    this.repetitionLevelColumn = new RunLengthBitPackingHybridEncoder(getWidthFromMaxInt(path.getMaxRepetitionLevel()), MIN_SLAB_SIZE, pageSize);
+    this.definitionLevelColumn = new RunLengthBitPackingHybridEncoder(getWidthFromMaxInt(path.getMaxDefinitionLevel()), MIN_SLAB_SIZE, pageSize);
+
+    int initialSlabSize = CapacityByteArrayOutputStream.initialSlabSizeHeuristic(MIN_SLAB_SIZE, pageSize, 10);
+    this.dataColumn = parquetProps.getValuesWriter(path, initialSlabSize, pageSize);
   }
 
   private void log(Object value, int r, int d) {
diff --git a/parquet-column/src/main/java/parquet/column/values/dictionary/DictionaryValuesWriter.java b/parquet-column/src/main/java/parquet/column/values/dictionary/DictionaryValuesWriter.java
index 9488a4c8ce..8fca4fd613 100644
--- a/parquet-column/src/main/java/parquet/column/values/dictionary/DictionaryValuesWriter.java
+++ b/parquet-column/src/main/java/parquet/column/values/dictionary/DictionaryValuesWriter.java
@@ -39,6 +39,7 @@
 import parquet.Log;
 import parquet.bytes.BytesInput;
 import parquet.bytes.BytesUtils;
+import parquet.bytes.CapacityByteArrayOutputStream;
 import parquet.column.Encoding;
 import parquet.column.page.DictionaryPage;
 import parquet.column.values.RequiresFallback;
@@ -62,6 +63,7 @@ public abstract class DictionaryValuesWriter extends ValuesWriter implements Req
 
   /* max entries allowed for the dictionary will fail over to plain encoding if reached */
   private static final int MAX_DICTIONARY_ENTRIES = Integer.MAX_VALUE - 1;
+  private static final int MIN_INITIAL_SLAB_SIZE = 64;
 
   /* encoding to label the data page */
   private final Encoding encodingForDataPage;
@@ -142,8 +144,12 @@ public BytesInput getBytes() {
     int maxDicId = getDictionarySize() - 1;
     if (DEBUG) LOG.debug("max dic id " + maxDicId);
     int bitWidth = BytesUtils.getWidthFromMaxInt(maxDicId);
-    // TODO: what is a good initialCapacity?
-    RunLengthBitPackingHybridEncoder encoder = new RunLengthBitPackingHybridEncoder(bitWidth, 64 * 1024, maxDictionaryByteSize);
+
+    int initialSlabSize =
+        CapacityByteArrayOutputStream.initialSlabSizeHeuristic(MIN_INITIAL_SLAB_SIZE, maxDictionaryByteSize, 10);
+
+    RunLengthBitPackingHybridEncoder encoder =
+        new RunLengthBitPackingHybridEncoder(bitWidth, initialSlabSize, maxDictionaryByteSize);
     IntIterator iterator = encodedValues.iterator();
     try {
       while (iterator.hasNext()) {
diff --git a/parquet-column/src/test/java/parquet/column/impl/TestColumnReaderImpl.java b/parquet-column/src/test/java/parquet/column/impl/TestColumnReaderImpl.java
index bcff4761e2..dda8187e66 100644
--- a/parquet-column/src/test/java/parquet/column/impl/TestColumnReaderImpl.java
+++ b/parquet-column/src/test/java/parquet/column/impl/TestColumnReaderImpl.java
@@ -38,7 +38,7 @@ public void test() {
     MessageType schema = MessageTypeParser.parseMessageType("message test { required binary foo; }");
     ColumnDescriptor col = schema.getColumns().get(0);
     MemPageWriter pageWriter = new MemPageWriter();
-    ColumnWriterV2 columnWriterV2 = new ColumnWriterV2(col, pageWriter, 1024, new ParquetProperties(1024, PARQUET_2_0, true), 2048);
+    ColumnWriterV2 columnWriterV2 = new ColumnWriterV2(col, pageWriter, new ParquetProperties(1024, PARQUET_2_0, true), 2048);
     for (int i = 0; i < rows; i++) {
       columnWriterV2.write(Binary.fromString("bar" + i % 10), 0, 0);
       if ((i + 1) % 1000 == 0) {
@@ -73,7 +73,7 @@ public void testOptional() {
     MessageType schema = MessageTypeParser.parseMessageType("message test { optional binary foo; }");
     ColumnDescriptor col = schema.getColumns().get(0);
     MemPageWriter pageWriter = new MemPageWriter();
-    ColumnWriterV2 columnWriterV2 = new ColumnWriterV2(col, pageWriter, 1024, new ParquetProperties(1024, PARQUET_2_0, true), 2048);
+    ColumnWriterV2 columnWriterV2 = new ColumnWriterV2(col, pageWriter, new ParquetProperties(1024, PARQUET_2_0, true), 2048);
     for (int i = 0; i < rows; i++) {
       columnWriterV2.writeNull(0, 0);
       if ((i + 1) % 1000 == 0) {
diff --git a/parquet-column/src/test/java/parquet/column/mem/TestMemColumn.java b/parquet-column/src/test/java/parquet/column/mem/TestMemColumn.java
index a386bbba92..b0abf55fa2 100644
--- a/parquet-column/src/test/java/parquet/column/mem/TestMemColumn.java
+++ b/parquet-column/src/test/java/parquet/column/mem/TestMemColumn.java
@@ -156,6 +156,6 @@ public void testMemColumnSeveralPagesRepeated() throws Exception {
   }
 
   private ColumnWriteStoreV1 newColumnWriteStoreImpl(MemPageStore memPageStore) {
-    return new ColumnWriteStoreV1(memPageStore, 2048, 2048, 2048, false, WriterVersion.PARQUET_1_0);
+    return new ColumnWriteStoreV1(memPageStore, 2048, 2048, false, WriterVersion.PARQUET_1_0);
   }
 }
diff --git a/parquet-column/src/test/java/parquet/io/PerfTest.java b/parquet-column/src/test/java/parquet/io/PerfTest.java
index 9cd31e3097..2642e09d15 100644
--- a/parquet-column/src/test/java/parquet/io/PerfTest.java
+++ b/parquet-column/src/test/java/parquet/io/PerfTest.java
@@ -74,7 +74,7 @@ private static void read(MemPageStore memPageStore, MessageType myschema,
 
 
   private static void write(MemPageStore memPageStore) {
-    ColumnWriteStoreV1 columns = new ColumnWriteStoreV1(memPageStore, 50*1024*1024, 50*1024*1024, 50*1024*1024, false, WriterVersion.PARQUET_1_0);
+    ColumnWriteStoreV1 columns = new ColumnWriteStoreV1(memPageStore, 50*1024*1024, 50*1024*1024, false, WriterVersion.PARQUET_1_0);
     MessageColumnIO columnIO = newColumnFactory(schema);
 
     GroupWriter groupWriter = new GroupWriter(columnIO.getRecordWriter(columns), schema);
diff --git a/parquet-column/src/test/java/parquet/io/TestColumnIO.java b/parquet-column/src/test/java/parquet/io/TestColumnIO.java
index d4442df6a8..bf93c6eb17 100644
--- a/parquet-column/src/test/java/parquet/io/TestColumnIO.java
+++ b/parquet-column/src/test/java/parquet/io/TestColumnIO.java
@@ -514,7 +514,7 @@ public void testPushParser() {
   }
 
   private ColumnWriteStoreV1 newColumnWriteStore(MemPageStore memPageStore) {
-    return new ColumnWriteStoreV1(memPageStore, 800, 800, 800, useDictionary, WriterVersion.PARQUET_1_0);
+    return new ColumnWriteStoreV1(memPageStore, 800, 800, useDictionary, WriterVersion.PARQUET_1_0);
   }
 
   @Test
diff --git a/parquet-column/src/test/java/parquet/io/TestFiltered.java b/parquet-column/src/test/java/parquet/io/TestFiltered.java
index 7acf6f1e69..2ba9c19e84 100644
--- a/parquet-column/src/test/java/parquet/io/TestFiltered.java
+++ b/parquet-column/src/test/java/parquet/io/TestFiltered.java
@@ -254,7 +254,7 @@ public void testFilteredNotPaged() {
 
   private MemPageStore writeTestRecords(MessageColumnIO columnIO, int number) {
     MemPageStore memPageStore = new MemPageStore(number * 2);
-    ColumnWriteStoreV1 columns = new ColumnWriteStoreV1(memPageStore, 800, 800, 800, false, WriterVersion.PARQUET_1_0);
+    ColumnWriteStoreV1 columns = new ColumnWriteStoreV1(memPageStore, 800, 800, false, WriterVersion.PARQUET_1_0);
 
     GroupWriter groupWriter = new GroupWriter(columnIO.getRecordWriter(columns), schema);
     for ( int i = 0; i < number; i++ ) {
diff --git a/parquet-encoding/src/main/java/parquet/bytes/CapacityByteArrayOutputStream.java b/parquet-encoding/src/main/java/parquet/bytes/CapacityByteArrayOutputStream.java
index 3efe9d0e78..eaa068902a 100644
--- a/parquet-encoding/src/main/java/parquet/bytes/CapacityByteArrayOutputStream.java
+++ b/parquet-encoding/src/main/java/parquet/bytes/CapacityByteArrayOutputStream.java
@@ -16,6 +16,7 @@
 package parquet.bytes;
 
 import static java.lang.Math.max;
+import static java.lang.Math.pow;
 import static java.lang.String.format;
 import static java.lang.System.arraycopy;
 import static parquet.Preconditions.checkArgument;
@@ -29,60 +30,92 @@
 import parquet.Log;
 
 /**
- * functionality of ByteArrayOutputStream without the memory and copy overhead
+ * Similar to a {@link ByteArrayOutputStream}, but uses a different strategy for growing that does not involve copying.
+ * Where ByteArrayOutputStream is backed by a single array that "grows" by copying into a new larger array, this output
+ * stream grows by allocating a new array (slab) and adding it to a list of previous arrays.
  *
- * It will linearly create a new slab of the initial size when needed (instead of creating a new buffer and copying the data).
- * After 10 slabs their size will increase exponentially (similar to {@link ByteArrayOutputStream} behavior) by making the new slab size the size of the existing data.
+ * Each new slab is allocated to be the same size as all the previous slabs combined, so these allocations become
+ * exponentially less frequent, just like ByteArrayOutputStream, with one difference. This output stream accepts a
+ * max capacity hint, which is a hint describing the max amount of data that will be written to this stream. As the
+ * total size of this stream nears this max, this stream starts to grow linearly instead of exponentially.
+ * So new slabs are allocated to be 1/5th of the max capacity hint,
+ * instead of equal to the total previous size of all slabs. This is useful because it prevents allocating roughly
+ * twice the needed space when a new slab is added just before the stream is done being used.
  *
- * When reusing a buffer it will adjust the slab size based on the previous data size ({@link CapacityByteArrayOutputStream#reset()})
+ * When reusing this stream it will adjust the initial slab size based on the previous data size, aiming for fewer
+ * allocations, with the assumption that a similar amount of data will be written to this stream on re-use.
+ * See ({@link CapacityByteArrayOutputStream#reset()}).
  *
  * @author Julien Le Dem
  *
  */
 public class CapacityByteArrayOutputStream extends OutputStream {
   private static final Log LOG = Log.getLog(CapacityByteArrayOutputStream.class);
-
   private static final byte[] EMPTY_SLAB = new byte[0];
 
-  private int initialSize;
-  private final int pageSize;
-  private List<byte[]> slabs = new ArrayList<byte[]>();
+  private int initialSlabSize;
+  private final int maxCapacityHint;
+  private final List<byte[]> slabs = new ArrayList<byte[]>();
+
   private byte[] currentSlab;
-  private int capacity = 0;
   private int currentSlabIndex;
-  private int currentSlabPosition;
-  private int size;
+  private int bytesAllocated = 0;
+  private int bytesUsed = 0;
 
   /**
-   * defaults pageSize to 1MB
-   * @param initialSize
-   * @deprecated use {@link CapacityByteArrayOutputStream#CapacityByteArrayOutputStream(int, int)}
+   * Return an initial slab size such that a CapacityByteArrayOutputStream constructed with it
+   * will end up allocating targetNumSlabs in order to reach targetCapacity. This aims to be
+   * a balance between the overhead of creating new slabs and wasting memory by eagerly making
+   * initial slabs too big.
+   *
+   * Note that targetCapacity here need not match maxCapacityHint in the constructor of
+   * CapacityByteArrayOutputStream, though often that would make sense.
+   *
+   * @param minSlabSize no matter what we shouldn't make slabs any smaller than this
+   * @param targetCapacity after we've allocated targetNumSlabs how much capacity should we have?
+   * @param targetNumSlabs how many slabs should it take to reach targetCapacity?
    */
-  @Deprecated
-  public CapacityByteArrayOutputStream(int initialSize) {
-    this(initialSize, 1024 * 1024);
+  public static int initialSlabSizeHeuristic(int minSlabSize, int targetCapacity, int targetNumSlabs) {
+    // initialSlabSize = (targetCapacity / (2^targetNumSlabs)) means we double targetNumSlabs times
+    // before reaching the targetCapacity
+    // eg for page size of 1MB we start at 1024 bytes.
+    // we also don't want to start too small, so we also apply a minimum.
+    return max(minSlabSize, ((int) (targetCapacity / pow(2, targetNumSlabs))));
   }
 
   /**
-   * @param initialSize the initialSize of the buffer (also slab size)
-   * @param pageSize
+   * Construct a CapacityByteArrayOutputStream configured such that its initial slab size is
+   * determined by {@link #initialSlabSizeHeuristic}, with targetCapacity == maxCapacityHint
    */
-  public CapacityByteArrayOutputStream(int initialSize, int pageSize) {
-    checkArgument(initialSize > 0, "initialSize must be > 0");
-    checkArgument(pageSize > 0, "pageSize must be > 0");
-    this.pageSize = pageSize;
-    initSlabs(initialSize);
+  public static CapacityByteArrayOutputStream withTargetNumSlabs(
+      int minSlabSize, int maxCapacityHint, int targetNumSlabs) {
+
+    return new CapacityByteArrayOutputStream(
+        initialSlabSizeHeuristic(minSlabSize, maxCapacityHint, targetNumSlabs),
+        maxCapacityHint);
   }
 
-  private void initSlabs(int initialSize) {
-    if (Log.DEBUG) LOG.debug(String.format("initial slab of size %d", initialSize));
-    this.initialSize = initialSize;
-    this.slabs.clear();
-    this.capacity = 0;
-    this.currentSlab = EMPTY_SLAB;
-    this.currentSlabIndex = -1;
-    this.currentSlabPosition = 0;
-    this.size = 0;
+  /**
+   * Defaults maxCapacityHint to 1MB
+   * @param initialSlabSize
+   * @deprecated use {@link CapacityByteArrayOutputStream#CapacityByteArrayOutputStream(int, int)}
+   */
+  @Deprecated
+  public CapacityByteArrayOutputStream(int initialSlabSize) {
+    this(initialSlabSize, 1024 * 1024);
+  }
+
+  /**
+   * @param initialSlabSize the size to make the first slab
+   * @param maxCapacityHint a hint (not guarantee) of the max amount of data written to this stream
+   */
+  public CapacityByteArrayOutputStream(int initialSlabSize, int maxCapacityHint) {
+    checkArgument(initialSlabSize > 0, "initialSlabSize must be > 0");
+    checkArgument(maxCapacityHint > 0, "maxCapacityHint must be > 0");
+    checkArgument(maxCapacityHint >= initialSlabSize, String.format("maxCapacityHint can't be less than initialSlabSize %d %d", initialSlabSize, maxCapacityHint));
+    this.initialSlabSize = initialSlabSize;
+    this.maxCapacityHint = maxCapacityHint;
+    reset();
   }
 
   /**
@@ -90,56 +123,60 @@ private void initSlabs(int initialSize) {
    * @param minimumSize the size of the data we want to copy in the new slab
    */
   private void addSlab(int minimumSize) {
-    this.currentSlabIndex += 1;
     int nextSlabSize;
-    if (size == 0) {
-      nextSlabSize = initialSize;
-    } else if (size > pageSize / 5) {
+
+    if (bytesUsed == 0) {
+      nextSlabSize = initialSlabSize;
+    } else if (bytesUsed > maxCapacityHint / 5) {
       // to avoid an overhead of up to twice the needed size, we get linear when approaching target page size
-      nextSlabSize = pageSize / 5;
+      nextSlabSize = maxCapacityHint / 5;
     } else {
       // double the size every time
-      nextSlabSize = size;
+      nextSlabSize = bytesUsed;
     }
+
     if (nextSlabSize < minimumSize) {
       if (Log.DEBUG) LOG.debug(format("slab size %,d too small for value of size %,d. Bumping up slab size", nextSlabSize, minimumSize));
       nextSlabSize = minimumSize;
     }
-    if (Log.DEBUG) LOG.debug(format("used %d slabs, new slab size %d", currentSlabIndex, nextSlabSize));
+
+    if (Log.DEBUG) LOG.debug(format("used %d slabs, adding new slab of size %d", slabs.size(), nextSlabSize));
+
     this.currentSlab = new byte[nextSlabSize];
     this.slabs.add(currentSlab);
-    this.capacity += nextSlabSize;
-    this.currentSlabPosition = 0;
+    this.bytesAllocated += nextSlabSize;
+    this.currentSlabIndex = 0;
   }
 
   @Override
   public void write(int b) {
-    if (currentSlabPosition == currentSlab.length) {
+    if (currentSlabIndex == currentSlab.length) {
       addSlab(1);
     }
-    currentSlab[currentSlabPosition] = (byte) b;
-    currentSlabPosition += 1;
-    size += 1;
+    currentSlab[currentSlabIndex] = (byte) b;
+    currentSlabIndex += 1;
+    bytesUsed += 1;
   }
 
   @Override
   public void write(byte b[], int off, int len) {
     if ((off < 0) || (off > b.length) || (len < 0) ||
         ((off + len) - b.length > 0)) {
-      throw new IndexOutOfBoundsException();
+      throw new IndexOutOfBoundsException(
+          String.format("Given byte array of size %d, with requested length(%d) and offset(%d)", b.length, len, off));
     }
-    if (currentSlabPosition + len >= currentSlab.length) {
-      final int length1 = currentSlab.length - currentSlabPosition;
-      arraycopy(b, off, currentSlab, currentSlabPosition, length1);
+    if (currentSlabIndex + len >= currentSlab.length) {
+      final int length1 = currentSlab.length - currentSlabIndex;
+      arraycopy(b, off, currentSlab, currentSlabIndex, length1);
       final int length2 = len - length1;
       addSlab(length2);
-      arraycopy(b, off + length1, currentSlab, currentSlabPosition, length2);
-      currentSlabPosition = length2;
+      arraycopy(b, off + length1, currentSlab, currentSlabIndex, length2);
+      currentSlabIndex = length2;
     } else {
-      arraycopy(b, off, currentSlab, currentSlabPosition, len);
-      currentSlabPosition += len;
+      arraycopy(b, off, currentSlab, currentSlabIndex, len);
+      currentSlabIndex += len;
     }
-    size += len;
+    bytesUsed += len;
   }
 
   /**
@@ -150,18 +187,26 @@ public void write(byte b[], int off, int len) {
    * @exception  IOException  if an I/O error occurs.
    */
   public void writeTo(OutputStream out) throws IOException {
-    for (int i = 0; i < currentSlabIndex; i++) {
+    for (int i = 0; i < slabs.size() - 1; i++) {
       final byte[] slab = slabs.get(i);
-      out.write(slab, 0, slab.length);
+      out.write(slab);
     }
-    out.write(currentSlab, 0, currentSlabPosition);
+    out.write(currentSlab, 0, currentSlabIndex);
   }
 
   /**
-   * @return the size of the allocated buffer
+   * @return The total size in bytes of data written to this stream.
+   */
+  public long size() {
+    return bytesUsed;
+  }
+
+  /**
+   *
+   * @return The total size in bytes currently allocated for this stream.
    */
   public int getCapacity() {
-    return capacity;
+    return bytesAllocated;
   }
 
   /**
@@ -172,23 +217,22 @@ public int getCapacity() {
   public void reset() {
     // readjust slab size.
     // 7 = 2^3 - 1 so that doubling the initial size 3 times will get to the same size
-    initSlabs(max(size / 7, initialSize));
-  }
-
-  /**
-   * @return the size of the buffered data
-   */
-  public long size() {
-    return size;
+    this.initialSlabSize = max(bytesUsed / 7, initialSlabSize);
+    if (Log.DEBUG) LOG.debug(String.format("initial slab of size %d", initialSlabSize));
+    this.slabs.clear();
+    this.bytesAllocated = 0;
+    this.bytesUsed = 0;
+    this.currentSlab = EMPTY_SLAB;
+    this.currentSlabIndex = 0;
   }
 
   /**
-   * @return the index of the last value being written to this stream, which
+   * @return the index of the last value written to this stream, which
    * can be passed to {@link #setByte(long, byte)} in order to change it
    */
   public long getCurrentIndex() {
-    checkArgument(size > 0, "This is an empty stream");
-    return size - 1;
+    checkArgument(bytesUsed > 0, "This is an empty stream");
+    return bytesUsed - 1;
   }
 
   /**
@@ -198,10 +242,10 @@ public long getCurrentIndex() {
    * @param value the value to replace it with
    */
   public void setByte(long index, byte value) {
-    checkArgument(index < size, "Index: " + index + " is >= the current size of: " + size);
+    checkArgument(index < bytesUsed, "Index: " + index + " is >= the current size of: " + bytesUsed);
 
     long seen = 0;
-    for (int i = 0; i <= currentSlabIndex; i++) {
+    for (int i = 0; i < slabs.size(); i++) {
       byte[] slab = slabs.get(i);
       if (index < seen + slab.length) {
         // ok found index
@@ -221,7 +265,7 @@ public String memUsageString(String prefix) {
   }
 
   /**
-   * @return the total count of allocated slabs
+   * @return the total number of allocated slabs
    */
   int getSlabCount() {
     return slabs.size();
diff --git a/parquet-encoding/src/main/java/parquet/bytes/ConcatenatingByteArrayCollector.java b/parquet-encoding/src/main/java/parquet/bytes/ConcatenatingByteArrayCollector.java
new file mode 100644
index 0000000000..9ea8296589
--- /dev/null
+++ b/parquet-encoding/src/main/java/parquet/bytes/ConcatenatingByteArrayCollector.java
@@ -0,0 +1,45 @@
+package parquet.bytes;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+
+import static java.lang.String.format;
+
+public class ConcatenatingByteArrayCollector extends BytesInput {
+  private final List<byte[]> slabs = new ArrayList<byte[]>();
+  private long size = 0;
+
+  public void collect(BytesInput bytesInput) throws IOException {
+    byte[] bytes = bytesInput.toByteArray();
+    slabs.add(bytes);
+    size += bytes.length;
+  }
+
+  public void reset() {
+    size = 0;
+    slabs.clear();
+  }
+
+  @Override
+  public void writeAllTo(OutputStream out) throws IOException {
+    for (byte[] slab : slabs) {
+      out.write(slab);
+    }
+  }
+
+  @Override
+  public long size() {
+    return size;
+  }
+
+  /**
+   * @param prefix  a prefix to be used for every new line in the string
+   * @return a text representation of the memory usage of this structure
+   */
+  public String memUsageString(String prefix) {
+    return format("%s %s %d slabs, %,d bytes", prefix, getClass().getSimpleName(), slabs.size(), size);
+  }
+
+}
diff --git a/parquet-hadoop/src/main/java/parquet/hadoop/ColumnChunkPageWriteStore.java b/parquet-hadoop/src/main/java/parquet/hadoop/ColumnChunkPageWriteStore.java
index 1a0e42eaee..f17b8d3204 100644
--- a/parquet-hadoop/src/main/java/parquet/hadoop/ColumnChunkPageWriteStore.java
+++ b/parquet-hadoop/src/main/java/parquet/hadoop/ColumnChunkPageWriteStore.java
@@ -18,6 +18,7 @@
 import static parquet.Log.INFO;
 import static parquet.column.statistics.Statistics.getStatsBasedOnType;
 
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -27,7 +28,7 @@
 
 import parquet.Log;
 import parquet.bytes.BytesInput;
-import parquet.bytes.CapacityByteArrayOutputStream;
+import parquet.bytes.ConcatenatingByteArrayCollector;
 import parquet.column.ColumnDescriptor;
 import parquet.column.Encoding;
 import parquet.column.page.DictionaryPage;
@@ -41,6 +42,7 @@
 
 class ColumnChunkPageWriteStore implements PageWriteStore {
   private static final Log LOG = Log.getLog(ColumnChunkPageWriteStore.class);
+  private static final int COLUMN_CHUNK_WRITER_MAX_SIZE_HINT = 64 * 1024;
 
   private static ParquetMetadataConverter parquetMetadataConverter = new ParquetMetadataConverter();
 
@@ -49,7 +51,8 @@ private static final class ColumnChunkPageWriter implements PageWriter {
     private final ColumnDescriptor path;
     private final BytesCompressor compressor;
 
-    private final CapacityByteArrayOutputStream buf;
+    private final ByteArrayOutputStream tempOutputStream = new ByteArrayOutputStream();
+    private final ConcatenatingByteArrayCollector buf;
     private DictionaryPage dictionaryPage;
 
     private long uncompressedLength;
@@ -61,10 +64,10 @@ private static final class ColumnChunkPageWriter implements PageWriter {
 
     private Statistics totalStatistics;
 
-    private ColumnChunkPageWriter(ColumnDescriptor path, BytesCompressor compressor, int initialSize, int pageSize) {
+    private ColumnChunkPageWriter(ColumnDescriptor path, BytesCompressor compressor, int pageSize) {
       this.path = path;
       this.compressor = compressor;
-      this.buf = new CapacityByteArrayOutputStream(initialSize, pageSize);
+      this.buf = new ConcatenatingByteArrayCollector();
       this.totalStatistics = getStatsBasedOnType(this.path.getType());
     }
 
@@ -88,6 +91,7 @@ public void writePage(BytesInput bytes,
             "Cannot write compressed page larger than Integer.MAX_VALUE bytes: "
             + compressedSize);
       }
+      tempOutputStream.reset();
       parquetMetadataConverter.writeDataPageHeader(
           (int)uncompressedSize,
           (int)compressedSize,
@@ -96,13 +100,15 @@ public void writePage(BytesInput bytes,
           rlEncoding,
           dlEncoding,
           valuesEncoding,
-          buf);
+          tempOutputStream);
       this.uncompressedLength += uncompressedSize;
       this.compressedLength += compressedSize;
       this.totalValueCount += valueCount;
       this.pageCount += 1;
       this.totalStatistics.mergeStatistics(statistics);
-      compressedBytes.writeAllTo(buf);
+      // by concatenating before collecting instead of collecting twice,
+      // we only allocate one buffer to copy into instead of multiple.
+      buf.collect(BytesInput.concat(BytesInput.from(tempOutputStream), compressedBytes));
       encodings.add(rlEncoding);
       encodings.add(dlEncoding);
       encodings.add(valuesEncoding);
@@ -124,21 +130,30 @@ public void writePageV2(
       int compressedSize = toIntWithCheck(
           compressedData.size() + repetitionLevels.size() + definitionLevels.size()
       );
+      tempOutputStream.reset();
       parquetMetadataConverter.writeDataPageV2Header(
           uncompressedSize, compressedSize,
           valueCount, nullCount, rowCount,
           statistics,
           dataEncoding,
-          rlByteLength, dlByteLength,
-          buf);
+          rlByteLength,
+          dlByteLength,
+          tempOutputStream);
       this.uncompressedLength += uncompressedSize;
       this.compressedLength += compressedSize;
       this.totalValueCount += valueCount;
       this.pageCount += 1;
       this.totalStatistics.mergeStatistics(statistics);
-      repetitionLevels.writeAllTo(buf);
-      definitionLevels.writeAllTo(buf);
-      compressedData.writeAllTo(buf);
+
+      // by concatenating before collecting instead of collecting twice,
+      // we only allocate one buffer to copy into instead of multiple.
+      buf.collect(
+          BytesInput.concat(
+            BytesInput.from(tempOutputStream),
+            repetitionLevels,
+            definitionLevels,
+            compressedData)
+      );
       encodings.add(dataEncoding);
     }
 
@@ -162,7 +177,7 @@ public void writeToFileWriter(ParquetFileWriter writer) throws IOException {
         writer.writeDictionaryPage(dictionaryPage);
         encodings.add(dictionaryPage.getEncoding());
       }
-      writer.writeDataPages(BytesInput.from(buf), uncompressedLength, compressedLength, totalStatistics, new ArrayList<Encoding>(encodings));
+      writer.writeDataPages(buf, uncompressedLength, compressedLength, totalStatistics, new ArrayList<Encoding>(encodings));
       writer.endColumn();
       if (INFO) {
         LOG.info(
@@ -180,7 +195,7 @@ public void writeToFileWriter(ParquetFileWriter writer) throws IOException {
 
     @Override
     public long allocatedSize() {
-      return buf.getCapacity();
+      return buf.size();
     }
 
     @Override
@@ -202,9 +217,9 @@ public String memUsageString(String prefix) {
 
   private final Map<ColumnDescriptor, ColumnChunkPageWriter> writers = new HashMap<ColumnDescriptor, ColumnChunkPageWriter>();
 
-  public ColumnChunkPageWriteStore(BytesCompressor compressor, MessageType schema, int initialSize, int pageSize) {
+  public ColumnChunkPageWriteStore(BytesCompressor compressor, MessageType schema, int pageSize) {
     for (ColumnDescriptor path : schema.getColumns()) {
-      writers.put(path,  new ColumnChunkPageWriter(path, compressor, initialSize, pageSize));
+      writers.put(path,  new ColumnChunkPageWriter(path, compressor, pageSize));
     }
   }
 
diff --git a/parquet-hadoop/src/main/java/parquet/hadoop/InternalParquetRecordWriter.java b/parquet-hadoop/src/main/java/parquet/hadoop/InternalParquetRecordWriter.java
index cd8875d590..1bfff8ef17 100644
--- a/parquet-hadoop/src/main/java/parquet/hadoop/InternalParquetRecordWriter.java
+++ b/parquet-hadoop/src/main/java/parquet/hadoop/InternalParquetRecordWriter.java
@@ -42,7 +42,6 @@
 class InternalParquetRecordWriter<T> {
   private static final Log LOG = Log.getLog(InternalParquetRecordWriter.class);
 
-  private static final int MINIMUM_BUFFER_SIZE = 64 * 1024;
   private static final int MINIMUM_RECORD_COUNT_FOR_CHECK = 100;
   private static final int MAXIMUM_RECORD_COUNT_FOR_CHECK = 10000;
 
@@ -98,22 +97,11 @@ public InternalParquetRecordWriter(
   }
 
   private void initStore() {
-    // we don't want this number to be too small
-    // ideally we divide the block equally across the columns
-    // it is unlikely all columns are going to be the same size.
-    // its value is likely below Integer.MAX_VALUE (2GB), although rowGroupSize is a long type.
-    // therefore this size is cast to int, since allocating byte array in under layer needs to
-    // limit the array size in an int scope.
-    int initialBlockBufferSize = Ints.checkedCast(max(MINIMUM_BUFFER_SIZE, rowGroupSize / schema.getColumns().size() / 5));
-    pageStore = new ColumnChunkPageWriteStore(compressor, schema, initialBlockBufferSize, pageSize);
-    // we don't want this number to be too small either
-    // ideally, slightly bigger than the page size, but not bigger than the block buffer
-    int initialPageBufferSize = max(MINIMUM_BUFFER_SIZE, min(pageSize + pageSize / 10, initialBlockBufferSize));
+    pageStore = new ColumnChunkPageWriteStore(compressor, schema, pageSize);
     columnStore = parquetProperties.newColumnWriteStore(
         schema,
         pageStore,
-        pageSize,
-        initialPageBufferSize);
+        pageSize);
     MessageColumnIO columnIO = new ColumnIOFactory(validating).getColumnIO(schema);
     writeSupport.prepareForWrite(columnIO.getRecordWriter(columnStore));
   }
diff --git a/parquet-hadoop/src/test/java/parquet/hadoop/TestColumnChunkPageWriteStore.java b/parquet-hadoop/src/test/java/parquet/hadoop/TestColumnChunkPageWriteStore.java
index e1223b666c..60337cddc8 100644
--- a/parquet-hadoop/src/test/java/parquet/hadoop/TestColumnChunkPageWriteStore.java
+++ b/parquet-hadoop/src/test/java/parquet/hadoop/TestColumnChunkPageWriteStore.java
@@ -64,7 +64,7 @@ public void test() throws Exception {
       writer.start();
       writer.startBlock(rowCount);
       {
-        ColumnChunkPageWriteStore store = new ColumnChunkPageWriteStore(f.getCompressor(codec, pageSize ), schema , initialSize, pageSize);
+        ColumnChunkPageWriteStore store = new ColumnChunkPageWriteStore(f.getCompressor(codec, pageSize ), schema, pageSize);
         PageWriter pageWriter = store.getPageWriter(col);
         pageWriter.writePageV2(
             rowCount, nullCount, valueCount,
diff --git a/parquet-pig/src/test/java/parquet/pig/TupleConsumerPerfTest.java b/parquet-pig/src/test/java/parquet/pig/TupleConsumerPerfTest.java
index 68ad1fed3e..9e590d855a 100644
--- a/parquet-pig/src/test/java/parquet/pig/TupleConsumerPerfTest.java
+++ b/parquet-pig/src/test/java/parquet/pig/TupleConsumerPerfTest.java
@@ -56,7 +56,7 @@ public static void main(String[] args) throws Exception {
     MessageType schema = new PigSchemaConverter().convert(Utils.getSchemaFromString(pigSchema));
 
     MemPageStore memPageStore = new MemPageStore(0);
-    ColumnWriteStoreV1 columns = new ColumnWriteStoreV1(memPageStore, 50*1024*1024, 50*1024*1024, 50*1024*1024, false, WriterVersion.PARQUET_1_0);
+    ColumnWriteStoreV1 columns = new ColumnWriteStoreV1(memPageStore, 50*1024*1024, 50*1024*1024, false, WriterVersion.PARQUET_1_0);
     write(memPageStore, columns, schema, pigSchema);
     columns.flush();
     read(memPageStore, pigSchema, pigSchemaProjected, pigSchemaNoString);
diff --git a/parquet-thrift/src/test/java/parquet/thrift/TestParquetReadProtocol.java b/parquet-thrift/src/test/java/parquet/thrift/TestParquetReadProtocol.java
index eb2041250d..43e1a884f2 100644
--- a/parquet-thrift/src/test/java/parquet/thrift/TestParquetReadProtocol.java
+++ b/parquet-thrift/src/test/java/parquet/thrift/TestParquetReadProtocol.java
@@ -145,7 +145,7 @@ private <T extends TBase<?,?>> void validate(T expected) throws TException {
     final MessageType schema = schemaConverter.convert(thriftClass);
     LOG.info(schema);
     final MessageColumnIO columnIO = new ColumnIOFactory(true).getColumnIO(schema);
-    final ColumnWriteStoreV1 columns = new ColumnWriteStoreV1(memPageStore, 10000, 10000, 10000, false, WriterVersion.PARQUET_1_0);
+    final ColumnWriteStoreV1 columns = new ColumnWriteStoreV1(memPageStore, 10000, 10000, false, WriterVersion.PARQUET_1_0);
     final RecordConsumer recordWriter = columnIO.getRecordWriter(columns);
     final StructType thriftType = schemaConverter.toStructType(thriftClass);
     ParquetWriteProtocol parquetWriteProtocol = new ParquetWriteProtocol(recordWriter, columnIO, thriftType);