/* * Copyright (C) 2001-2008 Jacek Sieka, arnetheduck on gmail point com * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. */ #include "stdinc.h" #include "DCPlusPlus.h" #include "BufferedSocket.h" #include "TimerManager.h" #include "SettingsManager.h" #include "Streams.h" #include "SSLSocket.h" #include "CryptoManager.h" #include "ZUtils.h" namespace dcpp { // Polling is used for tasks...should be fixed... #define POLL_TIMEOUT 250 BufferedSocket::BufferedSocket(char aSeparator) throw(ThreadException) : separator(aSeparator), mode(MODE_LINE), dataBytes(0), rollback(0), state(STARTING), disconnecting(false) { start(); Thread::safeInc(sockets); } volatile long BufferedSocket::sockets = 0; BufferedSocket::~BufferedSocket() throw() { Thread::safeDec(sockets); } void BufferedSocket::setMode (Modes aMode, size_t aRollback) { if (mode == aMode) { dcdebug ("WARNING: Re-entering mode %d\n", mode); return; } switch (aMode) { case MODE_LINE: rollback = aRollback; break; case MODE_ZPIPE: filterIn = std::auto_ptr(new UnZFilter); break; case MODE_DATA: break; } mode = aMode; } void BufferedSocket::setSocket(std::auto_ptr s) { dcassert(!sock.get()); if(SETTING(SOCKET_IN_BUFFER) > 0) s->setSocketOpt(SO_RCVBUF, SETTING(SOCKET_IN_BUFFER)); if(SETTING(SOCKET_OUT_BUFFER) > 0) s->setSocketOpt(SO_SNDBUF, SETTING(SOCKET_OUT_BUFFER)); s->setBlocking(false); inbuf.resize(s->getSocketOptInt(SO_RCVBUF)); sock = s; } void BufferedSocket::accept(const Socket& srv, bool secure, bool allowUntrusted) throw(SocketException) { dcdebug("BufferedSocket::accept() %p\n", (void*)this); std::auto_ptr s(secure ? CryptoManager::getInstance()->getServerSocket(allowUntrusted) : new Socket); s->accept(srv); setSocket(s); Lock l(cs); addTask(ACCEPTED, 0); } void BufferedSocket::connect(const string& aAddress, uint16_t aPort, bool secure, bool allowUntrusted, bool proxy) throw(SocketException) { dcdebug("BufferedSocket::connect() %p\n", (void*)this); std::auto_ptr s(secure ? CryptoManager::getInstance()->getClientSocket(allowUntrusted) : new Socket); s->create(); s->bind(0, SETTING(BIND_ADDRESS)); setSocket(s); Lock l(cs); addTask(CONNECT, new ConnectInfo(aAddress, aPort, proxy && (SETTING(OUTGOING_CONNECTIONS) == SettingsManager::OUTGOING_SOCKS5))); } #define CONNECT_TIMEOUT 30000 void BufferedSocket::threadConnect(const string& aAddr, uint16_t aPort, bool proxy) throw(SocketException) { dcassert(state == STARTING); dcdebug("threadConnect %s:%d\n", aAddr.c_str(), (int)aPort); fire(BufferedSocketListener::Connecting()); state = RUNNING; uint64_t startTime = GET_TICK(); if(proxy) { sock->socksConnect(aAddr, aPort, CONNECT_TIMEOUT); } else { sock->connect(aAddr, aPort); } while(sock->wait(POLL_TIMEOUT, Socket::WAIT_CONNECT) != Socket::WAIT_CONNECT) { if(disconnecting) return; if((startTime + 30000) < GET_TICK()) { throw SocketException(_("Connection timeout")); } } fire(BufferedSocketListener::Connected()); } void BufferedSocket::threadRead() throw(SocketException) { if(state != RUNNING) return; int left = sock->read(&inbuf[0], (int)inbuf.size()); if(left == -1) { // EWOULDBLOCK, no data received... return; } else if(left == 0) { // This socket has been closed... throw SocketException(_("Connection closed")); } string::size_type pos = 0; // always uncompressed data string l; int bufpos = 0, total = left; while (left > 0) { switch (mode) { case MODE_ZPIPE: { const int BUF_SIZE = 1024; // Special to autodetect nmdc connections... string::size_type pos = 0; boost::scoped_array buffer(new char[BUF_SIZE]); l = line; // decompress all input data and store in l. while (left) { size_t in = BUF_SIZE; size_t used = left; bool ret = (*filterIn) (&inbuf[0] + total - left, used, &buffer[0], in); left -= used; l.append (&buffer[0], in); // if the stream ends before the data runs out, keep remainder of data in inbuf if (!ret) { bufpos = total-left; setMode (MODE_LINE, rollback); break; } } // process all lines while ((pos = l.find(separator)) != string::npos) { fire(BufferedSocketListener::Line(), l.substr(0, pos)); l.erase (0, pos + 1 /* separator char */); } // store remainder line = l; break; } case MODE_LINE: // Special to autodetect nmdc connections... if(separator == 0) { if(inbuf[0] == '$') { separator = '|'; } else { separator = '\n'; } } l = line + string ((char*)&inbuf[bufpos], left); while ((pos = l.find(separator)) != string::npos) { fire(BufferedSocketListener::Line(), l.substr(0, pos)); l.erase (0, pos + 1 /* separator char */); if (l.length() < (size_t)left) left = l.length(); if (mode != MODE_LINE) { // we changed mode; remainder of l is invalid. l.clear(); bufpos = total - left; break; } } if (pos == string::npos) left = 0; line = l; break; case MODE_DATA: while(left > 0) { if(dataBytes == -1) { fire(BufferedSocketListener::Data(), &inbuf[bufpos], left); bufpos += (left - rollback); left = rollback; rollback = 0; } else { int high = (int)min(dataBytes, (int64_t)left); fire(BufferedSocketListener::Data(), &inbuf[bufpos], high); bufpos += high; left -= high; dataBytes -= high; if(dataBytes == 0) { mode = MODE_LINE; fire(BufferedSocketListener::ModeChange()); } } } break; } } if(mode == MODE_LINE && line.size() > static_cast(SETTING(MAX_COMMAND_LENGTH))) { throw SocketException(_("Maximum command length exceeded")); } } void BufferedSocket::threadSendFile(InputStream* file) throw(Exception) { if(state != RUNNING) return; if(disconnecting) return; dcassert(file != NULL); size_t sockSize = (size_t)sock->getSocketOptInt(SO_SNDBUF); size_t bufSize = max(sockSize, (size_t)64*1024); ByteVector readBuf(bufSize); ByteVector writeBuf(bufSize); size_t readPos = 0; bool readDone = false; dcdebug("Starting threadSend\n"); while(true) { if(!readDone && readBuf.size() > readPos) { // Fill read buffer size_t bytesRead = readBuf.size() - readPos; size_t actual = file->read(&readBuf[readPos], bytesRead); if(bytesRead > 0) { fire(BufferedSocketListener::BytesSent(), bytesRead, 0); } if(actual == 0) { readDone = true; } else { readPos += actual; } } if(readDone && readPos == 0) { fire(BufferedSocketListener::TransmitDone()); return; } readBuf.swap(writeBuf); readBuf.resize(bufSize); writeBuf.resize(readPos); readPos = 0; size_t writePos = 0; while(writePos < writeBuf.size()) { if(disconnecting) return; size_t writeSize = min(sockSize / 2, writeBuf.size() - writePos); int written = sock->write(&writeBuf[writePos], writeSize); if(written > 0) { writePos += written; fire(BufferedSocketListener::BytesSent(), 0, written); } else if(written == -1) { if(!readDone && readPos < readBuf.size()) { // Read a little since we're blocking anyway... size_t bytesRead = min(readBuf.size() - readPos, readBuf.size() / 2); size_t actual = file->read(&readBuf[readPos], bytesRead); if(bytesRead > 0) { fire(BufferedSocketListener::BytesSent(), bytesRead, 0); } if(actual == 0) { readDone = true; } else { readPos += actual; } } else { while(!disconnecting) { int w = sock->wait(POLL_TIMEOUT, Socket::WAIT_WRITE | Socket::WAIT_READ); if(w & Socket::WAIT_READ) { threadRead(); } if(w & Socket::WAIT_WRITE) { break; } } } } } } } void BufferedSocket::write(const char* aBuf, size_t aLen) throw() { if(!sock.get()) return; Lock l(cs); if(writeBuf.empty()) addTask(SEND_DATA, 0); writeBuf.insert(writeBuf.end(), aBuf, aBuf+aLen); } void BufferedSocket::threadSendData() { if(state != RUNNING) return; { Lock l(cs); if(writeBuf.empty()) return; writeBuf.swap(sendBuf); } size_t left = sendBuf.size(); size_t done = 0; while(left > 0) { if(disconnecting) { return; } int w = sock->wait(POLL_TIMEOUT, Socket::WAIT_READ | Socket::WAIT_WRITE); if(w & Socket::WAIT_READ) { threadRead(); } if(w & Socket::WAIT_WRITE) { int n = sock->write(&sendBuf[done], left); if(n > 0) { left -= n; done += n; } } } sendBuf.clear(); } bool BufferedSocket::checkEvents() { while(state == RUNNING ? taskSem.wait(0) : taskSem.wait()) { pair > p; { Lock l(cs); dcassert(tasks.size() > 0); p = tasks.front(); tasks.erase(tasks.begin()); } if(p.first == SHUTDOWN) { return false; } else if(p.first == UPDATED) { fire(BufferedSocketListener::Updated()); } if(state == STARTING) { if(p.first == CONNECT) { ConnectInfo* ci = static_cast(p.second.get()); threadConnect(ci->addr, ci->port, ci->proxy); } else if(p.first == ACCEPTED) { state = RUNNING; } else { dcdebug("%d unexpected in STARTING state", p.first); } } else if(state == RUNNING) { if(p.first == SEND_DATA) { threadSendData(); } else if(p.first == SEND_FILE) { threadSendFile(static_cast(p.second.get())->stream); break; } else if(p.first == DISCONNECT) { fail(_("Disconnected")); } else { dcdebug("%d unexpected in RUNNING state", p.first); } } } return true; } void BufferedSocket::checkSocket() { int waitFor = sock->wait(POLL_TIMEOUT, Socket::WAIT_READ); if(waitFor & Socket::WAIT_READ) { threadRead(); } } /** * Main task dispatcher for the buffered socket abstraction. * @todo Fix the polling... */ int BufferedSocket::run() { dcdebug("BufferedSocket::run() start %p\n", (void*)this); while(true) { try { if(!checkEvents()) { break; } if(state == RUNNING) { checkSocket(); } } catch(const Exception& e) { fail(e.getError()); } } dcdebug("BufferedSocket::run() end %p\n", (void*)this); delete this; return 0; } void BufferedSocket::fail(const string& aError) { sock->disconnect(); if(state == RUNNING) { state = FAILED; fire(BufferedSocketListener::Failed(), aError); } } void BufferedSocket::shutdown() { Lock l(cs); disconnecting = true; addTask(SHUTDOWN, 0); } void BufferedSocket::addTask(Tasks task, TaskData* data) { dcassert(task == SHUTDOWN || sock.get()); tasks.push_back(make_pair(task, data)); taskSem.signal(); } } // namespace dcpp