Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Switch InstanceLocationConfig to a pydantic BaseModel #15431

Merged
Merged
1 change: 1 addition & 0 deletions changelog.d/15431.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some validation to `instance_map` configuration loading.
28 changes: 27 additions & 1 deletion synapse/config/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable
from typing import Any, Dict, Iterable, Type, TypeVar

import jsonschema
from pydantic import BaseModel, ValidationError, parse_obj_as

from synapse.config._base import ConfigError
from synapse.types import JsonDict
Expand Down Expand Up @@ -64,3 +65,28 @@ def json_error_to_config_error(
else:
path.append(str(p))
return ConfigError(e.message, path)


Model = TypeVar("Model", bound=BaseModel)


def parse_and_validate_mapping(
config: Any,
model_type: Type[Model],
) -> Dict[str, Model]:
"""Parse `config` as a mapping from strings to a given `Model` type.
Args:
config: The configuration data to check
model_type: The BaseModel to validate and parse against.
Returns:
Fully validated and parsed Dict[str, Model].
Raises:
ConfigError, if given improper input.
"""
try:
# type-ignore: mypy doesn't like constructing `Dict[str, model_type]` because
# `model_type` is a runtime variable. Pydantic is fine with this.
instances = parse_obj_as(Dict[str, model_type], config) # type: ignore[valid-type]
except ValidationError as e:
raise ConfigError(str(e)) from e
return instances
52 changes: 43 additions & 9 deletions synapse/config/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from typing import Any, Dict, List, Union

import attr
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr

from synapse.config._base import (
Config,
ConfigError,
RoutableShardedWorkerHandlingConfig,
ShardedWorkerHandlingConfig,
)
from synapse.config._util import parse_and_validate_mapping
from synapse.config.server import (
DIRECT_TCP_ERROR,
TCPListenerConfig,
Expand All @@ -50,13 +52,43 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj


@attr.s(auto_attribs=True)
class InstanceLocationConfig:
class ConfigModel(BaseModel):
"""A custom version of Pydantic's BaseModel which

- ignores unknown fields and
- does not allow fields to be overwritten after construction,

but otherwise uses Pydantic's default behaviour.

For now, ignore unknown fields. In the future, we could change this so that unknown
config values cause a ValidationError, provided the error messages are meaningful to
server operators.

Subclassing in this way is recommended by
https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
"""

class Config:
# By default, ignore fields that we don't recognise.
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False


class InstanceLocationConfig(ConfigModel):
"""The host and port to talk to an instance via HTTP replication."""

host: str
port: int
tls: bool = False
host: StrictStr
port: StrictInt
tls: StrictBool = False

def scheme(self) -> str:
"""Hardcode a retrievable scheme based on self.tls"""
return "https" if self.tls else "http"

def netloc(self) -> str:
"""Nicely format the network location data"""
return f"{self.host}:{self.port}"


@attr.s
Expand Down Expand Up @@ -183,10 +215,12 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
)

# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {}
self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
}
self.instance_map: Dict[
str, InstanceLocationConfig
] = parse_and_validate_mapping(
config.get("instance_map", {}),
InstanceLocationConfig,
)

# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {}
Expand Down