From 16be8bdd8e9aa73ba8ada2865dadfa4737db998a Mon Sep 17 00:00:00 2001 From: Anthony Accomazzo Date: Wed, 5 Jul 2023 21:25:07 -0700 Subject: [PATCH] Validate startup packets Supavisor's ClientHandler can receive a lot of junk, as any client can attempt to connect to it. This change filters out the noise from unserious clients (like bots/scrapers) sending requests. --- lib/supavisor/client_handler.ex | 72 ++++++++++++++++++-------- test/supavisor/client_handler_test.exs | 24 +++++++++ 2 files changed, 75 insertions(+), 21 deletions(-) diff --git a/lib/supavisor/client_handler.ex b/lib/supavisor/client_handler.ex index 7725a90c..50ff69eb 100644 --- a/lib/supavisor/client_handler.ex +++ b/lib/supavisor/client_handler.ex @@ -71,22 +71,27 @@ defmodule Supavisor.ClientHandler do end def handle_event(:info, {:tcp, _, bin}, :exchange, %{socket: socket} = data) do - hello = decode_startup_packet(bin) - Logger.warning("Client startup message: #{inspect(hello)}") - {user, external_id} = parse_user_info(hello.payload["user"]) - Logger.metadata(project: external_id, user: user) + with {:ok, hello} <- decode_startup_packet(bin) do + Logger.warning("Client startup message: #{inspect(hello)}") + {user, external_id} = parse_user_info(hello.payload["user"]) + Logger.metadata(project: external_id, user: user) - case Tenants.get_user(external_id, user) do - {:ok, user_info} -> - new_data = update_user_data(data, external_id, user_info) + case Tenants.get_user(external_id, user) do + {:ok, user_info} -> + new_data = update_user_data(data, external_id, user_info) - {:keep_state, new_data, - {:next_event, :internal, {:handle, fn -> user_info.db_password end}}} + {:keep_state, new_data, + {:next_event, :internal, {:handle, fn -> user_info.db_password end}}} - {:error, reason} -> - Logger.error("User not found: #{inspect(reason)} #{inspect({user, external_id})}") - Server.send_error(socket, "XX000", "Tenant or user not found") - {:stop, :normal, data} + {:error, reason} -> + Logger.error("User not found: #{inspect(reason)} #{inspect({user, external_id})}") + Server.send_error(socket, "XX000", "Tenant or user not found") + {:stop, :normal, data} + else + {:error, :bad_startup_payload} -> + Logger.warn("Bad startup packet received", bin: bin) + {:stop, :normal, data} + end end end @@ -295,20 +300,45 @@ defmodule Supavisor.ClientHandler do end def decode_startup_packet(<>) do - %{ - len: len, - payload: - String.split(rest, <<0>>, trim: true) - |> Enum.chunk_every(2) - |> Enum.into(%{}, fn [k, v] -> {k, v} end), - tag: :startup - } + with {:ok, payload} <- decode_startup_packet_payload(rest) do + pkt = %{ + len: len, + payload: payload, + tag: :startup + } + + {:ok, pkt} + end end def decode_startup_packet(_) do :undef end + # The startup packet payload is a list of key/value pairs, separated by null bytes + defp decode_startup_packet_payload(payload) do + fields = String.split(payload, <<0>>, trim: true) + + # If the number of fields is odd, then the payload is malformed + if rem(length(fields), 2) == 1 do + {:error, :bad_startup_payload} + else + map = + fields + |> Enum.chunk_every(2) + |> Enum.map(fn [k, v] -> {k, v} end) + |> Map.new() + + # We only do light validation on the fields in the payload. The only field we use at the + # moment is `user`. If that's missing, this is a bad payload. + if Map.has_key?(map, "user") do + {:ok, map} + else + {:error, :bad_startup_payload} + end + end + end + @spec handle_exchange(port, fun) :: :ok | {:error, String.t()} def handle_exchange(socket, password) do :ok = Server.send_request_authentication(socket) diff --git a/test/supavisor/client_handler_test.exs b/test/supavisor/client_handler_test.exs index 86b05714..6311e53d 100644 --- a/test/supavisor/client_handler_test.exs +++ b/test/supavisor/client_handler_test.exs @@ -17,4 +17,28 @@ defmodule Supavisor.ClientHandlerTest do assert external_id == "external_id" end end + + describe "decode_startup_packet/1" do + test "handles bad startup packets" do + packet = <<0, 0, 0, 8, 0, 0, 0, 0, 3>> + assert {:error, _} = ClientHandler.decode_startup_packet(packet) + end + + test "handles valid startup packets" do + payload = %{ + "DateStyle" => "ISO", + "TimeZone" => "Asia/Tokyo", + "client_encoding" => "UTF8", + "database" => "mydbname", + "extra_float_digits" => "2", + "user" => "tenant.mytenant" + } + + fields = Enum.reduce(payload, [], fn {k, v}, acc -> [k, v | acc] end) |> Enum.join(<<0>>) + len = String.length(fields) + 4 + packet = <> + assert {:ok, hello} = ClientHandler.decode_startup_packet(packet) + assert hello[:payload]["user"] == "tenant.mytenant" + end + end end