diff --git a/pygmt/base_plotting.py b/pygmt/base_plotting.py index 90648e02d6a..83f882ebb71 100644 --- a/pygmt/base_plotting.py +++ b/pygmt/base_plotting.py @@ -2,6 +2,7 @@ Base class with plot generating commands. Does not define any special non-GMT methods (savefig, show, etc). """ +import contextlib import csv import os import numpy as np @@ -294,6 +295,7 @@ def grdimage(self, grid, **kwargs): Project grids or images and plot them on maps. Takes a grid file name or an xarray.DataArray object as input. + Alternatively, pass in a list of red, green, blue grids to be imaged. Full option list at :gmt-docs:`grdimage.html` @@ -301,20 +303,43 @@ def grdimage(self, grid, **kwargs): Parameters ---------- - grid : str or xarray.DataArray + grid : str, xarray.DataArray or list The file name of the input grid or the grid loaded as a DataArray. - + For plotting RGB grids, pass in a list made up of either file names or + DataArrays to the individual red, green and blue grids. """ kwargs = self._preprocess(**kwargs) - kind = data_kind(grid, None, None) + + if isinstance(grid, list): + if all([data_kind(g) == "file" for g in grid]): + kind = "file" + grid = " ".join(grid) + elif all([data_kind(g) == "grid" for g in grid]): + kind = "grid" + else: + kind = data_kind(grid) + with Session() as lib: if kind == "file": - file_context = dummy_context(grid) + file_contexts = [dummy_context(grid)] elif kind == "grid": - file_context = lib.virtualfile_from_grid(grid) + if isinstance(grid, list): + file_contexts = [ + lib.virtualfile_from_grid(grid[0]), + lib.virtualfile_from_grid(grid[1]), + lib.virtualfile_from_grid(grid[2]), + ] + else: + file_contexts = [lib.virtualfile_from_grid(grid)] else: raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid))) - with file_context as fname: + with contextlib.ExitStack() as stack: + fname = " ".join( + [ + stack.enter_context(file_context) + for file_context in file_contexts + ] + ) arg_str = " ".join([fname, build_arg_string(kwargs)]) lib.call_module("grdimage", arg_str) diff --git a/pygmt/tests/baseline/test_grdimage_rgb_files.png b/pygmt/tests/baseline/test_grdimage_rgb_files.png new file mode 100644 index 00000000000..3bf59a2f873 Binary files /dev/null and b/pygmt/tests/baseline/test_grdimage_rgb_files.png differ diff --git a/pygmt/tests/baseline/test_grdimage_rgb_grid.png b/pygmt/tests/baseline/test_grdimage_rgb_grid.png new file mode 100644 index 00000000000..c6a2a69302a Binary files /dev/null and b/pygmt/tests/baseline/test_grdimage_rgb_grid.png differ diff --git a/pygmt/tests/test_grdimage.py b/pygmt/tests/test_grdimage.py index 72759137f0f..7e88ff717e7 100644 --- a/pygmt/tests/test_grdimage.py +++ b/pygmt/tests/test_grdimage.py @@ -3,6 +3,7 @@ """ import numpy as np import pytest +import xarray as xr from .. import Figure from ..exceptions import GMTInvalidInput @@ -41,6 +42,38 @@ def test_grdimage_file(): return fig +@pytest.mark.mpl_image_compare +def test_grdimage_rgb_files(): + "Plot an image using Red, Green, and Blue file inputs" + fig = Figure() + fig.grdimage(grid=["@earth_relief_60m", "@earth_relief_60m", "@earth_relief_60m"]) + return fig + + +@pytest.mark.mpl_image_compare +def test_grdimage_rgb_grid(): + "Plot an image using Red, Green, and Blue xarray.DataArray inputs" + red = xr.DataArray( + data=[[128, 0, 0], [128, 0, 0]], + dims=("lat", "lon"), + coords={"lat": [0, 1], "lon": [2, 3, 4]}, + ) + green = xr.DataArray( + data=[[0, 128, 0], [0, 128, 0]], + dims=("lat", "lon"), + coords={"lat": [0, 1], "lon": [2, 3, 4]}, + ) + blue = xr.DataArray( + data=[[0, 0, 128], [0, 0, 128]], + dims=("lat", "lon"), + coords={"lat": [0, 1], "lon": [2, 3, 4]}, + ) + + fig = Figure() + fig.grdimage(grid=[red, green, blue], projection="x5c", frame=True) + return fig + + def test_grdimage_fails(): "Should fail for unrecognized input" fig = Figure()