#include "inspector_socket_server.h"

#include "node.h"
#include "uv.h"
#include "zlib.h"

#include <algorithm>
#include <map>
#include <set>
#include <sstream>

namespace node {
namespace inspector {

namespace {

static const uint8_t PROTOCOL_JSON[] = {
  #include "v8_inspector_protocol_json.h"  // NOLINT(build/include_order)
};

void Escape(std::string* string) {
  for (char& c : *string) {
    c = (c == '\"' || c == '\\') ? '_' : c;
  }
}

std::string GetWsUrl(const std::string& host, int port, const std::string& id) {
  char buf[1024];
  snprintf(buf, sizeof(buf), "%s:%d/%s", host.c_str(), port, id.c_str());
  return buf;
}

std::string MapToString(const std::map<std::string, std::string>& object) {
  bool first = true;
  std::ostringstream json;
  json << "{\n";
  for (const auto& name_value : object) {
    if (!first)
      json << ",\n";
    first = false;
    json << "  \"" << name_value.first << "\": \"";
    json << name_value.second << "\"";
  }
  json << "\n} ";
  return json.str();
}

std::string MapsToString(
    const std::vector<std::map<std::string, std::string>>& array) {
  bool first = true;
  std::ostringstream json;
  json << "[ ";
  for (const auto& object : array) {
    if (!first)
      json << ", ";
    first = false;
    json << MapToString(object);
  }
  json << "]\n\n";
  return json.str();
}

const char* MatchPathSegment(const char* path, const char* expected) {
  size_t len = strlen(expected);
  if (StringEqualNoCaseN(path, expected, len)) {
    if (path[len] == '/') return path + len + 1;
    if (path[len] == '\0') return path + len;
  }
  return nullptr;
}

void OnBufferAlloc(uv_handle_t* handle, size_t len, uv_buf_t* buf) {
  buf->base = new char[len];
  buf->len = len;
}

void PrintDebuggerReadyMessage(const std::string& host,
                               int port,
                               const std::vector<std::string>& ids,
                               FILE* out) {
  if (out == NULL) {
    return;
  }
  for (const std::string& id : ids) {
    fprintf(out, "Debugger listening on ws://%s\n",
            GetWsUrl(host, port, id).c_str());
  }
  fprintf(out, "For help see %s\n",
          "https://nodejs.org/en/docs/inspector");
  fflush(out);
}

void SendHttpResponse(InspectorSocket* socket, const std::string& response) {
  const char HEADERS[] = "HTTP/1.0 200 OK\r\n"
                         "Content-Type: application/json; charset=UTF-8\r\n"
                         "Cache-Control: no-cache\r\n"
                         "Content-Length: %zu\r\n"
                         "\r\n";
  char header[sizeof(HEADERS) + 20];
  int header_len = snprintf(header, sizeof(header), HEADERS, response.size());
  inspector_write(socket, header, header_len);
  inspector_write(socket, response.data(), response.size());
}

void SendVersionResponse(InspectorSocket* socket) {
  std::map<std::string, std::string> response;
  response["Browser"] = "node.js/" NODE_VERSION;
  response["Protocol-Version"] = "1.1";
  SendHttpResponse(socket, MapToString(response));
}

void SendProtocolJson(InspectorSocket* socket) {
  z_stream strm;
  strm.zalloc = Z_NULL;
  strm.zfree = Z_NULL;
  strm.opaque = Z_NULL;
  CHECK_EQ(Z_OK, inflateInit(&strm));
  static const size_t kDecompressedSize =
      PROTOCOL_JSON[0] * 0x10000u +
      PROTOCOL_JSON[1] * 0x100u +
      PROTOCOL_JSON[2];
  strm.next_in = const_cast<uint8_t*>(PROTOCOL_JSON + 3);
  strm.avail_in = sizeof(PROTOCOL_JSON) - 3;
  std::string data(kDecompressedSize, '\0');
  strm.next_out = reinterpret_cast<Byte*>(&data[0]);
  strm.avail_out = data.size();
  CHECK_EQ(Z_STREAM_END, inflate(&strm, Z_FINISH));
  CHECK_EQ(0, strm.avail_out);
  CHECK_EQ(Z_OK, inflateEnd(&strm));
  SendHttpResponse(socket, data);
}

int GetSocketHost(uv_tcp_t* socket, std::string* out_host) {
  char ip[INET6_ADDRSTRLEN];
  sockaddr_storage addr;
  int len = sizeof(addr);
  int err = uv_tcp_getsockname(socket,
                               reinterpret_cast<struct sockaddr*>(&addr),
                               &len);
  if (err != 0)
    return err;
  if (addr.ss_family == AF_INET6) {
    const sockaddr_in6* v6 = reinterpret_cast<const sockaddr_in6*>(&addr);
    err = uv_ip6_name(v6, ip, sizeof(ip));
  } else {
    const sockaddr_in* v4 = reinterpret_cast<const sockaddr_in*>(&addr);
    err = uv_ip4_name(v4, ip, sizeof(ip));
  }
  if (err != 0)
    return err;
  *out_host = ip;
  return err;
}

int GetPort(uv_tcp_t* socket, int* out_port) {
  sockaddr_storage addr;
  int len = sizeof(addr);
  int err = uv_tcp_getsockname(socket,
                               reinterpret_cast<struct sockaddr*>(&addr),
                               &len);
  if (err != 0)
    return err;
  int port;
  if (addr.ss_family == AF_INET6)
    port = reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_port;
  else
    port = reinterpret_cast<const sockaddr_in*>(&addr)->sin_port;
  *out_port = ntohs(port);
  return err;
}

}  // namespace


class Closer {
 public:
  explicit Closer(InspectorSocketServer* server) : server_(server),
                                                   close_count_(0) { }

