Skip to content

Commit

Permalink
fix: add argument to indicate the property that contains the band nam…
Browse files Browse the repository at this point in the history
…es (#425)
  • Loading branch information
12rambau authored Feb 5, 2025
2 parents f478dbf + 656f5c0 commit 14fdc79
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
40 changes: 27 additions & 13 deletions geetools/ee_image_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,12 +793,18 @@ def validPixel(self, band: str | ee.String = "") -> ee.Image:
validPct = validPixel.divide(self._obj.size()).multiply(100).rename("pct_valid")
return validPixel.addBands(validPct)

def containsBandNames(self, bandNames: list[str] | ee.List, filter: str) -> ee.ImageCollection:
def containsBandNames(
self,
bandNames: list[str] | ee.List,
filter: str,
bandNamesProperty: str | ee.String = "system:band_names",
) -> ee.ImageCollection:
"""Filter the :py:class:`ee.ImageCollection` by band names using the provided filter.
Args:
bandNames: List of band names to filter.
filter: Type of filter to apply. To keep images that contains all the specified bands use ``"ALL"``. To get the images including at least one of the specified band use ``"ANY"``.
bandNamesProperty: the name of the property that contains the band names. Defaults to GEE native default: 'system:band_name'.
Returns:
A filtered :py:class:`ee.ImageCollection`
Expand All @@ -821,31 +827,34 @@ def containsBandNames(self, bandNames: list[str] | ee.List, filter: str) -> ee.I
filter = {"ALL": "Filter.and", "ANY": "Filter.or"}[filter]
bandNames = ee.List(bandNames)

# add bands as metadata in a temporary property
band_name = uuid.uuid4().hex
ic = self._obj.map(lambda i: i.set(band_name, i.bandNames()))

# create a filter by combining a listContain filter over all the band names from the
# user list. Combine them with a "Or" to get a "any" filter and "And" to get a "all".
# We use a workaround until this is solved: https://issuetracker.google.com/issues/322838709
filterList = bandNames.map(lambda b: ee.Filter.listContains(band_name, b))
filterList = bandNames.map(lambda b: ee.Filter.listContains(bandNamesProperty, b))
filterCombination = apifunction.ApiFunction.call_(filter, ee.List(filterList))

# apply this filter and remove the temporary property. Exclude parameter is additive so
# we do a blank multiplication to remove all the properties beforhand
ic = ee.ImageCollection(ic.filter(filterCombination))
ic = ic.map(lambda i: ee.Image(i.multiply(1).copyProperties(i, exclude=[band_name])))
ic = ee.ImageCollection(self._obj.filter(filterCombination))
ic = ic.map(
lambda i: ee.Image(i.multiply(1).copyProperties(i, exclude=[bandNamesProperty]))
)

return ee.ImageCollection(ic)

def containsAllBands(self, bandNames: list[str] | ee.List) -> ee.ImageCollection:
def containsAllBands(
self,
bandNames: list[str] | ee.List,
bandNamesProperty: str | ee.String = "system:band_names",
) -> ee.ImageCollection:
"""Filter the :py:class:`ee.ImageCollection` keeping only the images with all the provided bands.
Args:
bandNames: List of band names to filter.
bandNamesProperty: the name of the property that contains the band names. Defaults to GEE native default: 'system:band_name'.
Returns:
A filtered :py:class:`ee.ImageCollection`
A filtered :py:class:`ee.ImageCollection`.
Examples:
.. code-block::
Expand All @@ -863,13 +872,18 @@ def containsAllBands(self, bandNames: list[str] | ee.List) -> ee.ImageCollection
filtered = collection.geetools.containsAllBands(["B1", "B2"])
print(filtered.getInfo())
"""
return self.containsBandNames(bandNames, "ALL")
return self.containsBandNames(bandNames, "ALL", bandNamesProperty)

def containsAnyBands(self, bandNames: list[str] | ee.List) -> ee.ImageCollection:
def containsAnyBands(
self,
bandNames: list[str] | ee.List,
bandNamesProperty: str | ee.String = "system:band_names",
) -> ee.ImageCollection:
"""Filter the :py:class:`ee.ImageCollection` keeping only the images with any of the provided bands.
Args:
bandNames: List of band names to filter.
bandNamesProperty: the name of the property that contains the band names. Defaults to GEE native default: 'system:band_name'.
Returns:
A filtered :py:class:`ee.ImageCollection`
Expand All @@ -890,7 +904,7 @@ def containsAnyBands(self, bandNames: list[str] | ee.List) -> ee.ImageCollection
filtered = collection.geetools.containsAnyBands(["B1", "B2"])
print(filtered.getInfo())
"""
return self.containsBandNames(bandNames, "ANY")
return self.containsBandNames(bandNames, "ANY", bandNamesProperty)

def aggregateArray(self, properties: list[str] | ee.List | None = None) -> ee.Dictionary:
"""Aggregate the :py:class:`ee.ImageCollection` selected properties into a dictionary.
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def s2_sr(amazonas) -> ee.ImageCollection:
)


@pytest.fixture
def aster(vatican) -> ee.ImageCollection:
"""Aster collection in Vatican City for year 2020."""
return (
ee.ImageCollection("ASTER/AST_L1T_003")
.filterBounds(vatican.geometry())
.filterDate("2020-01-01", "2021-01-01")
)


@pytest.fixture
def vatican():
"""Return the vatican city."""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_ImageCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def test_contains_all(self, s2_sr):
ic = ic.geetools.containsAllBands(["B2", "B3"])
assert ic.size().getInfo() == 2449

def test_contains_all_property_name(self, aster):
ic = aster.geetools.containsAllBands(
["B3N", "B02", "B01"], bandNamesProperty="ORIGINAL_BANDS_PRESENT"
)
assert ic.size().getInfo() == 2

def test_contains_all_mismatch(self, s2_sr):
ic = s2_sr.select(["B2", "B3", "B4"])
ic = ic.geetools.containsAllBands(["B2", "B3", "B5"])
Expand All @@ -248,6 +254,12 @@ def test_contains_any(self, s2_sr):
ic = ic.geetools.containsAnyBands(["B2", "B3", "B5"])
assert ic.size().getInfo() == 2449

def test_contains_any_property_name(self, aster):
ic = aster.geetools.containsAnyBands(
["B3N", "B02", "B01"], bandNamesProperty="ORIGINAL_BANDS_PRESENT"
)
assert ic.size().getInfo() == 2

def test_contains_any_mismatch(self, s2_sr):
ic = s2_sr.select(["B2", "B3", "B4"])
ic = ic.geetools.containsAnyBands(["B5", "B6"])
Expand Down

0 comments on commit 14fdc79

Please sign in to comment.