Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid allocating spire uint objects during apply agglomerate #6532

Merged
merged 7 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ For upgrade instructions, please check the [migration guide](MIGRATIONS.released
[Commits](https://github.com/scalableminds/webknossos/compare/22.10.0...HEAD)

### Added
- Improved performance for applying agglomerate mappings. [#6532](https://github.com/scalableminds/webknossos/pull/6532)

### Changed
- Creating tasks in bulk now also supports referencing task types by their summary instead of id. [#6486](https://github.com/scalableminds/webknossos/pull/6486)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
package com.scalableminds.webknossos.datastore.services

import java.nio._
import java.nio.file.{Files, Paths}

import ch.systemsx.cisd.hdf5._
import com.scalableminds.util.io.PathUtils
import com.scalableminds.webknossos.datastore.DataStoreConfig
import com.scalableminds.webknossos.datastore.EditableMapping.{AgglomerateEdge, AgglomerateGraph}
import com.scalableminds.webknossos.datastore.SkeletonTracing.{Edge, SkeletonTracing, Tree}
import com.scalableminds.webknossos.datastore.geometry.Vec3IntProto
import com.scalableminds.webknossos.datastore.helpers.{NodeDefaults, SkeletonTracingDefaults}
import com.scalableminds.webknossos.datastore.models.datasource.ElementClass
import com.scalableminds.webknossos.datastore.models.requests.DataServiceDataRequest
import com.scalableminds.webknossos.datastore.storage._
import com.typesafe.scalalogging.LazyLogging
import javax.inject.Inject
import net.liftweb.common.Box.tryo
import net.liftweb.common.{Box, Failure, Full}
import net.liftweb.util.Helpers.tryo
import org.apache.commons.io.FilenameUtils
import spire.math.{UByte, UInt, ULong, UShort}

import java.nio._
import java.nio.file.{Files, Paths}
import javax.inject.Inject

class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverter with LazyLogging {
private val agglomerateDir = "agglomerates"
Expand All @@ -41,15 +41,11 @@ class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverte
}

def applyAgglomerate(request: DataServiceDataRequest)(data: Array[Byte]): Array[Byte] = {
def byteFunc(buf: ByteBuffer, lon: Long) = buf put lon.toByte
def shortFunc(buf: ByteBuffer, lon: Long) = buf putShort lon.toShort
def intFunc(buf: ByteBuffer, lon: Long) = buf putInt lon.toInt
def longFunc(buf: ByteBuffer, lon: Long) = buf putLong lon

val agglomerateFileKey = AgglomerateFileKey.fromDataRequest(request)

def convertToAgglomerate(input: Array[ULong],
numBytes: Int,
def convertToAgglomerate(input: Array[Long],
bytesPerElement: Int,
bufferFunc: (ByteBuffer, Long) => ByteBuffer): Array[Byte] = {

val cachedAgglomerateFile = agglomerateFileCache.withCache(agglomerateFileKey)(initHDFReader)
Expand All @@ -64,16 +60,26 @@ class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverte
cachedAgglomerateFile.finishAccess()

agglomerateIds
.foldLeft(ByteBuffer.allocate(numBytes * input.length).order(ByteOrder.LITTLE_ENDIAN))(bufferFunc)
.foldLeft(ByteBuffer.allocate(bytesPerElement * input.length).order(ByteOrder.LITTLE_ENDIAN))(bufferFunc)
.array
}

val bytesPerElement = ElementClass.bytesPerElement(request.dataLayer.elementClass)
convertData(data, request.dataLayer.elementClass) match {
case data: Array[UByte] => convertToAgglomerate(data.map(e => ULong(e.toLong)), 1, byteFunc)
case data: Array[UShort] => convertToAgglomerate(data.map(e => ULong(e.toLong)), 2, shortFunc)
case data: Array[UInt] => convertToAgglomerate(data.map(e => ULong(e.toLong)), 4, intFunc)
case data: Array[ULong] => convertToAgglomerate(data, 8, longFunc)
case _ => data
case data: Array[Byte] =>
val longBuffer = LongBuffer.allocate(data.length)
data.foreach(e => longBuffer.put(uByteToLong(e)))
convertToAgglomerate(longBuffer.array, bytesPerElement, putByte)
case data: Array[Short] =>
val longBuffer = LongBuffer.allocate(data.length)
data.foreach(e => longBuffer.put(uShortToLong(e)))
convertToAgglomerate(longBuffer.array, bytesPerElement, putShort)
case data: Array[Int] =>
val longBuffer = LongBuffer.allocate(data.length)
data.foreach(e => longBuffer.put(uIntToLong(e)))
convertToAgglomerate(longBuffer.array, bytesPerElement, putInt)
case data: Array[Long] => convertToAgglomerate(data, bytesPerElement, putLong)
fm3 marked this conversation as resolved.
Show resolved Hide resolved
case _ => data
}
}

Expand Down Expand Up @@ -109,7 +115,7 @@ class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverte

val defaultCache: Either[AgglomerateIdCache, BoundingBoxCache] =
if (Files.exists(cumsumPath)) {
Right(CumsumParser.parse(cumsumPath.toFile, ULong(config.Datastore.Cache.AgglomerateFile.cumsumMaxReaderRange)))
Right(CumsumParser.parse(cumsumPath.toFile, config.Datastore.Cache.AgglomerateFile.cumsumMaxReaderRange))
} else {
Left(agglomerateIdCache)
}
Expand Down Expand Up @@ -216,8 +222,8 @@ class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverte
val cachedAgglomerateFile = agglomerateFileCache.withCache(agglomerateFileKey)(initHDFReader)

tryo {
val agglomerateIds = segmentIds.map { segmentId =>
cachedAgglomerateFile.agglomerateIdCache.withCache(ULong(segmentId),
val agglomerateIds = segmentIds.map { segmentId: Long =>
cachedAgglomerateFile.agglomerateIdCache.withCache(segmentId,
cachedAgglomerateFile.reader,
cachedAgglomerateFile.dataset)(readHDF)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,40 @@ package com.scalableminds.webknossos.datastore.services

import com.scalableminds.util.tools.FoxImplicits
import com.scalableminds.webknossos.datastore.models.datasource.ElementClass
import spire.math._
import spire.math.{ULong, _}

import java.nio._
import scala.reflect.ClassTag

trait DataConverter extends FoxImplicits {

def putByte(buf: ByteBuffer, lon: Long): ByteBuffer = buf put lon.toByte
def putShort(buf: ByteBuffer, lon: Long): ByteBuffer = buf putShort lon.toShort
def putInt(buf: ByteBuffer, lon: Long): ByteBuffer = buf putInt lon.toInt
def putLong(buf: ByteBuffer, lon: Long): ByteBuffer = buf putLong lon

def uByteToLong(uByte: Byte): Long = uByte & 0xffL
def uShortToLong(uShort: Short): Long = uShort & 0xffffL
def uIntToLong(uInt: Int): Long = uInt & 0xffffffffL

def convertData(data: Array[Byte],
elementClass: ElementClass.Value,
filterZeroes: Boolean = false): Array[_ >: UByte with UShort with UInt with ULong with Float] =
elementClass: ElementClass.Value): Array[_ >: Byte with Short with Int with Long with Float] =
elementClass match {
case ElementClass.uint8 =>
case ElementClass.uint8 | ElementClass.int8 =>
convertDataImpl[Byte, ByteBuffer](data, DataTypeFunctors[Byte, ByteBuffer](identity, _.get(_), _.toByte))
.map(UByte(_))
.filter(!filterZeroes || _ != UByte(0))
case ElementClass.uint16 =>
case ElementClass.uint16 | ElementClass.int16 =>
convertDataImpl[Short, ShortBuffer](data,
DataTypeFunctors[Short, ShortBuffer](_.asShortBuffer, _.get(_), _.toShort))
.map(UShort(_))
.filter(!filterZeroes || _ != UShort(0))
case ElementClass.uint24 =>
convertDataImpl[Byte, ByteBuffer](data, DataTypeFunctors[Byte, ByteBuffer](identity, _.get(_), _.toByte))
.map(UByte(_))
.filter(!filterZeroes || _ != UByte(0))
case ElementClass.uint32 =>
case ElementClass.uint32 | ElementClass.int32 =>
convertDataImpl[Int, IntBuffer](data, DataTypeFunctors[Int, IntBuffer](_.asIntBuffer, _.get(_), _.toInt))
.map(UInt(_))
.filter(!filterZeroes || _ != UInt(0))
case ElementClass.uint64 =>
case ElementClass.uint64 | ElementClass.int64 =>
convertDataImpl[Long, LongBuffer](data, DataTypeFunctors[Long, LongBuffer](_.asLongBuffer, _.get(_), identity))
.map(ULong(_))
.filter(!filterZeroes || _ != ULong(0))
case ElementClass.float =>
convertDataImpl[Float, FloatBuffer](data,
DataTypeFunctors[Float, FloatBuffer](
_.asFloatBuffer(),
_.get(_),
_.toFloat)).filter(!_.isNaN).filter(!filterZeroes || _ != 0f)
convertDataImpl[Float, FloatBuffer](
data,
DataTypeFunctors[Float, FloatBuffer](_.asFloatBuffer(), _.get(_), _.toFloat))
}

private def convertDataImpl[T: ClassTag, B <: Buffer](data: Array[Byte],
Expand All @@ -50,4 +46,59 @@ trait DataConverter extends FoxImplicits {
dataTypeFunctor.copyDataFn(srcBuffer, dstArray)
dstArray
}

def toUnsigned(data: Array[_ >: Byte with Short with Int with Long with Float])
: Array[_ >: UByte with UShort with UInt with ULong with Float] =
data match {
case d: Array[Byte] => d.map(UByte(_))
case d: Array[Short] => d.map(UShort(_))
case d: Array[Int] => d.map(UInt(_))
case d: Array[Long] => d.map(ULong(_))
case d: Array[Float] => d
}

def filterZeroes(data: Array[_ >: Byte with Short with Int with Long with Float],
skip: Boolean = false): Array[_ >: Byte with Short with Int with Long with Float] =
if (skip) data
else {
val zeroByte = 0.toByte
val zeroShort = 0.toShort
val zeroInt = 0
val zeroLong = 0L
data match {
case d: Array[Byte] => d.filter(_ != zeroByte)
case d: Array[Short] => d.filter(_ != zeroShort)
case d: Array[Int] => d.filter(_ != zeroInt)
case d: Array[Long] => d.filter(_ != zeroLong)
case d: Array[Float] => d.filter(!_.isNaN).filter(_ != 0f)
}
}

def toBytesSpire(typed: Array[_ >: UByte with UShort with UInt with ULong with Float],
elementClass: ElementClass.Value): Array[Byte] = {
val numBytes = ElementClass.bytesPerElement(elementClass)
val byteBuffer = ByteBuffer.allocate(numBytes * typed.length).order(ByteOrder.LITTLE_ENDIAN)
typed match {
case data: Array[UByte] => data.foreach(el => byteBuffer.put(el.signed))
case data: Array[UShort] => data.foreach(el => byteBuffer.putChar(el.signed))
case data: Array[UInt] => data.foreach(el => byteBuffer.putInt(el.signed))
case data: Array[ULong] => data.foreach(el => byteBuffer.putLong(el.signed))
case data: Array[Float] => data.foreach(el => byteBuffer.putFloat(el))
}
byteBuffer.array()
}

def toBytes(typed: Array[_ >: Byte with Short with Int with Long with Float],
elementClass: ElementClass.Value): Array[Byte] = {
val numBytes = ElementClass.bytesPerElement(elementClass)
val byteBuffer = ByteBuffer.allocate(numBytes * typed.length).order(ByteOrder.LITTLE_ENDIAN)
typed match {
case data: Array[Byte] => data.foreach(el => byteBuffer.put(el))
case data: Array[Short] => data.foreach(el => byteBuffer.putShort(el))
case data: Array[Int] => data.foreach(el => byteBuffer.putInt(el))
case data: Array[Long] => data.foreach(el => byteBuffer.putLong(el))
case data: Array[Float] => data.foreach(el => byteBuffer.putFloat(el))
}
byteBuffer.array()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package com.scalableminds.webknossos.datastore.services

import com.google.inject.Inject
import com.scalableminds.util.geometry.Vec3Int
import com.scalableminds.util.tools.{Fox, FoxImplicits, Math}
import com.scalableminds.util.tools.Math
import com.scalableminds.util.tools.{Fox, FoxImplicits}
import com.scalableminds.webknossos.datastore.models.datasource.{DataLayer, DataSource, ElementClass}
import com.scalableminds.webknossos.datastore.models.requests.DataServiceDataRequest
import com.scalableminds.webknossos.datastore.models.{DataRequest, VoxelPosition}
import net.liftweb.common.Full
import play.api.libs.json.{Json, OFormat}
import spire.math._
import spire.math.{UByte, UInt, ULong, UShort}

import scala.annotation.tailrec
import scala.concurrent.ExecutionContext
Expand Down Expand Up @@ -137,14 +138,15 @@ class FindDataService @Inject()(dataServicesHolder: BinaryDataServiceHolder)(imp
} yield positionAndResolutionOpt

def meanAndStdDev(dataSource: DataSource, dataLayer: DataLayer): Fox[(Double, Double)] = {
Fox.successful(5.0, 5.0)

def convertNonZeroDataToDouble(data: Array[Byte], elementClass: ElementClass.Value): Array[Double] =
convertData(data, elementClass, filterZeroes = true) match {
case d: Array[UByte] => d.map(_.toDouble)
case d: Array[UShort] => d.map(_.toDouble)
case d: Array[UInt] => d.map(_.toDouble)
case d: Array[ULong] => d.map(_.toDouble)
case d: Array[Float] => d.map(_.toDouble)
filterZeroes(convertData(data, elementClass)) match {
case d: Array[Byte] => d.map(uByteToLong).map(_.toDouble)
case d: Array[Short] => d.map(uShortToLong).map(_.toDouble)
case d: Array[Int] => d.map(uIntToLong).map(_.toDouble)
case d: Array[Long] => d.map(_.toDouble)
case d: Array[Float] => d.map(_.toDouble)
}

def meanAndStdDevForPositions(positions: List[Vec3Int], resolution: Vec3Int): Fox[(Double, Double)] =
Expand Down Expand Up @@ -199,7 +201,9 @@ class FindDataService @Inject()(dataServicesHolder: BinaryDataServiceHolder)(imp
}
if (isUint24) {
val listOfCounts = counts.grouped(256).toList
listOfCounts.map(counts => { counts(0) = 0; Histogram(counts, counts.sum.toInt, extrema._1, extrema._2) })
listOfCounts.map(counts => {
counts(0) = 0; Histogram(counts, counts.sum.toInt, extrema._1, extrema._2)
})
} else
List(Histogram(counts, data.length, extrema._1, extrema._2))
}
Expand All @@ -208,7 +212,7 @@ class FindDataService @Inject()(dataServicesHolder: BinaryDataServiceHolder)(imp
for {
dataConcatenated <- getConcatenatedDataFor(dataSource, dataLayer, positions, resolution) ?~> "dataSet.noData"
isUint24 = dataLayer.elementClass == ElementClass.uint24
convertedData = convertData(dataConcatenated, dataLayer.elementClass, filterZeroes = !isUint24)
convertedData = toUnsigned(filterZeroes(convertData(dataConcatenated, dataLayer.elementClass), skip = isUint24))
} yield calculateHistogramValues(convertedData, dataLayer.bytesPerElement, isUint24)

if (dataLayer.resolutions.nonEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.scalableminds.webknossos.datastore.dataformats.SafeCachable
import com.scalableminds.webknossos.datastore.models.VoxelPosition
import com.scalableminds.webknossos.datastore.models.requests.{Cuboid, DataServiceDataRequest}
import com.typesafe.scalalogging.LazyLogging
import spire.math.{ULong, max, min}

import scala.collection.mutable

Expand Down Expand Up @@ -73,27 +72,27 @@ class AgglomerateFileCache(val maxEntries: Int) extends LRUConcurrentCache[Agglo
class AgglomerateIdCache(val maxEntries: Int, val standardBlockSize: Int) extends LRUConcurrentCache[Long, Long] {
// On cache miss, reads whole blocks of IDs (number of elements is standardBlockSize)

def withCache(segmentId: ULong, reader: IHDF5Reader, dataSet: HDF5DataSet)(
def withCache(segmentId: Long, reader: IHDF5Reader, dataSet: HDF5DataSet)(
readFromFile: (IHDF5Reader, HDF5DataSet, Long, Long) => Array[Long]): Long = {

def handleUncachedAgglomerate(): Long = {
val minId =
if (segmentId < ULong(standardBlockSize / 2)) ULong(0) else segmentId - ULong(standardBlockSize / 2)
if (segmentId < standardBlockSize / 2) 0L else segmentId - standardBlockSize / 2

val agglomerateIds = readFromFile(reader, dataSet, minId.toLong, standardBlockSize)
val agglomerateIds = readFromFile(reader, dataSet, minId, standardBlockSize)

agglomerateIds.zipWithIndex.foreach {
case (id, index) => put(index + minId.toLong, id)
case (id, index) => put(index + minId, id)
}

agglomerateIds((segmentId - minId).toInt)
}

getOrHandleUncachedKey(segmentId.toLong, () => handleUncachedAgglomerate())
getOrHandleUncachedKey(segmentId, () => handleUncachedAgglomerate())
}
}

case class BoundingBoxValues(idRange: (ULong, ULong), dimensions: (Long, Long, Long))
case class BoundingBoxValues(idRange: (Long, Long), dimensions: (Long, Long, Long))

case class BoundingBoxFinder(
xCoordinates: util.TreeSet[Long], // TreeSets allow us to find the largest coordinate, which is smaller than the requested cuboid
Expand All @@ -116,7 +115,7 @@ case class BoundingBoxFinder(
class BoundingBoxCache(
val cache: mutable.HashMap[(Long, Long, Long), BoundingBoxValues], // maps bounding box top left to range and bb dimensions
val boundingBoxFinder: BoundingBoxFinder, // saves the bb top left positions
val maxReaderRange: ULong) // config value for maximum amount of elements that are allowed to be read as once
val maxReaderRange: Long) // config value for maximum amount of elements that are allowed to be read as once
extends LazyLogging {
private def getGlobalCuboid(cuboid: Cuboid): Cuboid = {
val res = cuboid.mag
Expand All @@ -130,7 +129,7 @@ class BoundingBoxCache(
}

// get the segment ID range for one cuboid
private def getReaderRange(request: DataServiceDataRequest): (ULong, ULong) = {
private def getReaderRange(request: DataServiceDataRequest): (Long, Long) = {
// convert cuboid to global coordinates (in res 1)
val globalCuboid = getGlobalCuboid(request.cuboid)

Expand Down Expand Up @@ -158,7 +157,7 @@ class BoundingBoxCache(
while (z < requestedCuboid.voxelZInMag && z < dataLayerBox.z) {
// get cached values for current bb and update the reader range by extending if necessary
cache.get((x, y, z)).foreach { value =>
range = (min(range._1, value.idRange._1), max(range._2, value.idRange._2))
range = (Math.min(range._1, value.idRange._1), Math.max(range._2, value.idRange._2))
currDimensions = value.dimensions
}
z = z + currDimensions._3
Expand All @@ -174,24 +173,24 @@ class BoundingBoxCache(
range
}

def withCache(request: DataServiceDataRequest, input: Array[ULong], reader: IHDF5Reader)(
def withCache(request: DataServiceDataRequest, input: Array[Long], reader: IHDF5Reader)(
readHDF: (IHDF5Reader, Long, Long) => Array[Long]): Array[Long] = {
val readerRange = getReaderRange(request)
if (readerRange._2 - readerRange._1 < maxReaderRange) {
val agglomerateIds = readHDF(reader, readerRange._1.toLong, (readerRange._2 - readerRange._1).toLong + 1)
input.map(i => if (i == ULong(0)) 0L else agglomerateIds((i - readerRange._1).toInt))
input.map(i => if (i == 0L) 0L else agglomerateIds((i - readerRange._1).toInt))
} else {
// if reader range does not fit in main memory, read agglomerate ids in chunks
var offset = readerRange._1
val result = Array.ofDim[Long](input.length)
val isTransformed = Array.fill(input.length)(false)
while (offset <= readerRange._2) {
val agglomerateIds =
val agglomerateIds: Array[Long] =
readHDF(reader, offset.toLong, spire.math.min(maxReaderRange, readerRange._2 - offset).toLong + 1)
for (i <- input.indices) {
val inputElement = input(i)
if (!isTransformed(i) && inputElement >= offset && inputElement < offset + maxReaderRange) {
result(i) = if (inputElement == ULong(0)) 0L else agglomerateIds((inputElement - offset).toInt)
result(i) = if (inputElement == 0L) 0L else agglomerateIds((inputElement - offset).toInt)
isTransformed(i) = true
}
}
Expand Down
Loading