diff --git a/geo/Geoserver.py b/geo/Geoserver.py index 87acb9e..e17a14b 100644 --- a/geo/Geoserver.py +++ b/geo/Geoserver.py @@ -1,6 +1,7 @@ # inbuilt libraries import os -from typing import List, Optional, Set +from typing import List, Optional, Set, Union +from pathlib import Path # third-party libraries import requests @@ -9,7 +10,7 @@ # custom functions from .Calculation_gdal import raster_value from .Style import catagorize_xml, classified_xml, coverage_style_xml, outline_only_xml -from .supports import prepare_zip_file +from .supports import prepare_zip_file, is_valid_xml # Custom exceptions. @@ -1130,7 +1131,8 @@ def upload_style( ----- The name of the style file will be, sld_name:workspace This function will create the style file in a specified workspace. - Inputs: path to the sld_file, workspace, + `path` can either be the path to the SLD file itself, or a string containing valid XML to be used for the style + Inputs: path to the sld_file or the contents of an SLD file itself, workspace, """ if name is None: name = os.path.basename(path) @@ -1138,6 +1140,17 @@ def upload_style( if len(f) > 0: name = f[0] + if Path(path).exists(): + # path is pointing to an existing file + with open(path, "rb") as f: + xml = f.read() + elif is_valid_xml(path): + # path is actually just the xml itself + xml = path + else: + # path is non-existing file or not valid xml + raise ValueError("`path` must be either a path to a style file, or a valid XML string.") + headers = {"content-type": "text/xml"} url = "{}/rest/workspaces/{}/styles".format(self.service_url, workspace) @@ -1158,13 +1171,13 @@ def upload_style( r = self._requests(method="post", url=url, data=style_xml, headers=headers) if r.status_code == 201: - with open(path, "rb") as f: - r_sld = requests.put( - url + "/" + name, - data=f.read(), - auth=(self.username, self.password), - headers=header_sld, - ) + r_sld = requests.put( + url + "/" + name, + data=xml, + auth=(self.username, self.password), + headers=header_sld, + ) + if r_sld.status_code == 200: return r_sld.status_code else: diff --git a/geo/supports.py b/geo/supports.py index 06786ae..db533ec 100644 --- a/geo/supports.py +++ b/geo/supports.py @@ -2,6 +2,7 @@ from tempfile import mkstemp from typing import Dict from zipfile import ZipFile +import xml.etree.ElementTree as ET def prepare_zip_file(name: str, data: Dict) -> str: @@ -37,3 +38,25 @@ def prepare_zip_file(name: str, data: Dict) -> str: zip_file.close() os.close(fd) return path + + +def is_valid_xml(xml_string: str) -> bool: + + """ + Returns True if string is valid XML, false otherwise + + Parameters + ---------- + xml_string : string containing xml + + Returns + ------- + bool + """ + + try: + # Attempt to parse the XML string + ET.fromstring(xml_string) + return True + except ET.ParseError: + return False diff --git a/tests/data/style.sld b/tests/data/style.sld new file mode 100644 index 0000000..f6077ff --- /dev/null +++ b/tests/data/style.sld @@ -0,0 +1,86 @@ + + + + generic + + Generic + Generic style + + + raster + Opaque Raster + + + + true + + + + 1.0 + + + + Polygon + Grey Polygon + + + + + + 2 + + + + + #AAAAAA + + + #000000 + 1 + + + + + Line + Blue Line + + + + + + 1 + + + + + #0000FF + 1 + + + + + point + Red Square Point + + + + + square + + #FF0000 + + + 6 + + + + first + + + + diff --git a/tests/test_geoserver.py b/tests/test_geoserver.py index 73c448b..e6ec835 100644 --- a/tests/test_geoserver.py +++ b/tests/test_geoserver.py @@ -1,9 +1,13 @@ +import pathlib + import pytest from geo.Style import catagorize_xml, classified_xml +from geo.Geoserver import GeoserverException from .common import geo +HERE = pathlib.Path(__file__).parent.resolve() @pytest.mark.skip(reason="Only setup for local testing.") class TestRequest: @@ -122,6 +126,58 @@ def test_styles(self): ) +class TestUploadStyles: + + def test_upload_style_from_file(self): + + try: + geo.delete_style("test_upload_style") + except GeoserverException: + pass + + geo.upload_style(f"{HERE}/data/style.sld", "test_upload_style") + style = geo.get_style("test_upload_style") + assert style["style"]["name"] == "test_upload_style" + + def test_upload_style_from_malformed_file_fails(self): + + try: + geo.delete_style("style_doesnt_exist") + except GeoserverException: + pass + + with pytest.raises(ValueError): + geo.upload_style(f"{HERE}/data/style_doesnt_exist.sld", "style_doesnt_exist") + with pytest.raises(GeoserverException): + style = geo.get_style("style_doesnt_exist") + print() + + def test_upload_style_from_xml(self): + + try: + geo.delete_style("test_upload_style") + except GeoserverException: + pass + + xml = open(f"{HERE}/data/style.sld").read() + geo.upload_style(xml, "test_upload_style") + style = geo.get_style("test_upload_style") + assert style["style"]["name"] == "test_upload_style" + + def test_upload_style_from_malformed_xml_fails(self): + + try: + geo.delete_style("style_malformed") + except GeoserverException: + pass + + xml = open(f"{HERE}/data/style.sld").read()[1:] + with pytest.raises(ValueError): + geo.upload_style(xml, "style_malformed") + with pytest.raises(GeoserverException): + style = geo.get_style("style_malformed") + + @pytest.mark.skip(reason="Only setup for local testing.") class TestPostGres: from geo.Postgres import Db