  void AddCallback(InspectorSocketServer::ServerCallback callback) {
    if (callback == nullptr)
      return;
    callbacks_.insert(callback);
  }

  void DecreaseExpectedCount() {
    --close_count_;
    NotifyIfDone();
  }

  void IncreaseExpectedCount() {
    ++close_count_;
  }

  void NotifyIfDone() {
    if (close_count_ == 0) {
      for (auto callback : callbacks_) {
        callback(server_);
      }
      InspectorSocketServer* server = server_;
      delete server->closer_;
      server->closer_ = nullptr;
    }
  }

 private:
  InspectorSocketServer* server_;
  std::set<InspectorSocketServer::ServerCallback> callbacks_;
  int close_count_;
};

class SocketSession {
 public:
  SocketSession(InspectorSocketServer* server, int id);
  void Close();
  void Declined() { state_ = State::kDeclined; }
  static SocketSession* From(InspectorSocket* socket) {
    return node::ContainerOf(&SocketSession::socket_, socket);
  }
  void FrontendConnected();
  InspectorSocketServer* GetServer() { return server_; }
  int Id() { return id_; }
  void Send(const std::string& message);
  void SetTargetId(const std::string& target_id) {
    CHECK(target_id_.empty());
    target_id_ = target_id;
  }
  InspectorSocket* Socket() { return &socket_; }
  const std::string TargetId() { return target_id_; }

