Skip to content

Commit

Permalink
Merge pull request #454 from databrickslabs/feature/remove_vsimem
Browse files Browse the repository at this point in the history
Remove any usage of /vsimem/ due to memory leaks.
  • Loading branch information
Milos Colic authored Dec 4, 2023
2 parents 0616f27 + ba7f276 commit 1aaea1c
Show file tree
Hide file tree
Showing 128 changed files with 2,179 additions and 454 deletions.
3 changes: 2 additions & 1 deletion .github/actions/python_build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ runs:
run: |
cd python
pip install build wheel pyspark==${{ matrix.spark }} numpy==${{ matrix.numpy }}
pip install gdal==${{ matrix.gdal }}
pip install numpy==${{ matrix.numpy }}
pip install --no-build-isolation --no-cache-dir --force-reinstall gdal==${{ matrix.gdal }}
pip install .
- name: Test and build python package
shell: bash
Expand Down
2 changes: 2 additions & 0 deletions .github/actions/scala_build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ runs:
pip install databricks-mosaic-gdal==${{ matrix.gdal }}
sudo tar -xf /opt/hostedtoolcache/Python/${{ matrix.python }}/x64/lib/python3.9/site-packages/databricks-mosaic-gdal/resources/gdal-${{ matrix.gdal }}-filetree.tar.xz -C /
sudo tar -xhf /opt/hostedtoolcache/Python/${{ matrix.python }}/x64/lib/python3.9/site-packages/databricks-mosaic-gdal/resources/gdal-${{ matrix.gdal }}-symlinks.tar.xz -C /
pip install numpy==${{ matrix.numpy }}
pip install gdal==${{ matrix.gdal }}
- name: Test and build the scala JAR - skip tests is false
if: inputs.skip_tests == 'false'
shell: bash
Expand Down
Binary file added mosaic-0.3.12-jar-with-dependencies.jar
Binary file not shown.
21 changes: 21 additions & 0 deletions python/mosaic/api/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"grid_cell_intersection_agg",
"rst_merge_agg",
"rst_combineavg_agg",
"rst_derivedband_agg",
"st_intersection_agg",
"st_intersects_agg",
]
Expand Down Expand Up @@ -209,3 +210,23 @@ def rst_combineavg_agg(raster: ColumnOrName) -> Column:
return config.mosaic_context.invoke_function(
"rst_combineavg_agg", pyspark_to_java_column(raster)
)


def rst_derivedband_agg(raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName) -> Column:
"""
Returns the raster representing the aggregation of rasters using provided python function.
Parameters
----------
raster: Column
pythonFunc: Column
funcName: Column
Returns
-------
Column
The resulting raster.
"""
return config.mosaic_context.invoke_function(
"rst_derivedband_agg", pyspark_to_java_column(raster), pyspark_to_java_column(pythonFunc), pyspark_to_java_column(funcName)
)
29 changes: 29 additions & 0 deletions python/mosaic/api/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"rst_boundingbox",
"rst_clip",
"rst_combineavg",
"rst_derivedband",
"rst_frombands",
"rst_fromfile",
"rst_georeference",
Expand Down Expand Up @@ -145,6 +146,34 @@ def rst_combineavg(rasters: ColumnOrName) -> Column:
)


def rst_derivedband(raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName) -> Column:
"""
Creates a new band by applying the given python function to the input rasters.
The result is a raster tile.
Parameters
----------
raster : Column (StringType)
Path to the raster file.
pythonFunc : Column (StringType)
The python function to apply to the bands.
funcName : Column (StringType)
The name of the function.
Returns
-------
Column (StringType)
The path to the new raster.
"""
return config.mosaic_context.invoke_function(
"rst_derivedband",
pyspark_to_java_column(raster),
pyspark_to_java_column(pythonFunc),
pyspark_to_java_column(funcName),
)


