Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 52 additions & 20 deletions src/bun.js/bindings/webcore/MessagePortChannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ MessagePortChannel::MessagePortChannel(MessagePortChannelRegistry& registry, con
: m_ports { port1, port2 }
, m_registry(registry)
{
relaxAdoptionRequirement();

m_processes[0] = port1.processIdentifier;
m_entangledToProcessProtectors[0] = this;
m_processes[1] = port2.processIdentifier;
Expand All @@ -61,6 +59,7 @@ std::optional<ProcessIdentifier> MessagePortChannel::processForPort(const Messag
{
ASSERT(port == m_ports[0] || port == m_ports[1]);
size_t i = port == m_ports[0] ? 0 : 1;
Locker locker { m_lock };
return m_processes[i];
}

Expand All @@ -76,6 +75,7 @@ void MessagePortChannel::entanglePortWithProcess(const MessagePortIdentifier& po

// LOG(MessagePorts, "MessagePortChannel %s (%p) entangling port %s (that port has %zu messages available)", logString().utf8().data(), this, port.logString().utf8().data(), m_pendingMessages[i].size());

Locker locker { m_lock };
ASSERT(!m_processes[i] || *m_processes[i] == process);
m_processes[i] = process;
m_entangledToProcessProtectors[i] = this;
Expand All @@ -89,27 +89,32 @@ void MessagePortChannel::disentanglePort(const MessagePortIdentifier& port)
ASSERT(port == m_ports[0] || port == m_ports[1]);
size_t i = port == m_ports[0] ? 0 : 1;

ASSERT(m_processes[i] || m_isClosed[i]);
m_processes[i] = std::nullopt;
m_pendingMessagePortTransfers[i].add(this);
RefPtr<MessagePortChannel> protectedThis;
{
Locker locker { m_lock };
ASSERT(m_processes[i] || m_isClosed[i]);
m_processes[i] = std::nullopt;
m_pendingMessagePortTransfers[i].add(this);

// This set of steps is to guarantee that the lock is unlocked before the
// last ref to this object is released.
auto protectedThis = WTF::move(m_entangledToProcessProtectors[i]);
// This set of steps is to guarantee that the lock is unlocked before the
// last ref to this object is released.
protectedThis = WTF::move(m_entangledToProcessProtectors[i]);
}
}

void MessagePortChannel::closePort(const MessagePortIdentifier& port)
{
ASSERT(port == m_ports[0] || port == m_ports[1]);
size_t i = port == m_ports[0] ? 0 : 1;

m_processes[i] = std::nullopt;
m_isClosed[i] = true;

// This set of steps is to guarantee that the lock is unlocked before the
// last ref to this object is released.
Ref protectedThis { *this };

Locker locker { m_lock };
m_processes[i] = std::nullopt;
m_isClosed[i] = true;

m_pendingMessages[i].clear();
m_pendingMessagePortTransfers[i].clear();
m_pendingMessageProtectors[i] = nullptr;
Expand All @@ -121,6 +126,7 @@ bool MessagePortChannel::postMessageToRemote(MessageWithMessagePorts&& message,
ASSERT(remoteTarget == m_ports[0] || remoteTarget == m_ports[1]);
size_t i = remoteTarget == m_ports[0] ? 0 : 1;

Locker locker { m_lock };
m_pendingMessages[i].append(WTF::move(message));
// LOG(MessagePorts, "MessagePortChannel %s (%p) now has %zu messages pending on port %s", logString().utf8().data(), this, m_pendingMessages[i].size(), remoteTarget.logString().utf8().data());

Expand All @@ -140,22 +146,35 @@ void MessagePortChannel::takeAllMessagesForPort(const MessagePortIdentifier& por
ASSERT(port == m_ports[0] || port == m_ports[1]);
size_t i = port == m_ports[0] ? 0 : 1;

if (m_pendingMessages[i].isEmpty()) {
callback({}, [] {});
return;
}
Vector<MessageWithMessagePorts> result;
RefPtr<MessagePortChannel> protectedThis;
bool isEmpty = false;

ASSERT(m_pendingMessageProtectors[i]);
{
Locker locker { m_lock };

Vector<MessageWithMessagePorts> result;
result.swap(m_pendingMessages[i]);
if (m_pendingMessages[i].isEmpty()) {
isEmpty = true;
} else {
ASSERT(m_pendingMessageProtectors[i]);

result.swap(m_pendingMessages[i]);
++m_messageBatchesInFlight;
protectedThis = WTF::move(m_pendingMessageProtectors[i]);
}
}

++m_messageBatchesInFlight;
// Invoke callback outside the lock to avoid potential deadlocks
if (isEmpty) {
callback({}, [] {});
return;
}

// LOG(MessagePorts, "There are %zu messages to take for port %s. Taking them now, messages in flight is now %" PRIu64, result.size(), port.logString().utf8().data(), m_messageBatchesInFlight);

callback(WTF::move(result), [this, port, protectedThis = WTF::move(m_pendingMessageProtectors[i])] {
callback(WTF::move(result), [this, port, protectedThis = WTF::move(protectedThis)] {
UNUSED_PARAM(port);
Locker locker { m_lock };
--m_messageBatchesInFlight;
// LOG(MessagePorts, "Message port channel %s was notified that a batch of %zu message port messages targeted for port %s just completed dispatch, in flight is now %" PRIu64, logString().utf8().data(), size, port.logString().utf8().data(), m_messageBatchesInFlight);
});
Expand All @@ -166,6 +185,7 @@ std::optional<MessageWithMessagePorts> MessagePortChannel::tryTakeMessageForPort
ASSERT(port == m_ports[0] || port == m_ports[1]);
size_t i = port == m_ports[0] ? 0 : 1;

Locker locker { m_lock };
if (m_pendingMessages[i].isEmpty())
return std::nullopt;

Expand All @@ -174,4 +194,16 @@ std::optional<MessageWithMessagePorts> MessagePortChannel::tryTakeMessageForPort
return WTF::move(message);
}

bool MessagePortChannel::hasAnyMessagesPendingOrInFlight() const
{
Locker locker { m_lock };
return !m_pendingMessages[0].isEmpty() || !m_pendingMessages[1].isEmpty() || m_messageBatchesInFlight > 0;
}

uint64_t MessagePortChannel::beingTransferredCount() const
{
Locker locker { m_lock };
return m_pendingMessagePortTransfers[0].size() + m_pendingMessagePortTransfers[1].size();
}

} // namespace WebCore
25 changes: 14 additions & 11 deletions src/bun.js/bindings/webcore/MessagePortChannel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
#include <wtf/HashSet.h>
#include <wtf/RefCounted.h>
#include <wtf/text/WTFString.h>
#include <wtf/RefCountedAndCanMakeWeakPtr.h>
#include <wtf/ThreadSafeRefCounted.h>
#include <wtf/ThreadSafeWeakPtr.h>
#include <wtf/Lock.h>

namespace WebCore {

class MessagePortChannelRegistry;

class MessagePortChannel : public RefCountedAndCanMakeWeakPtr<MessagePortChannel> {
class MessagePortChannel : public ThreadSafeRefCountedAndCanMakeThreadSafeWeakPtr<MessagePortChannel> {
public:
static Ref<MessagePortChannel> create(MessagePortChannelRegistry&, const MessagePortIdentifier& port1, const MessagePortIdentifier& port2);

Expand All @@ -59,7 +61,7 @@ class MessagePortChannel : public RefCountedAndCanMakeWeakPtr<MessagePortChannel

WEBCORE_EXPORT bool hasAnyMessagesPendingOrInFlight() const;

uint64_t beingTransferredCount();
uint64_t beingTransferredCount() const;

#if !LOG_DISABLED
String logString() const
Expand All @@ -72,14 +74,15 @@ class MessagePortChannel : public RefCountedAndCanMakeWeakPtr<MessagePortChannel
MessagePortChannel(MessagePortChannelRegistry&, const MessagePortIdentifier& port1, const MessagePortIdentifier& port2);

MessagePortIdentifier m_ports[2];
bool m_isClosed[2] { false, false };
std::optional<ProcessIdentifier> m_processes[2];
RefPtr<MessagePortChannel> m_entangledToProcessProtectors[2];
Vector<MessageWithMessagePorts> m_pendingMessages[2];
UncheckedKeyHashSet<RefPtr<MessagePortChannel>> m_pendingMessagePortTransfers[2];
RefPtr<MessagePortChannel> m_pendingMessageProtectors[2];
uint64_t m_messageBatchesInFlight { 0 };

bool m_isClosed[2] WTF_GUARDED_BY_LOCK(m_lock) { false, false };
std::optional<ProcessIdentifier> m_processes[2] WTF_GUARDED_BY_LOCK(m_lock);
RefPtr<MessagePortChannel> m_entangledToProcessProtectors[2] WTF_GUARDED_BY_LOCK(m_lock);
Vector<MessageWithMessagePorts> m_pendingMessages[2] WTF_GUARDED_BY_LOCK(m_lock);
UncheckedKeyHashSet<RefPtr<MessagePortChannel>> m_pendingMessagePortTransfers[2] WTF_GUARDED_BY_LOCK(m_lock);
RefPtr<MessagePortChannel> m_pendingMessageProtectors[2] WTF_GUARDED_BY_LOCK(m_lock);
uint64_t m_messageBatchesInFlight WTF_GUARDED_BY_LOCK(m_lock) { 0 };

mutable Lock m_lock;
MessagePortChannelRegistry& m_registry;
};

Expand Down
77 changes: 65 additions & 12 deletions src/bun.js/bindings/webcore/MessagePortChannelRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ void MessagePortChannelRegistry::messagePortChannelCreated(MessagePortChannel& c
{
// ASSERT(isMainThread());

Locker locker { m_lock };

// When a channel is destroyed, its ThreadSafeWeakPtr becomes null but the map entry may still exist.
// Clean up any stale entries before adding new channels with the same port identifiers.
RefPtr existingChannel1 = m_openChannels.get(channel.port1()).get();
if (!existingChannel1)
m_openChannels.remove(channel.port1());

RefPtr existingChannel2 = m_openChannels.get(channel.port2()).get();
if (!existingChannel2)
m_openChannels.remove(channel.port2());

auto result = m_openChannels.add(channel.port1(), channel);
ASSERT_UNUSED(result, result.isNewEntry);

Expand All @@ -69,11 +81,20 @@ void MessagePortChannelRegistry::messagePortChannelDestroyed(MessagePortChannel&
{
// ASSERT(isMainThread());

ASSERT(m_openChannels.get(channel.port1()) == &channel);
ASSERT(m_openChannels.get(channel.port2()) == &channel);
Locker locker { m_lock };
Comment thread
xentobias marked this conversation as resolved.

// The channel might have already been removed from m_openChannels if both ports
// were closed in quick succession from different threads. With ThreadSafeWeakPtr, the entries
// may still exist but point to null, or may have been removed entirely.
// We defensively remove the entries without asserting they match.
RefPtr existingChannel1 = m_openChannels.get(channel.port1()).get();
RefPtr existingChannel2 = m_openChannels.get(channel.port2()).get();

m_openChannels.remove(channel.port1());
m_openChannels.remove(channel.port2());
// Only remove if the entry points to this channel (or is null/stale)
if (!existingChannel1 || existingChannel1.get() == &channel)
m_openChannels.remove(channel.port1());
if (!existingChannel2 || existingChannel2.get() == &channel)
m_openChannels.remove(channel.port2());

// LOG(MessagePorts, "Registry: After removing channel %s there are %u channels left in the registry:", channel.logString().utf8().data(), m_openChannels.size());
}
Expand All @@ -83,7 +104,12 @@ void MessagePortChannelRegistry::didEntangleLocalToRemote(const MessagePortIdent
// ASSERT(isMainThread());

// The channel might be gone if the remote side was closed.
RefPtr channel = m_openChannels.get(local);
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };
channel = m_openChannels.get(local).get();
}

if (!channel)
return;

Expand All @@ -97,7 +123,13 @@ void MessagePortChannelRegistry::didDisentangleMessagePort(const MessagePortIden
// ASSERT(isMainThread());

// The channel might be gone if the remote side was closed.
if (RefPtr channel = m_openChannels.get(port))
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };
channel = m_openChannels.get(port).get();
}

if (channel)
channel->disentanglePort(port);
}

Expand All @@ -107,7 +139,12 @@ void MessagePortChannelRegistry::didCloseMessagePort(const MessagePortIdentifier

// LOG(MessagePorts, "Registry: MessagePort %s closed in registry", port.logString().utf8().data());

RefPtr channel = m_openChannels.get(port);
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };
channel = m_openChannels.get(port).get();
}

if (!channel)
return;

Expand All @@ -129,7 +166,12 @@ bool MessagePortChannelRegistry::didPostMessageToRemote(MessageWithMessagePorts&
// LOG(MessagePorts, "Registry: Posting message to MessagePort %s in registry", remoteTarget.logString().utf8().data());

// The channel might be gone if the remote side was closed.
RefPtr channel = m_openChannels.get(remoteTarget);
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };
channel = m_openChannels.get(remoteTarget).get();
}

if (!channel) {
// LOG(MessagePorts, "Registry: Could not find MessagePortChannel for port %s; It was probably closed. Message will be dropped.", remoteTarget.logString().utf8().data());
return false;
Expand All @@ -143,7 +185,12 @@ void MessagePortChannelRegistry::takeAllMessagesForPort(const MessagePortIdentif
// ASSERT(isMainThread());

// The channel might be gone if the remote side was closed.
RefPtr channel = m_openChannels.get(port);
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };
channel = m_openChannels.get(port).get();
}

if (!channel) {
callback({}, [] {});
return;
Expand All @@ -159,18 +206,24 @@ std::optional<MessageWithMessagePorts> MessagePortChannelRegistry::tryTakeMessag
// LOG(MessagePorts, "Registry: Trying to take a message for MessagePort %s", port.logString().utf8().data());

// The channel might be gone if the remote side was closed.
auto* channel = m_openChannels.get(port);
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };
channel = m_openChannels.get(port).get();
}

if (!channel)
return std::nullopt;

return channel->tryTakeMessageForPort(port);
}

MessagePortChannel* MessagePortChannelRegistry::existingChannelContainingPort(const MessagePortIdentifier& port)
RefPtr<MessagePortChannel> MessagePortChannelRegistry::existingChannelContainingPort(const MessagePortIdentifier& port)
{
// ASSERT(isMainThread());

return m_openChannels.get(port);
Locker locker { m_lock };
return m_openChannels.get(port).get();
}

} // namespace WebCore
7 changes: 5 additions & 2 deletions src/bun.js/bindings/webcore/MessagePortChannelRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "ProcessIdentifier.h"
#include <wtf/HashMap.h>
#include <wtf/CheckedRef.h>
#include <wtf/Lock.h>
#include <wtf/ThreadSafeWeakPtr.h>

namespace WebCore {

Expand All @@ -51,13 +53,14 @@ class MessagePortChannelRegistry final : public CanMakeWeakPtr<MessagePortChanne
WEBCORE_EXPORT void takeAllMessagesForPort(const MessagePortIdentifier&, CompletionHandler<void(Vector<MessageWithMessagePorts>&&, CompletionHandler<void()>&&)>&&);
WEBCORE_EXPORT std::optional<MessageWithMessagePorts> tryTakeMessageForPort(const MessagePortIdentifier&);

WEBCORE_EXPORT MessagePortChannel* existingChannelContainingPort(const MessagePortIdentifier&);
WEBCORE_EXPORT RefPtr<MessagePortChannel> existingChannelContainingPort(const MessagePortIdentifier&);

WEBCORE_EXPORT void messagePortChannelCreated(MessagePortChannel&);
WEBCORE_EXPORT void messagePortChannelDestroyed(MessagePortChannel&);

private:
UncheckedKeyHashMap<MessagePortIdentifier, WeakRef<MessagePortChannel>> m_openChannels;
UncheckedKeyHashMap<MessagePortIdentifier, ThreadSafeWeakPtr<MessagePortChannel>> m_openChannels WTF_GUARDED_BY_LOCK(m_lock);
Lock m_lock;
};

} // namespace WebCore
Loading