 private:
  enum class State { kHttp, kWebSocket, kClosing, kEOF, kDeclined };
  static void CloseCallback_(InspectorSocket* socket, int code);
  static void ReadCallback_(uv_stream_t* stream, ssize_t read,
                            const uv_buf_t* buf);
  void OnRemoteDataIO(ssize_t read, const uv_buf_t* buf);
  const int id_;
  InspectorSocket socket_;
  InspectorSocketServer* server_;
  std::string target_id_;
  State state_;
};

InspectorSocketServer::InspectorSocketServer(SocketServerDelegate* delegate,
                                             const std::string& host,
                                             int port,
                                             FILE* out) : loop_(nullptr),
                                                          delegate_(delegate),
                                                          host_(host),
                                                          port_(port),
                                                          server_(uv_tcp_t()),
                                                          closer_(nullptr),
                                                          next_session_id_(0),
                                                          out_(out) {
  state_ = ServerState::kNew;
}


// static
bool InspectorSocketServer::HandshakeCallback(InspectorSocket* socket,
                                              inspector_handshake_event event,
                                              const std::string& path) {
  InspectorSocketServer* server = SocketSession::From(socket)->GetServer();
  const std::string& id = path.empty() ? path : path.substr(1);
  switch (event) {
  case kInspectorHandshakeHttpGet:
    return server->RespondToGet(socket, path);
  case kInspectorHandshakeUpgrading:
    return server->SessionStarted(SocketSession::From(socket), id);
  case kInspectorHandshakeUpgraded:
    SocketSession::From(socket)->FrontendConnected();
    return true;
  case kInspectorHandshakeFailed:
    server->SessionTerminated(SocketSession::From(socket));
    return false;
  default:
    UNREACHABLE();
    return false;
  }
}

bool InspectorSocketServer::SessionStarted(SocketSession* session,
                                           const std::string& id) {
  bool connected = false;
  if (TargetExists(id)) {
    connected = delegate_->StartSession(session->Id(), id);
  }
  if (connected) {
    connected_sessions_[session->Id()] = session;
    session->SetTargetId(id);
  } else {
    session->Declined();
  }
  return connected;
}

void InspectorSocketServer::SessionTerminated(SocketSession* session) {
  int id = session->Id();
  if (connected_sessions_.erase(id) != 0) {
    delegate_->EndSession(id);
    if (connected_sessions_.empty()) {
      if (state_ == ServerState::kRunning) {
        PrintDebuggerReadyMessage(host_, port_,
                                  delegate_->GetTargetIds(), out_);
      }
      if (state_ == ServerState::kStopped) {
        delegate_->ServerDone();
      }
    }
  }
  delete session;
}

bool InspectorSocketServer::RespondToGet(InspectorSocket* socket,
                                         const std::string& path) {
  const char* command = MatchPathSegment(path.c_str(), "/json");
  if (command == nullptr)
    return false;

  if (MatchPathSegment(command, "list") || command[0] == '\0') {
    SendListResponse(socket);
    return true;
  } else if (MatchPathSegment(command, "protocol")) {
    SendProtocolJson(socket);
    return true;
  } else if (MatchPathSegment(command, "version")) {
    SendVersionResponse(socket);
    return true;
  } else if (const char* target_id = MatchPathSegment(command, "activate")) {
    if (TargetExists(target_id)) {
      SendHttpResponse(socket, "Target activated");
      return true;
    }
    return false;
  }
  return false;
}

void InspectorSocketServer::SendListResponse(InspectorSocket* socket) {
  std::vector<std::map<std::string, std::string>> response;
  for (const std::string& id : delegate_->GetTargetIds()) {
    response.push_back(std::map<std::string, std::string>());
    std::map<std::string, std::string>& target_map = response.back();
    target_map["description"] = "node.js instance";
    target_map["faviconUrl"] = "https://nodejs.org/static/favicon.ico";
    target_map["id"] = id;
    target_map["title"] = delegate_->GetTargetTitle(id);
    Escape(&target_map["title"]);
    target_map["type"] = "node";
    // This attribute value is a "best effort" URL that is passed as a JSON
    // string. It is not guaranteed to resolve to a valid resource.
    target_map["url"] = delegate_->GetTargetUrl(id);
    Escape(&target_map["url"]);

    bool connected = false;
    for (const auto& session : connected_sessions_) {
      if (session.second->TargetId() == id) {
        connected = true;
        break;
      }
    }
    if (!connected) {
      std::string host;
      GetSocketHost(&socket->client, &host);
      std::string address = GetWsUrl(host, port_, id);
      std::ostringstream frontend_url;
      frontend_url << "chrome-devtools://devtools/bundled";
      frontend_url << "/inspector.html?experiments=true&v8only=true&ws=";
      frontend_url << address;
      target_map["devtoolsFrontendUrl"] += frontend_url.str();
      target_map["webSocketDebuggerUrl"] = "ws://" + address;
    }
  }
  SendHttpResponse(socket, MapsToString(response));
}

bool InspectorSocketServer::Start(uv_loop_t* loop) {
  CHECK_EQ(state_, ServerState::kNew);
  loop_ = loop;
  sockaddr_in addr;
  uv_tcp_init(loop_, &server_);
  uv_ip4_addr(host_.c_str(), port_, &addr);
  int err = uv_tcp_bind(&server_,
                        reinterpret_cast<const struct sockaddr*>(&addr), 0);
  if (err == 0)
    err = GetPort(&server_, &port_);
  if (err == 0) {
    err = uv_listen(reinterpret_cast<uv_stream_t*>(&server_), 1,
                    SocketConnectedCallback);
  }
  if (err == 0 && connected_sessions_.empty()) {
    state_ = ServerState::kRunning;
    PrintDebuggerReadyMessage(host_, port_, delegate_->GetTargetIds(), out_);
  }
  if (err != 0 && connected_sessions_.empty()) {
    if (out_ != NULL) {
      fprintf(out_, "Starting inspector on %s:%d failed: %s\n",
              host_.c_str(), port_, uv_strerror(err));
      fflush(out_);
    }
    uv_close(reinterpret_cast<uv_handle_t*>(&server_), nullptr);
    return false;
  }
  return true;
}

void InspectorSocketServer::Stop(ServerCallback cb) {
  CHECK_EQ(state_, ServerState::kRunning);
  if (closer_ == nullptr) {
    closer_ = new Closer(this);
  }
  closer_->AddCallback(cb);
  closer_->IncreaseExpectedCount();
  state_ = ServerState::kStopping;
  uv_close(reinterpret_cast<uv_handle_t*>(&server_), ServerClosedCallback);
  closer_->NotifyIfDone();
}

void InspectorSocketServer::TerminateConnections() {
  for (const auto& session : connected_sessions_) {
    session.second->Close();
  }
}

bool InspectorSocketServer::TargetExists(const std::string& id) {
  const std::vector<std::string>& target_ids = delegate_->GetTargetIds();
  const auto& found = std::find(target_ids.begin(), target_ids.end(), id);
  return found != target_ids.end();
}

void InspectorSocketServer::Send(int session_id, const std::string& message) {
  auto session_iterator = connected_sessions_.find(session_id);
  if (session_iterator != connected_sessions_.end()) {
    session_iterator->second->Send(message);
  }
}

// static
void InspectorSocketServer::ServerClosedCallback(uv_handle_t* server) {
  InspectorSocketServer* socket_server = InspectorSocketServer::From(server);
  CHECK_EQ(socket_server->state_, ServerState::kStopping);
  if (socket_server->closer_) {
    socket_server->closer_->DecreaseExpectedCount();
  }
  if (socket_server->connected_sessions_.empty()) {
    socket_server->delegate_->ServerDone();
  }
  socket_server->state_ = ServerState::kStopped;
}

// static
void InspectorSocketServer::SocketConnectedCallback(uv_stream_t* server,
                                                    int status) {
  if (status == 0) {
    InspectorSocketServer* socket_server = InspectorSocketServer::From(server);
    // Memory is freed when the socket closes.
    SocketSession* session =
        new SocketSession(socket_server, socket_server->next_session_id_++);
    if (inspector_accept(server, session->Socket(), HandshakeCallback) != 0) {
      delete session;
    }
  }
}

// InspectorSession tracking
SocketSession::SocketSession(InspectorSocketServer* server, int id)
                             : id_(id), server_(server),
                               state_(State::kHttp) { }

void SocketSession::Close() {
  CHECK_NE(state_, State::kClosing);
  state_ = State::kClosing;
  inspector_close(&socket_, CloseCallback_);
}

// static
void SocketSession::CloseCallback_(InspectorSocket* socket, int code) {
  SocketSession* session = SocketSession::From(socket);
  CHECK_EQ(State::kClosing, session->state_);
  session->server_->SessionTerminated(session);
}

void SocketSession::FrontendConnected() {
  CHECK_EQ(State::kHttp, state_);
  state_ = State::kWebSocket;
  inspector_read_start(&socket_, OnBufferAlloc, ReadCallback_);
}

// static
void SocketSession::ReadCallback_(uv_stream_t* stream, ssize_t read,
                                  const uv_buf_t* buf) {
  InspectorSocket* socket = inspector_from_stream(stream);
  SocketSession::From(socket)->OnRemoteDataIO(read, buf);
}

void SocketSession::OnRemoteDataIO(ssize_t read, const uv_buf_t* buf) {
  if (read > 0) {
    server_->Delegate()->MessageReceived(id_, std::string(buf->base, read));
  } else {
    Close();
  }
  if (buf != nullptr && buf->base != nullptr)
    delete[] buf->base;
}

void SocketSession::Send(const std::string& message) {
  inspector_write(&socket_, message.data(), message.length());
}

}  // namespace inspector
}  // namespace node