def rst_georeference(raster: ColumnOrName) -> Column:
"""
Returns GeoTransform of the raster as a GT array of doubles.
Expand Down
2 changes: 1 addition & 1 deletion python/test/test_raster_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_raster_flatmap_functions(self):
)

overlap_result.write.format("noop").mode("overwrite").save()
self.assertEqual(overlap_result.count(), 86)
self.assertEqual(overlap_result.count(), 87)

def test_raster_aggregator_functions(self):
collection = (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package com.databricks.labs.mosaic.core.raster.api

object FormatLookup {

// ShortDriverName -> FormatExtension
val formats: Map[String, String] = Map(
"AAIGrid" -> "asc",
"ACE2" -> "ace2",
"ADRG" -> "gen",
"AIG" -> "aig",
"AIRSAR" -> "airsar",
"ARCGEN" -> "gen",
"ARG" -> "arg",
"BLX" -> "blx",
"BMP" -> "bmp",
"BT" -> "bt",
"CAD" -> "dwg",
"CEOS" -> "ceos",
"COASP" -> "coasp",
"COSAR" -> "cosar",
"CPG" -> "cpg",
"CSW" -> "csw",
"CTG" -> "ctg",
"DB2ODBC" -> "db2",
"DERIVED" -> "derived",
"DGN" -> "dgn",
"DIMAP" -> "dim",
"DIPEx" -> "dipex",
"DOQ1" -> "doq1",
"DOQ2" -> "doq2",
"DTED" -> "dt0",
"DXF" -> "dxf",
"ECRGTOC" -> "toc",
"ECRGTP" -> "ecrgtp",
"EEDA" -> "eeda",
"EIR" -> "eir",
"ELAS" -> "elas",
"ENVI" -> "hdr",
"ERS" -> "ers",
"ESAT" -> "esat",
"ESRI Shapefile" -> "shp",
"ESRI" -> "ers",
"FAST" -> "fst",
"FIT" -> "fit",
"FITS" -> "fits",
"GFF" -> "gff",
"GIF" -> "gif",
"GLOBE" -> "globe",
"GMT" -> "gmt",
"GNM" -> "gnm",
"GRASSASCIIGrid" -> "asc",
"GRASS" -> "grass",
"GRIB" -> "grb",
"GTiff" -> "tif",
"GXF" -> "gxf",
"HDF4" -> "hdf4",
"HDF5" -> "hdf5",
"HF2" -> "hf2",
"HFA" -> "img",
"HTTP" -> "http",
"IDRISI" -> "rst",
"ILWIS" -> "mpr",
"INGR" -> "grd",
"IRIS" -> "ppm",
"ISIS2" -> "cub",
"ISIS3" -> "cub",
"JDEM" -> "mem",
"JPEG2000" -> "jp2",
"JPEG" -> "jpg",
"JP2OpenJPEG" -> "jp2",
"KMLSUPEROVERLAY" -> "kml",
"LAN" -> "lan",
"LCP" -> "lcp",
"L1B" -> "l1b",
"MBTiles" -> "mbtiles",
"MEM" -> "mem",
"MFF" -> "mff",
"MG4Lidar" -> "mg4l",
"MRF" -> "mrf",
"MSGN" -> "msgn",
"NDF" -> "ndf",
"NITF" -> "ntf",
"NTv2" -> "gsb",
"ODBC" -> "odbc",
"OGR_GMT" -> "gmt",
"OGR_PDS" -> "pds",
"OGR_SDTS" -> "sdts",
"OGR_VRT" -> "vrt",
"OGR" -> "shp",
"OpenAir" -> "oar",
"PCIDSK" -> "pix",
"PCRaster" -> "map",
"PDF" -> "pdf",
"PDS" -> "pds",
"PGDUMP" -> "pgdump",
"PGeo" -> "mdb",
"PLMOSAIC" -> "mosaic",
"PNG" -> "png",
"PostgreSQL" -> "pg",
"R" -> "r",
"RDA" -> "rda",
"RIK" -> "rik",
"RMF" -> "rmf",
"ROI_PAC" -> "rsc",
"RPFTOC" -> "toc",
"RS2" -> "rs2",
"RST" -> "rst",
"SAGA" -> "sdat",
"SAR_CEOS" -> "ceos",
"SAR_SG" -> "sgm",
"SDTS" -> "sdts",
"SEGUKOOA" -> "dat",
"SEGY" -> "segy",
"Sentinel2" -> "jp2",
"SRTMHGT" -> "hgt",
"SQLite" -> "sqlite",
"SUA" -> "sua",
"SVG" -> "svg",
"TIGER" -> "tiger",
"TIL" -> "til",
"TSX" -> "tsx",
"USGSDEM" -> "dem",
"VDV" -> "vdv",
"VICAR" -> "vicar",
"VFK" -> "vfk",
"VRT" -> "vrt",
"WCS" -> "wcs",
"WFS" -> "wfs",
"WMS" -> "wms",
"XLS" -> "xls",
"XLSX" -> "xlsx",
"XPlane" -> "bin",
"netCDF" -> "nc",
"Zarr" -> "zarr"
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import org.apache.spark.unsafe.types.UTF8String
import org.gdal.gdal.gdal
import org.gdal.gdalconst.gdalconstConstants._

import java.util.UUID

/**
* GDAL Raster API. It uses [[MosaicRasterGDAL]] as the
* [[com.databricks.labs.mosaic.core.raster.io.RasterReader]].
Expand Down Expand Up @@ -66,8 +68,9 @@ object GDAL {
def getExtension(driverShortName: String): String = {
val driver = gdal.GetDriverByName(driverShortName)
val result = driver.GetMetadataItem("DMD_EXTENSION")
val toReturn = if (result == null) FormatLookup.formats(driverShortName) else result
driver.delete()
result
toReturn
}

/**
Expand All @@ -84,7 +87,7 @@ object GDAL {
* Returns a Raster object.
*/
def readRaster(
inputRaster: => Any,
inputRaster: Any,
parentPath: String,
shortDriverName: String,
inputDT: DataType
Expand Down Expand Up @@ -117,13 +120,14 @@ object GDAL {
* @return
* Returns the paths of the written rasters.
*/
def writeRasters(generatedRasters: => Seq[MosaicRasterGDAL], checkpointPath: String, rasterDT: DataType): Seq[Any] = {
def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], checkpointPath: String, rasterDT: DataType): Seq[Any] = {
generatedRasters.map(raster =>
if (raster != null) {
rasterDT match {
case StringType =>
val uuid = UUID.randomUUID().toString
val extension = GDAL.getExtension(raster.getDriversShortName)
val writePath = s"$checkpointPath/${raster.uuid}.$extension"
val writePath = s"$checkpointPath/$uuid.$extension"
val outPath = raster.writeToPath(writePath)
RasterCleaner.dispose(raster)
UTF8String.fromString(outPath)
Expand Down Expand Up @@ -159,7 +163,7 @@ object GDAL {
* @return
* Returns a Raster object.
*/
def raster(content: => Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL =
def raster(content: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL =
MosaicRasterGDAL.readRaster(content, parentPath, driverShortName)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import scala.collection.JavaConverters.dictionaryAsScalaMapConverter
import scala.util._

/** GDAL implementation of the MosaicRasterBand trait. */
class MosaicRasterBandGDAL(band: => Band, id: Int) {
case class MosaicRasterBandGDAL(band: Band, id: Int) {

def getBand: Band = band

Expand Down
Loading

0 comments on commit 1aaea1c

Please sign in to comment.