Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Switch InstanceLocationConfig to a pydantic BaseModel #15431

Merged
Merged
1 change: 1 addition & 0 deletions changelog.d/15431.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some validation to `instance_map` configuration loading.
realtyem marked this conversation as resolved.
Show resolved Hide resolved
26 changes: 17 additions & 9 deletions synapse/config/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Dict, List, Union

import attr
from pydantic import BaseModel, StrictBool, StrictInt, StrictStr, parse_obj_as

from synapse.config._base import (
Config,
Expand Down Expand Up @@ -50,13 +51,20 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj


@attr.s(auto_attribs=True)
class InstanceLocationConfig:
class InstanceLocationConfig(BaseModel):
realtyem marked this conversation as resolved.
Show resolved Hide resolved
"""The host and port to talk to an instance via HTTP replication."""

host: str
port: int
tls: bool = False
host: StrictStr
port: StrictInt
tls: StrictBool = False

def scheme(self) -> str:
"""Hardcode a retrievable scheme based on self.tls"""
return "https" if self.tls else "http"

def netloc(self) -> str:
"""Nicely format the network location data"""
return f"{self.host}:{self.port}"


@attr.s
Expand Down Expand Up @@ -183,10 +191,10 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
)

# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {}
self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
}
# instance_map = config.get("instance_map") or {}
self.instance_map = parse_obj_as(
Dict[str, InstanceLocationConfig], config.get("instance_map") or {}
realtyem marked this conversation as resolved.
Show resolved Hide resolved
)

# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {}
Expand Down