Skip to content

Commit

Permalink
Code review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
cmnbroad committed Feb 17, 2016
1 parent db98950 commit 3fcc2aa
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ protected void runPipeline( JavaSparkContext sparkContext ) {
*/
private void initializeToolInputs(final JavaSparkContext sparkContext) {
initializeReference();
initializeReads(sparkContext); // reference must be intialized before reads
initializeReads(sparkContext); // reference must be initialized before reads
initializeIntervals();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public final class ReadsSparkSource implements Serializable {
private transient final JavaSparkContext ctx;
private ValidationStringency validationStringency = ReadConstants.DEFAULT_READ_VALIDATION_STRINGENCY;

protected final Logger logger = LogManager.getLogger(ReadsSparkSource.class);
private static final Logger logger = LogManager.getLogger(ReadsSparkSource.class);

public ReadsSparkSource(final JavaSparkContext ctx) { this.ctx = ctx; }

Expand Down Expand Up @@ -82,6 +82,7 @@ public JavaRDD<GATKRead> getParallelReads(final String readFileName, final Strin
* @return RDD of (SAMRecord-backed) GATKReads from the file.
*/
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final List<SimpleInterval> intervals, final long splitSize) {
// use the Hadoop configuration attached to the Spark context to maintain cumulative settings
final Configuration conf = ctx.hadoopConfiguration();
if (splitSize > 0) {
conf.set("mapreduce.input.fileinputformat.split.maxsize", Long.toString(splitSize));
Expand Down Expand Up @@ -157,7 +158,7 @@ public JavaRDD<GATKRead> getADAMReads(final String inputPath, final List<SimpleI
/**
* Loads the header using Hadoop-BAM.
* @param filePath path to the bam.
* @param referencePath to the reference
* @param referencePath Reference path or null if not available. Reference is required for CRAM files.
* @param auth authentication information if using GCS.
* @return the header for the bam.
*/
Expand Down Expand Up @@ -201,14 +202,17 @@ public boolean accept(Path path) {
/**
* Propagate any values that need to be passed to Hadoop-BAM through configuration properties:
*
* - the validation stringency property is always set
* - if the input file is a CRAM file, the reference value, which must be a URI which includes
* a scheme, will also be set
* - the validation stringency property is always set using the current value of the
* validationStringency field
* - if the input file is a CRAM file, the reference value will also be set, and must be a URI
* which includes a scheme. if no scheme is provided a "file://" scheme will be used. for
* non-CRAM input the reference may be null.
* - if the input file is not CRAM, the reference property is *unset* to prevent Hadoop-BAM
* from passing a stale value through to htsjdk when multiple read calls are made serially
* with different inputs but the same Spark context
*/
private void setHadoopBAMConfigurationProperties(final String inputName, final String referenceName) {
// use the Hadoop configuration attached to the Spark context to maintain cumulative settings
final Configuration conf = ctx.hadoopConfiguration();
conf.set(SAMHeaderReader.VALIDATION_STRINGENCY_PROPERTY, validationStringency.name());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMRecordCoordinateComparator;
import htsjdk.samtools.ValidationStringency;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
Expand Down Expand Up @@ -71,7 +70,7 @@ public void readsSinkTest(String inputBam, String outputFileName, String outputF
final File outputFile = createTempFile(outputFileName, outputFileExtension);
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.STRICT);
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
JavaRDD<GATKRead> rddParallelReads = readSource.getParallelReads(inputBam, null);
SAMFileHeader header = readSource.getHeader(inputBam, null, null);

Expand All @@ -97,7 +96,7 @@ public void readsSinkShardedTest(String inputBam, String outputFileName, String
final File outputFile = createTempFile(outputFileName, outputFileExtension);
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.STRICT);
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
JavaRDD<GATKRead> rddParallelReads = readSource.getParallelReads(inputBam, null);
rddParallelReads = rddParallelReads.repartition(2); // ensure that the output is in two shards
SAMFileHeader header = readSource.getHeader(inputBam, null, null);
Expand All @@ -124,7 +123,7 @@ public void readsSinkADAMTest(String inputBam, String outputDirectoryName) throw

JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.STRICT);
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
JavaRDD<GATKRead> rddParallelReads = readSource.getParallelReads(inputBam, null);
SAMFileHeader header = readSource.getHeader(inputBam, null, null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadConstants;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;
import org.broadinstitute.hellbender.utils.test.BaseTest;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import javax.validation.Valid;
import java.io.File;
import java.io.IOException;
import java.util.List;
Expand All @@ -36,13 +36,11 @@ public class ReadsSparkSourceUnitTest extends BaseTest {
@DataProvider(name = "loadReads")
public Object[][] loadReads() {
return new Object[][]{
{NA12878_chr17_1k_BAM, null, ValidationStringency.SILENT},

{dirBQSR + "HiSeq.1mb.1RG.2k_lines.alternate.bam", null, ValidationStringency.STRICT},
{dirBQSR + "expected.HiSeq.1mb.1RG.2k_lines.alternate.recalibrated.DIQ.bam", null, ValidationStringency.STRICT},

{NA12878_chr17_1k_CRAM, v37_chr17_1Mb_Reference, ValidationStringency.SILENT},
{dir + "valid.cram", dir + "valid.fasta", ValidationStringency.STRICT}
{NA12878_chr17_1k_BAM, null},
{dirBQSR + "HiSeq.1mb.1RG.2k_lines.alternate.bam", null},
{dirBQSR + "expected.HiSeq.1mb.1RG.2k_lines.alternate.recalibrated.DIQ.bam", null},
{NA12878_chr17_1k_CRAM, v37_chr17_1Mb_Reference},
{dir + "valid.cram", dir + "valid.fasta"}
};
}

Expand All @@ -62,7 +60,11 @@ public Object[][] loadReadsValidation() {
};
}

private void doLoadReadsTest(String bam, String referencePath, ValidationStringency validationStringency) {
private void doLoadReadsTest(String bam, String referencePath) {
doLoadReads(bam, referencePath, ReadConstants.DEFAULT_READ_VALIDATION_STRINGENCY);
}

private void doLoadReads(String bam, String referencePath, ValidationStringency validationStringency) {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

ReadsSparkSource readSource = new ReadsSparkSource(ctx, validationStringency);
Expand All @@ -75,31 +77,31 @@ private void doLoadReadsTest(String bam, String referencePath, ValidationStringe
}

@Test(dataProvider = "loadReads", groups = "spark")
public void readsSparkSourceTest(String bam, String referencePath, ValidationStringency validationStringency) {
doLoadReadsTest(bam, referencePath, validationStringency);
public void readsSparkSourceTest(String bam, String referencePath) {
doLoadReadsTest(bam, referencePath);
}

@Test(dataProvider = "loadReadsValidation", groups = "spark", expectedExceptions = SAMFormatException.class)
public void readsSparkSourceTestStrict(String bam, String referencePath) {
doLoadReadsTest(bam, referencePath, ValidationStringency.STRICT);
doLoadReads(bam, referencePath, ValidationStringency.STRICT);
}

@Test(dataProvider = "loadReadsValidation", groups = "spark")
public void readsSparkSourceTestLenient(String bam, String referencePath) {
doLoadReadsTest(bam, referencePath, ValidationStringency.LENIENT);
doLoadReads(bam, referencePath, ValidationStringency.LENIENT);
}

@Test(dataProvider = "loadReadsValidation", groups = "spark")
public void readsSparkSourceTestSilent(String bam, String referencePath) {
doLoadReadsTest(bam, referencePath, ValidationStringency.SILENT);
doLoadReads(bam, referencePath, ValidationStringency.SILENT);
}

@Test(dataProvider = "loadShardedReads", groups = "spark")
public void shardedReadsSparkSourceTest(String expectedBam, String shardedBam, String referencePath) {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.STRICT);
JavaRDD<GATKRead> rddSerialReads = getSerialReads(ctx, expectedBam, referencePath, ValidationStringency.DEFAULT_STRINGENCY);
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
JavaRDD<GATKRead> rddSerialReads = getSerialReads(ctx, expectedBam, referencePath, ReadConstants.DEFAULT_READ_VALIDATION_STRINGENCY);
JavaRDD<GATKRead> rddParallelReads = readSource.getParallelReads(shardedBam, referencePath);

List<GATKRead> serialReads = rddSerialReads.collect();
Expand All @@ -110,7 +112,7 @@ public void shardedReadsSparkSourceTest(String expectedBam, String shardedBam, S
@Test(groups = "spark")
public void testHeadersAreStripped() {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.STRICT);
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
final List<GATKRead> reads = readSource.getParallelReads(dirBQSR + "HiSeq.1mb.1RG.2k_lines.alternate.bam", null).collect();

for ( final GATKRead read : reads ) {
Expand All @@ -120,12 +122,12 @@ public void testHeadersAreStripped() {

@Test(groups = "spark", expectedExceptions=UserException.class)
public void testReject2BitCRAMReference() {
doLoadReadsTest(NA12878_chr17_1k_CRAM, b37_2bit_reference_20_21, ValidationStringency.STRICT);
doLoadReadsTest(NA12878_chr17_1k_CRAM, b37_2bit_reference_20_21);
}

@Test(groups = "spark", expectedExceptions=UserException.class)
public void testCRAMReferenceRequired() {
doLoadReadsTest(NA12878_chr17_1k_CRAM, null, ValidationStringency.STRICT);
doLoadReadsTest(NA12878_chr17_1k_CRAM, null);
}

@Test
Expand All @@ -134,7 +136,7 @@ public void testPartitionSizing(){
String bam = dirBQSR + "HiSeq.1mb.1RG.2k_lines.alternate.bam"; //file is ~220 kB
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.STRICT);
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
JavaRDD<GATKRead> allInOnePartition = readSource.getParallelReads(bam, null);
JavaRDD<GATKRead> smallPartitions = readSource.getParallelReads(bam, null, 100 * 1024); // 100 kB
Assert.assertEquals(allInOnePartition.partitions().size(), 1);
Expand All @@ -155,7 +157,7 @@ public void testReadFromFileAndHDFS() throws IOException {
cluster.getFileSystem().copyFromLocalFile(new Path(bai.toURI()), baiPath);

final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
final ReadsSparkSource readsSparkSource = new ReadsSparkSource(ctx, ValidationStringency.STRICT);
final ReadsSparkSource readsSparkSource = new ReadsSparkSource(ctx);
final List<GATKRead> localReads = readsSparkSource.getParallelReads(bam.toURI().toString(), null).collect();
final List<GATKRead> hdfsReads = readsSparkSource.getParallelReads(bamPath.toUri().toString(), null).collect();

Expand All @@ -179,17 +181,17 @@ public void testCRAMReferenceFromHDFS() throws IOException {
try {
cluster = new MiniDFSCluster.Builder(new Configuration()).build();
final Path workingDirectory = cluster.getFileSystem().getWorkingDirectory();
final Path cramPath = new Path(workingDirectory,"hdfs.cram");
final Path refPath = new Path(workingDirectory, "hdfs.fasta");
final Path refIndexPath = new Path(workingDirectory, "hdfs.fasta.fai");
cluster.getFileSystem().copyFromLocalFile(new Path(cram.toURI()), cramPath);
cluster.getFileSystem().copyFromLocalFile(new Path(reference.toURI()), refPath);
cluster.getFileSystem().copyFromLocalFile(new Path(referenceIndex.toURI()), refIndexPath);
final Path cramHDFSPath = new Path(workingDirectory,"hdfs.cram");
final Path refHDFSPath = new Path(workingDirectory, "hdfs.fasta");
final Path refIndexHDFSPath = new Path(workingDirectory, "hdfs.fasta.fai");
cluster.getFileSystem().copyFromLocalFile(new Path(cram.toURI()), cramHDFSPath);
cluster.getFileSystem().copyFromLocalFile(new Path(reference.toURI()), refHDFSPath);
cluster.getFileSystem().copyFromLocalFile(new Path(referenceIndex.toURI()), refIndexHDFSPath);

final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
final ReadsSparkSource readsSparkSource = new ReadsSparkSource(ctx);
final List<GATKRead> localReads = readsSparkSource.getParallelReads(cram.toURI().toString(), reference.toURI().toString()).collect();
final List<GATKRead> hdfsReads = readsSparkSource.getParallelReads(cramPath.toUri().toString(), refPath.toUri().toString()).collect();
final List<GATKRead> hdfsReads = readsSparkSource.getParallelReads(cramHDFSPath.toUri().toString(), refHDFSPath.toUri().toString()).collect();

Assert.assertFalse(localReads.isEmpty());
Assert.assertEquals(localReads, hdfsReads);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.broadinstitute.hellbender.tools.spark.transforms.markduplicates;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.ValidationStringency;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.broadinstitute.hellbender.cmdline.argumentcollections.OpticalDuplicatesArgumentCollection;
Expand Down Expand Up @@ -33,7 +32,7 @@ public Object[][] loadReads() {
public void markDupesTest(final String input, final long totalExpected, final long dupsExpected) throws IOException {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.STRICT);
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
JavaRDD<GATKRead> reads = readSource.getParallelReads(input, null);
Assert.assertEquals(reads.count(), totalExpected);

Expand Down

0 comments on commit 3fcc2aa

Please sign in to comment.