Skip to content
Closed
30 changes: 6 additions & 24 deletions src/bun.js/bindings/webcore/MessagePortChannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ MessagePortChannel::MessagePortChannel(MessagePortChannelRegistry& registry, con
: m_ports { port1, port2 }
, m_registry(registry)
{
relaxAdoptionRequirement();

m_processes[0] = port1.processIdentifier;
m_entangledToProcessProtectors[0] = this;
m_processes[1] = port2.processIdentifier;
Expand Down Expand Up @@ -92,10 +90,7 @@ void MessagePortChannel::disentanglePort(const MessagePortIdentifier& port)
ASSERT(m_processes[i] || m_isClosed[i]);
m_processes[i] = std::nullopt;
m_pendingMessagePortTransfers[i].add(this);

// This set of steps is to guarantee that the lock is unlocked before the
// last ref to this object is released.
auto protectedThis = WTF::move(m_entangledToProcessProtectors[i]);
m_entangledToProcessProtectors[i] = nullptr;
}

void MessagePortChannel::closePort(const MessagePortIdentifier& port)
Expand All @@ -106,10 +101,6 @@ void MessagePortChannel::closePort(const MessagePortIdentifier& port)
m_processes[i] = std::nullopt;
m_isClosed[i] = true;

// This set of steps is to guarantee that the lock is unlocked before the
// last ref to this object is released.
Ref protectedThis { *this };

m_pendingMessages[i].clear();
m_pendingMessagePortTransfers[i].clear();
m_pendingMessageProtectors[i] = nullptr;
Expand All @@ -136,32 +127,23 @@ bool MessagePortChannel::postMessageToRemote(MessageWithMessagePorts&& message,
return false;
}

void MessagePortChannel::takeAllMessagesForPort(const MessagePortIdentifier& port, CompletionHandler<void(Vector<MessageWithMessagePorts>&&, CompletionHandler<void()>&&)>&& callback)
Vector<MessageWithMessagePorts> MessagePortChannel::takeAllMessagesForPort(const MessagePortIdentifier& port)
{
// LOG(MessagePorts, "MessagePortChannel %p taking all messages for port %s", this, port.logString().utf8().data());

ASSERT(port == m_ports[0] || port == m_ports[1]);
size_t i = port == m_ports[0] ? 0 : 1;

if (m_pendingMessages[i].isEmpty()) {
callback({}, [] {});
return;
}
if (m_pendingMessages[i].isEmpty())
return {};

ASSERT(m_pendingMessageProtectors[i]);

Vector<MessageWithMessagePorts> result;
result.swap(m_pendingMessages[i]);
m_pendingMessageProtectors[i] = nullptr;

++m_messageBatchesInFlight;

// LOG(MessagePorts, "There are %zu messages to take for port %s. Taking them now, messages in flight is now %" PRIu64, result.size(), port.logString().utf8().data(), m_messageBatchesInFlight);

callback(WTF::move(result), [this, port, protectedThis = WTF::move(m_pendingMessageProtectors[i])] {
UNUSED_PARAM(port);
--m_messageBatchesInFlight;
// LOG(MessagePorts, "Message port channel %s was notified that a batch of %zu message port messages targeted for port %s just completed dispatch, in flight is now %" PRIu64, logString().utf8().data(), size, port.logString().utf8().data(), m_messageBatchesInFlight);
});
return result;
}

std::optional<MessageWithMessagePorts> MessagePortChannel::tryTakeMessageForPort(const MessagePortIdentifier port)
Comment thread
claude[bot] marked this conversation as resolved.
Expand Down
15 changes: 6 additions & 9 deletions src/bun.js/bindings/webcore/MessagePortChannel.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@
#include "MessageWithMessagePorts.h"
#include "ProcessIdentifier.h"
#include <wtf/HashSet.h>
#include <wtf/RefCounted.h>
#include <wtf/ThreadSafeWeakPtr.h>
#include <wtf/text/WTFString.h>
#include <wtf/RefCountedAndCanMakeWeakPtr.h>

namespace WebCore {

class MessagePortChannelRegistry;

class MessagePortChannel : public RefCountedAndCanMakeWeakPtr<MessagePortChannel> {
// In WebKit this is RefCountedAndCanMakeWeakPtr because the registry is main-thread-only.
// Bun serializes registry/channel access with a Lock instead (MessagePortChannelRegistry::m_lock),
// so the refcount and weak control block must be atomic — RefPtrs can be released on any thread.
class MessagePortChannel : public ThreadSafeRefCountedAndCanMakeThreadSafeWeakPtr<MessagePortChannel> {
public:
static Ref<MessagePortChannel> create(MessagePortChannelRegistry&, const MessagePortIdentifier& port1, const MessagePortIdentifier& port2);

Expand All @@ -54,13 +56,9 @@ class MessagePortChannel : public RefCountedAndCanMakeWeakPtr<MessagePortChannel
void closePort(const MessagePortIdentifier&);
bool postMessageToRemote(MessageWithMessagePorts&&, const MessagePortIdentifier& remoteTarget);

void takeAllMessagesForPort(const MessagePortIdentifier&, CompletionHandler<void(Vector<MessageWithMessagePorts>&&, CompletionHandler<void()>&&)>&&);
Vector<MessageWithMessagePorts> takeAllMessagesForPort(const MessagePortIdentifier&);
std::optional<MessageWithMessagePorts> tryTakeMessageForPort(const MessagePortIdentifier);

WEBCORE_EXPORT bool hasAnyMessagesPendingOrInFlight() const;

uint64_t beingTransferredCount();

#if !LOG_DISABLED
String logString() const
{
Expand All @@ -78,7 +76,6 @@ class MessagePortChannel : public RefCountedAndCanMakeWeakPtr<MessagePortChannel
Vector<MessageWithMessagePorts> m_pendingMessages[2];
UncheckedKeyHashSet<RefPtr<MessagePortChannel>> m_pendingMessagePortTransfers[2];
RefPtr<MessagePortChannel> m_pendingMessageProtectors[2];
uint64_t m_messageBatchesInFlight { 0 };

MessagePortChannelRegistry& m_registry;
};
Expand Down
2 changes: 0 additions & 2 deletions src/bun.js/bindings/webcore/MessagePortChannelProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ static MessagePortChannelProviderImpl* globalProvider;

MessagePortChannelProvider& MessagePortChannelProvider::singleton()
{
// TODO: I think this assertion is relevant. Bun will call this on the Worker's thread
// ASSERT(isMainThread());
static std::once_flag onceFlag;
std::call_once(onceFlag, [] {
if (!globalProvider)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,8 @@ void MessagePortChannelProviderImpl::postMessageToRemote(MessageWithMessagePorts
MessagePort::notifyMessageAvailable(remoteTarget);
}

void MessagePortChannelProviderImpl::takeAllMessagesForPort(const MessagePortIdentifier& port, CompletionHandler<void(Vector<MessageWithMessagePorts>&&, CompletionHandler<void()>&&)>&& outerCallback)
void MessagePortChannelProviderImpl::takeAllMessagesForPort(const MessagePortIdentifier& port, CompletionHandler<void(Vector<MessageWithMessagePorts>&&, CompletionHandler<void()>&&)>&& callback)
{
// It is the responsibility of outerCallback to get itself to the appropriate thread (e.g. WebWorker thread)
auto callback = [outerCallback = WTF::move(outerCallback)](Vector<MessageWithMessagePorts>&& messages, CompletionHandler<void()>&& messageDeliveryCallback) mutable {
// ASSERT(isMainThread());
outerCallback(WTF::move(messages), WTF::move(messageDeliveryCallback));
};

m_registry.takeAllMessagesForPort(port, WTF::move(callback));
}

Expand Down
129 changes: 69 additions & 60 deletions src/bun.js/bindings/webcore/MessagePortChannelRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@

// #include "Logging.h"
#include <wtf/CompletionHandler.h>
#include <wtf/MainThread.h>

// ASSERT(isMainThread()) is used alot here, and I think it may be required, but i'm not 100% sure.
// we totally are calling these off the main thread in many cases in Bun, so ........
#include <wtf/Locker.h>

namespace WebCore {

Expand All @@ -49,14 +46,14 @@ MessagePortChannelRegistry::~MessagePortChannelRegistry()
void MessagePortChannelRegistry::didCreateMessagePortChannel(const MessagePortIdentifier& port1, const MessagePortIdentifier& port2)
{
// LOG(MessagePorts, "Registry: Creating MessagePortChannel %p linking %s and %s", this, port1.logString().utf8().data(), port2.logString().utf8().data());
// ASSERT(isMainThread());

// No lock here: the channel constructor calls back into messagePortChannelCreated() which locks.
MessagePortChannel::create(*this, port1, port2);
}

void MessagePortChannelRegistry::messagePortChannelCreated(MessagePortChannel& channel)
{
// ASSERT(isMainThread());
Locker locker { m_lock };

auto result = m_openChannels.add(channel.port1(), channel);
ASSERT_UNUSED(result, result.isNewEntry);
Expand All @@ -67,10 +64,7 @@ void MessagePortChannelRegistry::messagePortChannelCreated(MessagePortChannel& c

void MessagePortChannelRegistry::messagePortChannelDestroyed(MessagePortChannel& channel)
{
// ASSERT(isMainThread());

ASSERT(m_openChannels.get(channel.port1()) == &channel);
ASSERT(m_openChannels.get(channel.port2()) == &channel);
Locker locker { m_lock };

m_openChannels.remove(channel.port1());
m_openChannels.remove(channel.port2());
Expand All @@ -80,97 +74,112 @@ void MessagePortChannelRegistry::messagePortChannelDestroyed(MessagePortChannel&

void MessagePortChannelRegistry::didEntangleLocalToRemote(const MessagePortIdentifier& local, const MessagePortIdentifier& remote, ProcessIdentifier process)
{
// ASSERT(isMainThread());
// The channel RefPtr must outlive the lock so its destructor (which re-enters
// messagePortChannelDestroyed and locks) cannot deadlock.
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };

// The channel might be gone if the remote side was closed.
RefPtr channel = m_openChannels.get(local);
if (!channel)
return;
// The channel might be gone if the remote side was closed.
channel = m_openChannels.get(local).get();
if (!channel)
return;

ASSERT_UNUSED(remote, channel->includesPort(remote));
ASSERT_UNUSED(remote, channel->includesPort(remote));

channel->entanglePortWithProcess(local, process);
channel->entanglePortWithProcess(local, process);
}
}

void MessagePortChannelRegistry::didDisentangleMessagePort(const MessagePortIdentifier& port)
{
// ASSERT(isMainThread());
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };

// The channel might be gone if the remote side was closed.
channel = m_openChannels.get(port).get();
if (!channel)
return;

// The channel might be gone if the remote side was closed.
if (RefPtr channel = m_openChannels.get(port))
channel->disentanglePort(port);
}
}

void MessagePortChannelRegistry::didCloseMessagePort(const MessagePortIdentifier& port)
{
// ASSERT(isMainThread());

// LOG(MessagePorts, "Registry: MessagePort %s closed in registry", port.logString().utf8().data());

RefPtr channel = m_openChannels.get(port);
if (!channel)
return;
RefPtr<MessagePortChannel> channel;
{
Locker locker { m_lock };

#ifndef NDEBUG
// if (channel && channel->hasAnyMessagesPendingOrInFlight())
// LOG(MessagePorts, "Registry: (Note) The channel closed for port %s had messages pending or in flight", port.logString().utf8().data());
#endif
channel = m_openChannels.get(port).get();
if (!channel)
return;

channel->closePort(port);
channel->closePort(port);
}

// FIXME: When making message ports be multi-process, this should probably push a notification
// to the remaining port to tell it this port closed.
}

bool MessagePortChannelRegistry::didPostMessageToRemote(MessageWithMessagePorts&& message, const MessagePortIdentifier& remoteTarget)
{
// ASSERT(isMainThread());

// LOG(MessagePorts, "Registry: Posting message to MessagePort %s in registry", remoteTarget.logString().utf8().data());

// The channel might be gone if the remote side was closed.
RefPtr channel = m_openChannels.get(remoteTarget);
if (!channel) {
// LOG(MessagePorts, "Registry: Could not find MessagePortChannel for port %s; It was probably closed. Message will be dropped.", remoteTarget.logString().utf8().data());
return false;
}
RefPtr<MessagePortChannel> channel;
bool result;
{
Locker locker { m_lock };

return channel->postMessageToRemote(WTF::move(message), remoteTarget);
// The channel might be gone if the remote side was closed.
channel = m_openChannels.get(remoteTarget).get();
if (!channel) {
// LOG(MessagePorts, "Registry: Could not find MessagePortChannel for port %s; It was probably closed. Message will be dropped.", remoteTarget.logString().utf8().data());
return false;
}

result = channel->postMessageToRemote(WTF::move(message), remoteTarget);
}
return result;
}

void MessagePortChannelRegistry::takeAllMessagesForPort(const MessagePortIdentifier& port, CompletionHandler<void(Vector<MessageWithMessagePorts>&&, CompletionHandler<void()>&&)>&& callback)
{
// ASSERT(isMainThread());

// The channel might be gone if the remote side was closed.
RefPtr channel = m_openChannels.get(port);
if (!channel) {
callback({}, [] {});
return;
RefPtr<MessagePortChannel> channel;
Vector<MessageWithMessagePorts> messages;
{
Locker locker { m_lock };

// The channel might be gone if the remote side was closed.
channel = m_openChannels.get(port).get();
if (channel)
messages = channel->takeAllMessagesForPort(port);
}

channel->takeAllMessagesForPort(port, WTF::move(callback));
// Invoked outside the lock: the callback re-enters the registry via MessagePort::entanglePorts.
callback(WTF::move(messages), [] {});
}

std::optional<MessageWithMessagePorts> MessagePortChannelRegistry::tryTakeMessageForPort(const MessagePortIdentifier& port)
{
// ASSERT(isMainThread());

// LOG(MessagePorts, "Registry: Trying to take a message for MessagePort %s", port.logString().utf8().data());

// The channel might be gone if the remote side was closed.
auto* channel = m_openChannels.get(port);
if (!channel)
return std::nullopt;
RefPtr<MessagePortChannel> channel;
std::optional<MessageWithMessagePorts> result;
{
Locker locker { m_lock };

return channel->tryTakeMessageForPort(port);
}

MessagePortChannel* MessagePortChannelRegistry::existingChannelContainingPort(const MessagePortIdentifier& port)
{
// ASSERT(isMainThread());
// The channel might be gone if the remote side was closed.
channel = m_openChannels.get(port).get();
if (!channel)
return std::nullopt;

return m_openChannels.get(port);
result = channel->tryTakeMessageForPort(port);
}
return result;
}

} // namespace WebCore
10 changes: 7 additions & 3 deletions src/bun.js/bindings/webcore/MessagePortChannelRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "ProcessIdentifier.h"
#include <wtf/HashMap.h>
#include <wtf/CheckedRef.h>
#include <wtf/Lock.h>

namespace WebCore {

Expand All @@ -51,13 +52,16 @@ class MessagePortChannelRegistry final : public CanMakeWeakPtr<MessagePortChanne
WEBCORE_EXPORT void takeAllMessagesForPort(const MessagePortIdentifier&, CompletionHandler<void(Vector<MessageWithMessagePorts>&&, CompletionHandler<void()>&&)>&&);
WEBCORE_EXPORT std::optional<MessageWithMessagePorts> tryTakeMessageForPort(const MessagePortIdentifier&);

WEBCORE_EXPORT MessagePortChannel* existingChannelContainingPort(const MessagePortIdentifier&);

WEBCORE_EXPORT void messagePortChannelCreated(MessagePortChannel&);
WEBCORE_EXPORT void messagePortChannelDestroyed(MessagePortChannel&);

private:
UncheckedKeyHashMap<MessagePortIdentifier, WeakRef<MessagePortChannel>> m_openChannels;
// WebKit guarantees single-threaded access via ASSERT(isMainThread()) and routes worker calls through
// WorkerMessagePortChannelProvider → callOnMainThread. Bun has no equivalent main-thread runloop and
// additionally needs synchronous receiveMessageOnPort() from any thread, so we serialize with a lock
// instead. All MessagePortChannel state mutation happens via this registry and is covered by m_lock.
Lock m_lock;
UncheckedKeyHashMap<MessagePortIdentifier, ThreadSafeWeakPtr<MessagePortChannel>> m_openChannels WTF_GUARDED_BY_LOCK(m_lock);
};

} // namespace WebCore
3 changes: 0 additions & 3 deletions src/bun.js/bindings/webcore/Performance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ MonotonicTime Performance::monotonicTimeFromRelativeTime(DOMHighResTimeStamp rel

PerformanceTiming* Performance::timing()
{
// if (!is<Document>(scriptExecutionContext()))
// return nullptr;
// ASSERT(isMainThread());
if (!m_timing)
m_timing = PerformanceTiming::create();
return m_timing.get();
Expand Down
Loading
Loading