508 lines
17 KiB
C++
508 lines
17 KiB
C++
#ifndef SOCKET_UWS_H
|
|
#define SOCKET_UWS_H
|
|
|
|
#include "Networking.h"
|
|
|
|
namespace uS {
|
|
|
|
struct TransferData {
|
|
// Connection state
|
|
uv_os_sock_t fd;
|
|
SSL *ssl;
|
|
|
|
// Poll state
|
|
void (*pollCb)(Poll *, int, int);
|
|
int pollEvents;
|
|
|
|
// User state
|
|
void *userData;
|
|
|
|
// Destination
|
|
NodeData *destination;
|
|
void (*transferCb)(Poll *);
|
|
};
|
|
|
|
// perfectly 64 bytes (4 + 60)
|
|
struct WIN32_EXPORT Socket : Poll {
|
|
protected:
|
|
struct {
|
|
int poll : 4;
|
|
int shuttingDown : 4;
|
|
} state = {0, false};
|
|
|
|
SSL *ssl;
|
|
void *user = nullptr;
|
|
NodeData *nodeData;
|
|
|
|
// this is not needed by HttpSocket!
|
|
struct Queue {
|
|
struct Message {
|
|
const char *data;
|
|
size_t length;
|
|
Message *nextMessage = nullptr;
|
|
void (*callback)(void *socket, void *data, bool cancelled, void *reserved) = nullptr;
|
|
void *callbackData = nullptr, *reserved = nullptr;
|
|
};
|
|
|
|
Message *head = nullptr, *tail = nullptr;
|
|
void pop()
|
|
{
|
|
Message *nextMessage;
|
|
if ((nextMessage = head->nextMessage)) {
|
|
delete [] (char *) head;
|
|
head = nextMessage;
|
|
} else {
|
|
delete [] (char *) head;
|
|
head = tail = nullptr;
|
|
}
|
|
}
|
|
|
|
bool empty() {return head == nullptr;}
|
|
Message *front() {return head;}
|
|
|
|
void push(Message *message)
|
|
{
|
|
message->nextMessage = nullptr;
|
|
if (tail) {
|
|
tail->nextMessage = message;
|
|
tail = message;
|
|
} else {
|
|
head = message;
|
|
tail = message;
|
|
}
|
|
}
|
|
} messageQueue;
|
|
|
|
int getPoll() {
|
|
return state.poll;
|
|
}
|
|
|
|
int setPoll(int poll) {
|
|
state.poll = poll;
|
|
return poll;
|
|
}
|
|
|
|
void setShuttingDown(bool shuttingDown) {
|
|
state.shuttingDown = shuttingDown;
|
|
}
|
|
|
|
void transfer(NodeData *nodeData, void (*cb)(Poll *)) {
|
|
// userData is invalid from now on till onTransfer
|
|
setUserData(new TransferData({getFd(), ssl, getCb(), getPoll(), getUserData(), nodeData, cb}));
|
|
stop(this->nodeData->loop);
|
|
close(this->nodeData->loop, [](Poll *p) {
|
|
Socket *s = (Socket *) p;
|
|
TransferData *transferData = (TransferData *) s->getUserData();
|
|
|
|
transferData->destination->asyncMutex->lock();
|
|
bool wasEmpty = transferData->destination->transferQueue.empty();
|
|
transferData->destination->transferQueue.push_back(s);
|
|
transferData->destination->asyncMutex->unlock();
|
|
|
|
if (wasEmpty) {
|
|
transferData->destination->async->send();
|
|
}
|
|
});
|
|
}
|
|
|
|
void changePoll(Socket *socket) {
|
|
if (!threadSafeChange(nodeData->loop, this, socket->getPoll())) {
|
|
if (socket->nodeData->tid != pthread_self()) {
|
|
socket->nodeData->asyncMutex->lock();
|
|
socket->nodeData->changePollQueue.push_back(socket);
|
|
socket->nodeData->asyncMutex->unlock();
|
|
socket->nodeData->async->send();
|
|
} else {
|
|
change(socket->nodeData->loop, socket, socket->getPoll());
|
|
}
|
|
}
|
|
}
|
|
|
|
// clears user data!
|
|
template <void onTimeout(Socket *)>
|
|
void startTimeout(int timeoutMs = 15000) {
|
|
Timer *timer = new Timer(nodeData->loop);
|
|
timer->setData(this);
|
|
timer->start([](Timer *timer) {
|
|
Socket *s = (Socket *) timer->getData();
|
|
s->cancelTimeout();
|
|
onTimeout(s);
|
|
}, timeoutMs, 0);
|
|
|
|
user = timer;
|
|
}
|
|
|
|
void cancelTimeout() {
|
|
Timer *timer = (Timer *) getUserData();
|
|
if (timer) {
|
|
timer->stop();
|
|
timer->close();
|
|
user = nullptr;
|
|
}
|
|
}
|
|
|
|
template <class STATE>
|
|
static void sslIoHandler(Poll *p, int status, int events) {
|
|
Socket *socket = (Socket *) p;
|
|
|
|
if (status < 0) {
|
|
STATE::onEnd((Socket *) p);
|
|
return;
|
|
}
|
|
|
|
if (!socket->messageQueue.empty() && ((events & UV_WRITABLE) || SSL_want(socket->ssl) == SSL_READING)) {
|
|
socket->cork(true);
|
|
while (true) {
|
|
Queue::Message *messagePtr = socket->messageQueue.front();
|
|
int sent = SSL_write(socket->ssl, messagePtr->data, messagePtr->length);
|
|
if (sent == (ssize_t) messagePtr->length) {
|
|
if (messagePtr->callback) {
|
|
messagePtr->callback(p, messagePtr->callbackData, false, messagePtr->reserved);
|
|
}
|
|
socket->messageQueue.pop();
|
|
if (socket->messageQueue.empty()) {
|
|
if ((socket->state.poll & UV_WRITABLE) && SSL_want(socket->ssl) != SSL_WRITING) {
|
|
socket->change(socket->nodeData->loop, socket, socket->setPoll(UV_READABLE));
|
|
}
|
|
break;
|
|
}
|
|
} else if (sent <= 0) {
|
|
switch (SSL_get_error(socket->ssl, sent)) {
|
|
case SSL_ERROR_WANT_READ:
|
|
break;
|
|
case SSL_ERROR_WANT_WRITE:
|
|
if ((socket->getPoll() & UV_WRITABLE) == 0) {
|
|
socket->change(socket->nodeData->loop, socket, socket->setPoll(socket->getPoll() | UV_WRITABLE));
|
|
}
|
|
break;
|
|
default:
|
|
STATE::onEnd((Socket *) p);
|
|
return;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
socket->cork(false);
|
|
}
|
|
|
|
if (events & UV_READABLE) {
|
|
do {
|
|
int length = SSL_read(socket->ssl, socket->nodeData->recvBuffer, socket->nodeData->recvLength);
|
|
if (length <= 0) {
|
|
switch (SSL_get_error(socket->ssl, length)) {
|
|
case SSL_ERROR_WANT_READ:
|
|
break;
|
|
case SSL_ERROR_WANT_WRITE:
|
|
if ((socket->getPoll() & UV_WRITABLE) == 0) {
|
|
socket->change(socket->nodeData->loop, socket, socket->setPoll(socket->getPoll() | UV_WRITABLE));
|
|
}
|
|
break;
|
|
default:
|
|
STATE::onEnd((Socket *) p);
|
|
return;
|
|
}
|
|
break;
|
|
} else {
|
|
// Warning: onData can delete the socket! Happens when HttpSocket upgrades
|
|
socket = STATE::onData((Socket *) p, socket->nodeData->recvBuffer, length);
|
|
if (socket->isClosed() || socket->isShuttingDown()) {
|
|
return;
|
|
}
|
|
}
|
|
} while (SSL_pending(socket->ssl));
|
|
}
|
|
}
|
|
|
|
template <class STATE>
|
|
static void ioHandler(Poll *p, int status, int events) {
|
|
Socket *socket = (Socket *) p;
|
|
NodeData *nodeData = socket->nodeData;
|
|
Context *netContext = nodeData->netContext;
|
|
|
|
if (status < 0) {
|
|
STATE::onEnd((Socket *) p);
|
|
return;
|
|
}
|
|
|
|
if (events & UV_WRITABLE) {
|
|
if (!socket->messageQueue.empty() && (events & UV_WRITABLE)) {
|
|
socket->cork(true);
|
|
while (true) {
|
|
Queue::Message *messagePtr = socket->messageQueue.front();
|
|
ssize_t sent = ::send(socket->getFd(), messagePtr->data, messagePtr->length, MSG_NOSIGNAL);
|
|
if (sent == (ssize_t) messagePtr->length) {
|
|
if (messagePtr->callback) {
|
|
messagePtr->callback(p, messagePtr->callbackData, false, messagePtr->reserved);
|
|
}
|
|
socket->messageQueue.pop();
|
|
if (socket->messageQueue.empty()) {
|
|
// todo, remove bit, don't set directly
|
|
socket->change(socket->nodeData->loop, socket, socket->setPoll(UV_READABLE));
|
|
break;
|
|
}
|
|
} else if (sent == SOCKET_ERROR) {
|
|
if (!netContext->wouldBlock()) {
|
|
STATE::onEnd((Socket *) p);
|
|
return;
|
|
}
|
|
break;
|
|
} else {
|
|
messagePtr->length -= sent;
|
|
messagePtr->data += sent;
|
|
break;
|
|
}
|
|
}
|
|
socket->cork(false);
|
|
}
|
|
}
|
|
|
|
if (events & UV_READABLE) {
|
|
int length = recv(socket->getFd(), nodeData->recvBuffer, nodeData->recvLength, 0);
|
|
if (length > 0) {
|
|
STATE::onData((Socket *) p, nodeData->recvBuffer, length);
|
|
} else if (length <= 0 || (length == SOCKET_ERROR && !netContext->wouldBlock())) {
|
|
STATE::onEnd((Socket *) p);
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
template<class STATE>
|
|
void setState() {
|
|
if (ssl) {
|
|
setCb(sslIoHandler<STATE>);
|
|
} else {
|
|
setCb(ioHandler<STATE>);
|
|
}
|
|
}
|
|
|
|
bool hasEmptyQueue() {
|
|
return messageQueue.empty();
|
|
}
|
|
|
|
void enqueue(Queue::Message *message) {
|
|
messageQueue.push(message);
|
|
}
|
|
|
|
Queue::Message *allocMessage(size_t length, const char *data = 0) {
|
|
Queue::Message *messagePtr = (Queue::Message *) new char[sizeof(Queue::Message) + length];
|
|
messagePtr->length = length;
|
|
messagePtr->data = ((char *) messagePtr) + sizeof(Queue::Message);
|
|
messagePtr->nextMessage = nullptr;
|
|
|
|
if (data) {
|
|
memcpy((char *) messagePtr->data, data, messagePtr->length);
|
|
}
|
|
|
|
return messagePtr;
|
|
}
|
|
|
|
void freeMessage(Queue::Message *message) {
|
|
delete [] (char *) message;
|
|
}
|
|
|
|
bool write(Queue::Message *message, bool &wasTransferred) {
|
|
ssize_t sent = 0;
|
|
if (messageQueue.empty()) {
|
|
|
|
if (ssl) {
|
|
sent = SSL_write(ssl, message->data, message->length);
|
|
if (sent == (ssize_t) message->length) {
|
|
wasTransferred = false;
|
|
return true;
|
|
} else if (sent < 0) {
|
|
switch (SSL_get_error(ssl, sent)) {
|
|
case SSL_ERROR_WANT_READ:
|
|
break;
|
|
case SSL_ERROR_WANT_WRITE:
|
|
if ((getPoll() & UV_WRITABLE) == 0) {
|
|
setPoll(getPoll() | UV_WRITABLE);
|
|
changePoll(this);
|
|
}
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
} else {
|
|
sent = ::send(getFd(), message->data, message->length, MSG_NOSIGNAL);
|
|
if (sent == (ssize_t) message->length) {
|
|
wasTransferred = false;
|
|
return true;
|
|
} else if (sent == SOCKET_ERROR) {
|
|
if (!nodeData->netContext->wouldBlock()) {
|
|
return false;
|
|
}
|
|
} else {
|
|
message->length -= sent;
|
|
message->data += sent;
|
|
}
|
|
|
|
if ((getPoll() & UV_WRITABLE) == 0) {
|
|
setPoll(getPoll() | UV_WRITABLE);
|
|
changePoll(this);
|
|
}
|
|
}
|
|
}
|
|
messageQueue.push(message);
|
|
wasTransferred = true;
|
|
return true;
|
|
}
|
|
|
|
template <class T, class D>
|
|
void sendTransformed(const char *message, size_t length, void(*callback)(void *socket, void *data, bool cancelled, void *reserved), void *callbackData, D transformData) {
|
|
size_t estimatedLength = T::estimate(message, length) + sizeof(Queue::Message);
|
|
|
|
if (hasEmptyQueue()) {
|
|
if (estimatedLength <= uS::NodeData::preAllocMaxSize) {
|
|
int memoryLength = estimatedLength;
|
|
int memoryIndex = nodeData->getMemoryBlockIndex(memoryLength);
|
|
|
|
Queue::Message *messagePtr = (Queue::Message *) nodeData->getSmallMemoryBlock(memoryIndex);
|
|
messagePtr->data = ((char *) messagePtr) + sizeof(Queue::Message);
|
|
messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData);
|
|
|
|
bool wasTransferred;
|
|
if (write(messagePtr, wasTransferred)) {
|
|
if (!wasTransferred) {
|
|
nodeData->freeSmallMemoryBlock((char *) messagePtr, memoryIndex);
|
|
if (callback) {
|
|
callback(this, callbackData, false, nullptr);
|
|
}
|
|
} else {
|
|
messagePtr->callback = callback;
|
|
messagePtr->callbackData = callbackData;
|
|
}
|
|
} else {
|
|
nodeData->freeSmallMemoryBlock((char *) messagePtr, memoryIndex);
|
|
if (callback) {
|
|
callback(this, callbackData, true, nullptr);
|
|
}
|
|
}
|
|
} else {
|
|
Queue::Message *messagePtr = allocMessage(estimatedLength - sizeof(Queue::Message));
|
|
messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData);
|
|
|
|
bool wasTransferred;
|
|
if (write(messagePtr, wasTransferred)) {
|
|
if (!wasTransferred) {
|
|
freeMessage(messagePtr);
|
|
if (callback) {
|
|
callback(this, callbackData, false, nullptr);
|
|
}
|
|
} else {
|
|
messagePtr->callback = callback;
|
|
messagePtr->callbackData = callbackData;
|
|
}
|
|
} else {
|
|
freeMessage(messagePtr);
|
|
if (callback) {
|
|
callback(this, callbackData, true, nullptr);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
Queue::Message *messagePtr = allocMessage(estimatedLength - sizeof(Queue::Message));
|
|
messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData);
|
|
messagePtr->callback = callback;
|
|
messagePtr->callbackData = callbackData;
|
|
enqueue(messagePtr);
|
|
}
|
|
}
|
|
|
|
public:
|
|
Socket(NodeData *nodeData, Loop *loop, uv_os_sock_t fd, SSL *ssl) : Poll(loop, fd), ssl(ssl), nodeData(nodeData) {
|
|
if (ssl) {
|
|
// OpenSSL treats SOCKETs as int
|
|
SSL_set_fd(ssl, (int) fd);
|
|
SSL_set_mode(ssl, SSL_MODE_RELEASE_BUFFERS);
|
|
}
|
|
}
|
|
|
|
NodeData *getNodeData() {
|
|
return nodeData;
|
|
}
|
|
|
|
Poll *next = nullptr, *prev = nullptr;
|
|
|
|
void *getUserData() {
|
|
return user;
|
|
}
|
|
|
|
void setUserData(void *user) {
|
|
this->user = user;
|
|
}
|
|
|
|
struct Address {
|
|
unsigned int port;
|
|
const char *address;
|
|
const char *family;
|
|
};
|
|
|
|
Address getAddress();
|
|
|
|
void setNoDelay(int enable) {
|
|
setsockopt(getFd(), IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
|
}
|
|
|
|
void cork(int enable) {
|
|
#if defined(TCP_CORK)
|
|
// Linux & SmartOS have proper TCP_CORK
|
|
setsockopt(getFd(), IPPROTO_TCP, TCP_CORK, &enable, sizeof(int));
|
|
#elif defined(TCP_NOPUSH)
|
|
// Mac OS X & FreeBSD have TCP_NOPUSH
|
|
setsockopt(getFd(), IPPROTO_TCP, TCP_NOPUSH, &enable, sizeof(int));
|
|
if (!enable) {
|
|
// Tested on OS X, FreeBSD situation is unclear
|
|
::send(getFd(), "", 0, MSG_NOSIGNAL);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void shutdown() {
|
|
if (ssl) {
|
|
//todo: poll in/out - have the io_cb recall shutdown if failed
|
|
SSL_shutdown(ssl);
|
|
} else {
|
|
::shutdown(getFd(), SHUT_WR);
|
|
}
|
|
}
|
|
|
|
template <class T>
|
|
void closeSocket() {
|
|
uv_os_sock_t fd = getFd();
|
|
Context *netContext = nodeData->netContext;
|
|
stop(nodeData->loop);
|
|
netContext->closeSocket(fd);
|
|
|
|
if (ssl) {
|
|
SSL_free(ssl);
|
|
}
|
|
|
|
Poll::close(nodeData->loop, [](Poll *p) {
|
|
delete (T *) p;
|
|
});
|
|
}
|
|
|
|
bool isShuttingDown() {
|
|
return state.shuttingDown;
|
|
}
|
|
|
|
friend class Node;
|
|
friend struct NodeData;
|
|
};
|
|
|
|
struct ListenSocket : Socket {
|
|
|
|
ListenSocket(NodeData *nodeData, Loop *loop, uv_os_sock_t fd, SSL *ssl) : Socket(nodeData, loop, fd, ssl) {
|
|
|
|
}
|
|
|
|
Timer *timer = nullptr;
|
|
uS::TLS::Context sslContext;
|
|
};
|
|
|
|
}
|
|
|
|
#endif // SOCKET_UWS_H
|