Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Switch InstanceLocationConfig to a pydantic BaseModel #15431

Merged
Merged
1 change: 1 addition & 0 deletions changelog.d/15431.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some validation to `instance_map` configuration loading.
30 changes: 29 additions & 1 deletion synapse/config/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable
from typing import Any, Dict, Iterable, TypeAlias, TypeVar

import jsonschema
from pydantic import BaseModel, ValidationError, parse_obj_as

from synapse.config._base import ConfigError
from synapse.types import JsonDict
Expand Down Expand Up @@ -64,3 +65,30 @@ def json_error_to_config_error(
else:
path.append(str(p))
return ConfigError(e.message, path)


Model = TypeVar("Model", bound=BaseModel)


def validate_instance_map_config(
realtyem marked this conversation as resolved.
Show resolved Hide resolved
config: Any,
model_type: TypeAlias,
realtyem marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, Model]:
"""Validate Dict type data in the style of Dict[str, ExampleBaseModel] and check it
for realness. Acts as a wrapper for parse_obj_as() from pydantic that changes any
ValidationError to ConfigError.

Args:
config: The configuration data to check
model_type: The BaseModel to validate and parse against.
Returns: Fully validated and parsed Dict[str, ExampleBaseModel]
"""
realtyem marked this conversation as resolved.
Show resolved Hide resolved
try:
instances = parse_obj_as(Dict[str, model_type], config)
except ValidationError as e:
raise validation_error_to_config_error(e)
return instances


def validation_error_to_config_error(e: ValidationError) -> ConfigError:
return ConfigError(str(e))
realtyem marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 42 additions & 9 deletions synapse/config/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from typing import Any, Dict, List, Union

import attr
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr

from synapse.config._base import (
Config,
ConfigError,
RoutableShardedWorkerHandlingConfig,
ShardedWorkerHandlingConfig,
)
from synapse.config._util import validate_instance_map_config
from synapse.config.server import (
DIRECT_TCP_ERROR,
TCPListenerConfig,
Expand All @@ -50,13 +52,42 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj


@attr.s(auto_attribs=True)
class InstanceLocationConfig:
class InstanceLocationConfigModel(BaseModel):
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""A custom version of Pydantic's BaseModel which

- ignores unknown fields and
- does not allow fields to be overwritten after construction,

but otherwise uses Pydantic's default behaviour.

Ignoring unknown fields is a useful default. It means that clients can provide
unstable field not known to the server without the request being refused outright.

realtyem marked this conversation as resolved.
Show resolved Hide resolved
realtyem marked this conversation as resolved.
Show resolved Hide resolved
Subclassing in this way is recommended by
https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
"""

class Config:
# By default, ignore fields that we don't recognise.
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False


class InstanceLocationConfig(InstanceLocationConfigModel):
"""The host and port to talk to an instance via HTTP replication."""

host: str
port: int
tls: bool = False
host: StrictStr
port: StrictInt
tls: StrictBool = False

def scheme(self) -> str:
"""Hardcode a retrievable scheme based on self.tls"""
return "https" if self.tls else "http"

def netloc(self) -> str:
"""Nicely format the network location data"""
return f"{self.host}:{self.port}"


@attr.s
Expand Down Expand Up @@ -183,10 +214,12 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
)

# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {}
self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
}
self.instance_map: Dict[
str, InstanceLocationConfig
] = validate_instance_map_config(
config.get("instance_map", {}),
InstanceLocationConfig,
)

# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {}
Expand Down