/*
This code is copyright (C) 1998 Robert O'Callahan.
This code is free for non-commercial use. Modification in all forms is permitted.
This license continues to apply to any modified versions. This license text must be
reproduced and distributed with any modified versions.
As a matter of courtesy I (Robert O'Callahan) would like to be informed of
any potentially useful modifications.
*/

#include "ttssh.h"
#include "util.h"

#include <bn.h>

static char ssh_ttymodes[] = "\x01\x03\x02\x1c\x03\x08\x04\x15\x05\x04";

static void auth_status_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding);
static void rsa_challenge_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding);

/* For use by the protocol processing code.
   Gets N bytes of uncompressed payload. Returns 0 if data not available
   and a fatal error has been signalled.
   The data is available in the payload buffer. This buffer address
   can change during a call to grab_payload, so take care!
   The payload pointer is set to point to the first byte of the actual data
   (after the packet type byte).
*/
#define grab_payload(pvar, num_bytes) grab_payload_limited((pvar), (num_bytes), 0)

static int grab_payload_limited(struct _TInstVar FAR * pvar, int num_bytes, unsigned int limit) {
  if (!pvar->ssh_state.compressing) {
    pvar->ssh_state.payload_grabbed += num_bytes;

    if (pvar->ssh_state.payload_grabbed > pvar->ssh_state.payloadlen) {
      notify_fatal_error(pvar, "Received truncated packet");
      return 0;
    } else {
      return 1;
    }
  } else {
    int cur_decompressed_bytes = pvar->ssh_state.decompress_stream.next_out -
      pvar->ssh_state.postdecompress_inbuf;
    int new_buf_size = num_bytes + pvar->ssh_state.payload_grabbed;

    if (new_buf_size <= cur_decompressed_bytes) {
      pvar->ssh_state.payload_grabbed += num_bytes;
      return 1;
    } else {
      buf_ensure_size(&pvar->ssh_state.postdecompress_inbuf, &pvar->ssh_state.postdecompress_inbuflen,
        new_buf_size + 1);
      pvar->ssh_state.payload = pvar->ssh_state.postdecompress_inbuf + 1;
      
      pvar->ssh_state.decompress_stream.next_out = pvar->ssh_state.postdecompress_inbuf
        + cur_decompressed_bytes;
      pvar->ssh_state.decompress_stream.avail_out = pvar->ssh_state.postdecompress_inbuflen
        - cur_decompressed_bytes;

      if (limit > 0 && pvar->ssh_state.decompress_stream.avail_out > limit) {
        pvar->ssh_state.decompress_stream.avail_out = limit;
      }
      
      while (new_buf_size > cur_decompressed_bytes) {
        if (inflate(&pvar->ssh_state.decompress_stream, Z_PARTIAL_FLUSH) != Z_OK) {
          notify_fatal_error(pvar, "Invalid compressed data in received packet");
          return 0;
        } else {
          cur_decompressed_bytes = pvar->ssh_state.decompress_stream.next_out -
            pvar->ssh_state.postdecompress_inbuf;
        }
      }
      pvar->ssh_state.payload_grabbed += num_bytes;
      return 1;
    }
  }
}

#define get_payload_uint32(pvar, offset) get_uint32_MSBfirst((pvar)->ssh_state.payload + (offset))
#define get_uint32(buf) get_uint32_MSBfirst((buf))
#define set_uint32(buf, v) set_uint32_MSBfirst((buf), (v))
#define get_mpint_len(pvar, offset) ((get_ushort16_MSBfirst((pvar)->ssh_state.payload + (offset)) + 7) >> 3)
#define get_ushort16(buf) get_ushort16_MSBfirst((buf))

#define do_crc(buf, len) (~(uint32)crc32(0xFFFFFFFF, (buf), (len)))

/* Decrypt the payload, checksum it, eat the padding, get the packet type
   and return it */
static int prep_packet(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  pvar->ssh_state.payload = data;
  pvar->ssh_state.payloadlen = len;

  CRYPT_decrypt(pvar, data, len);
  /* PKT guarantees that the data is always 4-byte aligned */
  if (do_crc(data, len - 4) != get_uint32_MSBfirst(pvar->ssh_state.payload + len - 4)) {
    notify_fatal_error(pvar, "Invalid checksum in received packet");
    return SSH_MSG_NONE;
  } else {
    pvar->ssh_state.payload += padding;
    pvar->ssh_state.payloadlen -= padding + 4;
    pvar->ssh_state.payload_grabbed = 0;

    if (pvar->ssh_state.compressing) {
      if (pvar->ssh_state.decompress_stream.avail_in != 0) {
        notify_nonfatal_error(pvar, "Internal error: a packet was not fully decompressed.\n"
          "This is a bug, please report it.");
      }

      pvar->ssh_state.decompress_stream.next_in = pvar->ssh_state.payload;
      pvar->ssh_state.decompress_stream.avail_in = pvar->ssh_state.payloadlen;
      pvar->ssh_state.decompress_stream.next_out = pvar->ssh_state.postdecompress_inbuf;
    } else {
      pvar->ssh_state.payload++;
    }
      
    if (!grab_payload_limited(pvar, 1, 1)) {
      return SSH_MSG_NONE;
    }

    return pvar->ssh_state.payload[-1];
  }
}

static unsigned char FAR * begin_send_packet(struct _TInstVar FAR * pvar, int type, int len) {
  int padding = 8 - ((len + 5) % 8);
  unsigned char FAR * buf;

  pvar->ssh_state.outgoing_packet_len = len + 1;

  if (pvar->ssh_state.compressing) {
    buf_ensure_size(&pvar->ssh_state.precompress_outbuf, &pvar->ssh_state.precompress_outbuflen,
      1 + len);
    buf = pvar->ssh_state.precompress_outbuf;
  } else {
    buf_ensure_size(&pvar->ssh_state.outbuf, &pvar->ssh_state.outbuflen, len + padding + 9);
    buf = pvar->ssh_state.outbuf + 4 + padding;
  }

  buf[0] = (unsigned char)type;
  return buf + 1;
}

#define finish_send_packet(pvar) finish_send_packet_special((pvar), 0)

/* if skip_compress is true, then the data has already been compressed
   into outbuf + 12 */
static void finish_send_packet_special(struct _TInstVar FAR * pvar, int skip_compress) {
  int len = pvar->ssh_state.outgoing_packet_len;
  int padding;
  int total_size;
  u_long do_block = 0;

  if (pvar->ssh_state.compressing) {
    if (!skip_compress) {
      buf_ensure_size(&pvar->ssh_state.outbuf, &pvar->ssh_state.outbuflen, len + (len >> 6) + 50);
      pvar->ssh_state.compress_stream.next_in = pvar->ssh_state.precompress_outbuf;
      pvar->ssh_state.compress_stream.avail_in = len;
      pvar->ssh_state.compress_stream.next_out = pvar->ssh_state.outbuf + 12;
      pvar->ssh_state.compress_stream.avail_out = pvar->ssh_state.outbuflen - 12;
    
      if (deflate(&pvar->ssh_state.compress_stream, Z_PARTIAL_FLUSH) != Z_OK) {
        notify_fatal_error(pvar, "An error occurred while compressing packet data.\n"
          "The connection will close.");
        return;
      }
    }

    len = pvar->ssh_state.outbuflen - 12 - pvar->ssh_state.compress_stream.avail_out;
    padding = 8 - ((len + 4) % 8);
    if (padding != 8) {
      memmove(pvar->ssh_state.outbuf + 4 + padding, pvar->ssh_state.outbuf + 12, len);
    }
  } else {
    padding = 8 - ((len + 4) % 8);
  }

  total_size = padding + len + 8;
  set_uint32(pvar->ssh_state.outbuf + 0, len + 4);
  CRYPT_set_random_data(pvar, pvar->ssh_state.outbuf + 4, padding);
  set_uint32(pvar->ssh_state.outbuf + total_size - 4,
    do_crc(pvar->ssh_state.outbuf + 4, total_size - 8));
  CRYPT_encrypt(pvar, pvar->ssh_state.outbuf + 4, total_size - 4);

  if ((pvar->PWSAAsyncSelect)(pvar->socket, pvar->NotificationWindow,
        0, 0) == SOCKET_ERROR
    || ioctlsocket(pvar->socket, FIONBIO, &do_block) == SOCKET_ERROR
    || (pvar->Psend)(pvar->socket, pvar->ssh_state.outbuf, total_size, 0) != total_size
    || (pvar->PWSAAsyncSelect)(pvar->socket, pvar->NotificationWindow,
         pvar->notification_msg, pvar->notification_events) == SOCKET_ERROR) {
    notify_fatal_error(pvar, "A communications error occurred while sending an SSH packet.\n"
       "The connection will close.");
  }
}

static void destroy_packet_buf(struct _TInstVar FAR * pvar) {
  memset(pvar->ssh_state.outbuf, 0, pvar->ssh_state.outbuflen);
  if (pvar->ssh_state.compressing) {
    memset(pvar->ssh_state.precompress_outbuf, 0, pvar->ssh_state.precompress_outbuflen);
  }
}

void SSH_handle_server_ID(struct _TInstVar FAR * pvar, char FAR * ID) {
  char TTSSH_ID[] = "SSH-1.5-TTSSH-1.2\n";

  if (strncmp(ID, "SSH-", 4) != 0) {
    notify_fatal_error(pvar, "The server does not understand the SSH protocol.");
  } else if (strncmp(ID + 4, "1.", 2) != 0) {
    notify_fatal_error(pvar, "This program does not understand the server's version of the protocol.");
  } else {
    int minorVersion = atoi(ID + 6);

    if (minorVersion > 0 && minorVersion < TTSSH_ID[6] - '0') {
      TTSSH_ID[6] = (char)(minorVersion + '0');
    }

    if ((pvar->Psend)(pvar->socket, TTSSH_ID, sizeof(TTSSH_ID) - 1, 0)
      != sizeof(TTSSH_ID) - 1) {
      notify_fatal_error(pvar, "An error occurred while sending the SSH ID string.\n"
         "The connection will close.");
    } else {
      CRYPT_initialize_random_numbers(pvar);
      HOSTS_prefetch_host_key(pvar, pvar->ssh_state.hostname);
      /* while we wait for a response from the server... */
    }
  }
}

static void handle_ignore(struct _TInstVar FAR * pvar) {
  if (grab_payload(pvar, 4) && grab_payload(pvar, get_payload_uint32(pvar, 0))) {
    /* ignore it! but it must be decompressed */
  }
}

static void handle_debug(struct _TInstVar FAR * pvar) {
  if (grab_payload(pvar, 4) && grab_payload(pvar, get_payload_uint32(pvar, 0))) {
    char buf[2048];
    int len = get_payload_uint32(pvar, 0);

    if (len > sizeof(buf) - 2) {
      len = sizeof(buf) - 2;
    }
    memcpy(buf, pvar->ssh_state.payload + 4, len);
    buf[len] = '\n';
    buf[len + 1] = 0;
    OutputDebugString(buf);
  }
}

static void handle_disconnect(struct _TInstVar FAR * pvar) {
  int len;

  if (grab_payload(pvar, 4) && grab_payload(pvar, len = get_payload_uint32(pvar, 0))) {
    pvar->ssh_state.payload[4 + len] = 0;
    if (pvar->ssh_state.payload[0] == 0) {
      notify_fatal_error(pvar, "Server disconnected (no reason given).");
    } else {
      notify_fatal_error(pvar, pvar->ssh_state.payload);
    }
  }
}

static void handle_exit(struct _TInstVar FAR * pvar) {
  if (grab_payload(pvar, 4)) {
    begin_send_packet(pvar, SSH_CMSG_EXIT_CONFIRMATION, 0);
    finish_send_packet(pvar);
    notify_closed_connection(pvar);
  }
}

static void handle_data(struct _TInstVar FAR * pvar) {
  if (grab_payload_limited(pvar, 4, 4)) {
    pvar->ssh_state.payload_datalen = get_payload_uint32(pvar, 0);
    pvar->ssh_state.payload_datastart = 4;
  }
}

static void handle_channel_open(struct _TInstVar FAR * pvar) {
  int host_len;
  int originator_len;

  if (grab_payload(pvar, 8)
    && grab_payload(pvar, 8 + (host_len = get_payload_uint32(pvar, 4)))
    && grab_payload(pvar, originator_len = get_payload_uint32(pvar, host_len + 12))) {
    int local_port = get_payload_uint32(pvar, 8 + host_len);

    pvar->ssh_state.payload[8 + host_len] = 0;
    FWD_open(pvar, get_payload_uint32(pvar, 0), pvar->ssh_state.payload + 8,
      local_port, pvar->ssh_state.payload + 16 + host_len, originator_len);
  }
}

static void handle_channel_open_confirmation(struct _TInstVar FAR * pvar) {
  if (grab_payload(pvar, 8)) {
    FWD_confirmed_open(pvar, get_payload_uint32(pvar, 0), get_payload_uint32(pvar, 4));
  }
}

static void handle_channel_open_failure(struct _TInstVar FAR * pvar) {
  if (grab_payload(pvar, 4)) {
    FWD_failed_open(pvar, get_payload_uint32(pvar, 0));
  }
}

static void handle_channel_data(struct _TInstVar FAR * pvar) {
  int len;

  if (grab_payload(pvar, 8) && grab_payload(pvar, len = get_payload_uint32(pvar, 4))) {
    FWD_received_data(pvar, get_payload_uint32(pvar, 0),
      pvar->ssh_state.payload + 8, len);
  }
}

static void handle_channel_input_eof(struct _TInstVar FAR * pvar) {
  if (grab_payload(pvar, 4)) {
    FWD_channel_input_eof(pvar, get_payload_uint32(pvar, 0));
  }
}

static void handle_channel_output_eof(struct _TInstVar FAR * pvar) {
  if (grab_payload(pvar, 4)) {
    FWD_channel_output_eof(pvar, get_payload_uint32(pvar, 0));
  }
}

static void interactive_mode_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_SMSG_EXITSTATUS:   handle_exit(pvar); break;
  case SSH_SMSG_STDOUT_DATA:
  case SSH_SMSG_STDERR_DATA:  handle_data(pvar); break;
  case SSH_MSG_CHANNEL_OPEN_CONFIRMATION:
    handle_channel_open_confirmation(pvar); break;
  case SSH_MSG_CHANNEL_OPEN_FAILURE:
    handle_channel_open_failure(pvar); break;
  case SSH_MSG_CHANNEL_DATA:  handle_channel_data(pvar); break;
  case SSH_MSG_CHANNEL_INPUT_EOF: handle_channel_input_eof(pvar); break;
  case SSH_MSG_CHANNEL_OUTPUT_CLOSED:
    handle_channel_output_eof(pvar); break;
  case SSH_MSG_PORT_OPEN:     handle_channel_open(pvar); break;
  default:
    notify_fatal_error(pvar, "Invalid packet; expected interactive mode message");
  }
}

static void prep_pty_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_SMSG_FAILURE: {
    notify_nonfatal_error(pvar, "The server cannot allocate a pseudo-terminal. "
                          "You may encounter some problems with the terminal.");
    FWD_enter_interactive_mode(pvar);
    pvar->ssh_state.packet_handler = interactive_mode_handler;
    break;
  }
  case SSH_SMSG_SUCCESS: {
    FWD_enter_interactive_mode(pvar);
    pvar->ssh_state.packet_handler = interactive_mode_handler;
    break;
  }
  default:
    notify_fatal_error(pvar, "Invalid packet; expected pseudo-terminal request acknowledgement");
  }
}

static void prep_pty(struct _TInstVar FAR * pvar) {
  int len = strlen(pvar->ts->TermType);
  unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_REQUEST_PTY,
    4 + len + 16 + sizeof(ssh_ttymodes));

  set_uint32(outmsg, len);
  memcpy(outmsg + 4, pvar->ts->TermType, len);
  set_uint32(outmsg + 4 + len, pvar->ssh_state.win_rows);
  set_uint32(outmsg + 4 + len + 4, pvar->ssh_state.win_cols);
  set_uint32(outmsg + 4 + len + 8, 0);
  set_uint32(outmsg + 4 + len + 12, 0);
  memcpy(outmsg + 4 + len + 16, ssh_ttymodes, sizeof(ssh_ttymodes));
  finish_send_packet(pvar);

  begin_send_packet(pvar, SSH_CMSG_EXEC_SHELL, 0);
  finish_send_packet(pvar);

  pvar->ssh_state.packet_handler = prep_pty_handler;
}

static void handle_forwarding_disconnect(struct _TInstVar FAR * pvar) {
  int len;

  if (grab_payload(pvar, 4) && grab_payload(pvar, len = get_payload_uint32(pvar, 0))) {
    pvar->ssh_state.payload[4 + len] = 0;
    if (pvar->ssh_state.payload[0] == 0) {
      notify_fatal_error(pvar, "Server disconnected (no reason given).\n"
        "It may have disconnected because it was unable to forward a port you requested to be forwarded from the server.");
    } else {
      notify_fatal_error(pvar, pvar->ssh_state.payload);
    }
  }
}

static void prep_forwarding_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_forwarding_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_SMSG_FAILURE:
  case SSH_SMSG_SUCCESS: {
    pvar->ssh_state.num_FWD_server_acks_expected--;
    if (pvar->ssh_state.num_FWD_server_acks_expected <= 0) {
      prep_pty(pvar);
    }
    break;
  }
  default:
    notify_fatal_error(pvar, "Invalid packet; expected compression request acknowledgement");
  }
}

static void prep_forwarding(struct _TInstVar FAR * pvar) {
  pvar->ssh_state.num_FWD_server_acks_expected = 0;
  FWD_prep_forwarding(pvar);
  if (pvar->ssh_state.num_FWD_server_acks_expected == 0) {
    prep_pty(pvar);
  } else {
    pvar->ssh_state.packet_handler = prep_forwarding_handler;
  }
}

static void enable_compression(struct _TInstVar FAR * pvar) {
  pvar->ssh_state.compress_stream.zalloc = NULL;
  pvar->ssh_state.compress_stream.zfree = NULL;
  pvar->ssh_state.compress_stream.opaque = NULL;
  if (deflateInit(&pvar->ssh_state.compress_stream, pvar->ssh_state.compression_level) != Z_OK) {
    notify_fatal_error(pvar, "An error occurred while setting up compression.\n"
      "The connection will close.");
    return;
  }

  pvar->ssh_state.decompress_stream.zalloc = NULL;
  pvar->ssh_state.decompress_stream.zfree = NULL;
  pvar->ssh_state.decompress_stream.opaque = NULL;
  if (inflateInit(&pvar->ssh_state.decompress_stream) != Z_OK) {
    deflateEnd(&pvar->ssh_state.compress_stream);
    notify_fatal_error(pvar, "An error occurred while setting up compression.\n"
      "The connection will close.");
    return;
  }

  pvar->ssh_state.compressing = 1;
}

static void prep_compression_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_SMSG_FAILURE:      prep_forwarding(pvar); break;
  case SSH_SMSG_SUCCESS:      enable_compression(pvar); prep_forwarding(pvar); break;
  default:
    notify_fatal_error(pvar, "Invalid packet; expected compression request acknowledgement");
  }
}

static void prep_compression(struct _TInstVar FAR * pvar) {
  if (pvar->session_settings.CompressionLevel > 0) {
    unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_REQUEST_COMPRESSION, 4);

    set_uint32(outmsg, pvar->session_settings.CompressionLevel);
    finish_send_packet(pvar);

    pvar->ssh_state.packet_handler = prep_compression_handler;
    pvar->ssh_state.compression_level = pvar->session_settings.CompressionLevel;
  } else {
    prep_forwarding(pvar);
  }
}

static void handle_prep_phase(struct _TInstVar FAR * pvar) {
  prep_compression(pvar);
}

static void handle_rsa_challenge(struct _TInstVar FAR * pvar) {
  int challenge_bytes;
  
  if (!grab_payload(pvar, 2)) {
    return;
  }

  challenge_bytes = get_mpint_len(pvar, 0);

  if (grab_payload(pvar, challenge_bytes)) {
    unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_AUTH_RSA_RESPONSE, 16);

    if (CRYPT_generate_RSA_challenge_response(pvar, pvar->ssh_state.payload + 2,
      challenge_bytes, outmsg)) {
      AUTH_destroy_cur_cred(pvar);
      finish_send_packet(pvar);

      pvar->ssh_state.packet_handler = auth_status_handler;
    } else {
      notify_fatal_error(pvar, "An error occurred while decrypting the RSA challenge.\n"
        "Perhaps the key file is corrupted.");
    }
  }
}

static void try_send_credentials(struct _TInstVar FAR * pvar) {
  if ((pvar->ssh_state.status_flags & STATUS_DONT_SEND_CREDENTIALS) == 0) {
    AUTHCred FAR * cred = AUTH_get_cur_cred(pvar);

    switch (cred->method) {
    case SSH_AUTH_NONE: return;
    case SSH_AUTH_PASSWORD: {
      int len = strlen(cred->password);
      unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_AUTH_PASSWORD, 4 + len);

      set_uint32(outmsg, len);
      memcpy(outmsg + 4, cred->password, len);
      AUTH_destroy_cur_cred(pvar);
      break;
    }
    case SSH_AUTH_RHOSTS: {
      int len = strlen(cred->rhosts_client_user);
      unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_AUTH_RHOSTS, 4 + len);

      set_uint32(outmsg, len);
      memcpy(outmsg + 4, cred->rhosts_client_user, len);
      AUTH_destroy_cur_cred(pvar);
      break;
    }
    case SSH_AUTH_RSA: {
      int len = BN_num_bytes(cred->RSA_key->n);
      unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_AUTH_RSA, 2 + len);

      set_ushort16_MSBfirst(outmsg, len*8);
      BN_bn2bin(cred->RSA_key->n, outmsg + 2);
      /* don't destroy the current credentials yet */
      pvar->ssh_state.packet_handler = rsa_challenge_handler;
      break;
    }
    case SSH_AUTH_RHOSTS_RSA: {
      int mod_len = BN_num_bytes(cred->RSA_key->n);
      int name_len = strlen(cred->rhosts_client_user);
      int exp_len = BN_num_bytes(cred->RSA_key->e);
      int index;
      unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_AUTH_RHOSTS_RSA, 12 + mod_len
        + name_len + exp_len);

      set_uint32(outmsg, name_len);
      memcpy(outmsg + 4, cred->rhosts_client_user, name_len);
      index = 4 + name_len;

      set_uint32(outmsg + index, 8*mod_len);
      set_ushort16_MSBfirst(outmsg + index + 4, 8*exp_len);
      BN_bn2bin(cred->RSA_key->e, outmsg + index + 6);
      index += 6 + exp_len;

      set_ushort16_MSBfirst(outmsg + index, 8*mod_len);
      BN_bn2bin(cred->RSA_key->n, outmsg + index + 2);
      /* don't destroy the current credentials yet */
      pvar->ssh_state.packet_handler = rsa_challenge_handler;
      break;
    }
    default:
      notify_fatal_error(pvar, "Internal error: unsupported authentication method");
      return;
    }

    finish_send_packet(pvar);
    destroy_packet_buf(pvar);

    pvar->ssh_state.status_flags |= STATUS_DONT_SEND_CREDENTIALS;
  }
}

static void handle_auth_failure(struct _TInstVar FAR * pvar) {
  AUTH_advance_to_next_cred(pvar);
  pvar->ssh_state.status_flags &= ~STATUS_DONT_SEND_CREDENTIALS;
  pvar->ssh_state.packet_handler = auth_status_handler;
  try_send_credentials(pvar);
}

static void handle_rsa_auth_refused(struct _TInstVar FAR * pvar) {
  AUTH_destroy_cur_cred(pvar);
  handle_auth_failure(pvar);
}

static void rsa_challenge_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_SMSG_FAILURE:      handle_rsa_auth_refused(pvar); break;
  case SSH_SMSG_AUTH_RSA_CHALLENGE:
                              handle_rsa_challenge(pvar); break;
  default:
    notify_fatal_error(pvar, "Invalid packet; expected authentication acknowledgement");
  }
}

static void auth_status_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_SMSG_FAILURE:      handle_auth_failure(pvar); break;
  case SSH_SMSG_SUCCESS:      handle_prep_phase(pvar); break;
  default:
    notify_fatal_error(pvar, "Invalid packet; expected authentication acknowledgement");
  }
}

static void handle_auth_required(struct _TInstVar FAR * pvar) {
  pvar->ssh_state.status_flags &= ~STATUS_DONT_SEND_CREDENTIALS;
  pvar->ssh_state.packet_handler = auth_status_handler;
  try_send_credentials(pvar);
  /* the first AUTH_advance_to_next_cred is issued early by ttssh.c */
}

static void first_auth_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_SMSG_FAILURE:      handle_auth_required(pvar); break;
  case SSH_SMSG_SUCCESS:      handle_prep_phase(pvar); break;
  default:
    notify_fatal_error(pvar, "Invalid packet; expected authentication acknowledgement");
  }
}

static void crypt_status_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_SMSG_SUCCESS:      pvar->ssh_state.packet_handler = first_auth_handler; break;
  default:
    notify_fatal_error(pvar, "Invalid packet; expected encrypted acknowledgement");
  }
}

static void expecting_nothing_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  default:
    notify_fatal_error(pvar, "Invalid packet; expected nothing while waiting for user response");
  }
}

static void try_send_user_name(struct _TInstVar FAR * pvar) {
  if ((pvar->ssh_state.status_flags & STATUS_DONT_SEND_USER_NAME) == 0) {
    char FAR * username = AUTH_get_user_name(pvar);

    if (username != NULL) {
      int len = strlen(username);
      unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_USER, 4 + len);

      set_uint32(outmsg, len);
      memcpy(outmsg + 4, username, len);
      finish_send_packet(pvar);

      pvar->ssh_state.status_flags |= STATUS_DONT_SEND_USER_NAME;
    }
  }
}

static void send_session_key(struct _TInstVar FAR * pvar) {
  int encrypted_session_key_len = CRYPT_get_encrypted_session_key_len(pvar);
  unsigned char FAR * outmsg;

  outmsg = begin_send_packet(pvar, SSH_CMSG_SESSION_KEY, 15 + encrypted_session_key_len);
  outmsg[0] = (unsigned char)CRYPT_choose_cipher(pvar);
  memcpy(outmsg + 1, CRYPT_get_cookie(pvar), 8); /* antispoofing cookie */
  outmsg[9] = (unsigned char)(encrypted_session_key_len >> 5);
  outmsg[10] = (unsigned char)(encrypted_session_key_len << 3);
  if (!CRYPT_choose_session_key(pvar, outmsg + 11)) return;
  set_uint32(outmsg + 11 + encrypted_session_key_len,
    SSH_PROTOFLAG_SCREEN_NUMBER | SSH_PROTOFLAG_HOST_IN_FWD_OPEN);
  finish_send_packet(pvar);

  if (!CRYPT_start_encryption(pvar)) return;
  notify_established_secure_connection(pvar);

  pvar->ssh_state.status_flags &= ~STATUS_DONT_SEND_USER_NAME;
  try_send_user_name(pvar);

  pvar->ssh_state.packet_handler = crypt_status_handler;
}

static void handle_server_public_key(struct _TInstVar FAR * pvar) {
  int server_key_public_exponent_len;
  int server_key_public_modulus_pos;
  int server_key_public_modulus_len;
  int host_key_bits_pos;
  int host_key_public_exponent_len;
  int host_key_public_modulus_pos;
  int host_key_public_modulus_len;
  int protocol_flags_pos;
  char FAR * inmsg;
 
  if (!grab_payload(pvar, 14)) return;
  server_key_public_exponent_len = get_mpint_len(pvar, 12);

  if (!grab_payload(pvar, server_key_public_exponent_len + 2)) return;
  server_key_public_modulus_pos = 14 + server_key_public_exponent_len;
  server_key_public_modulus_len = get_mpint_len(pvar, server_key_public_modulus_pos);

  if (!grab_payload(pvar, server_key_public_modulus_len + 6)) return;
  host_key_bits_pos = server_key_public_modulus_pos + 2 + server_key_public_modulus_len;
  host_key_public_exponent_len = get_mpint_len(pvar, host_key_bits_pos + 4);

  if (!grab_payload(pvar, host_key_public_exponent_len + 2)) return;
  host_key_public_modulus_pos = host_key_bits_pos + 6 + host_key_public_exponent_len;
  host_key_public_modulus_len = get_mpint_len(pvar, host_key_public_modulus_pos);

  if (!grab_payload(pvar, host_key_public_modulus_len + 12)) return;
  protocol_flags_pos = host_key_public_modulus_pos + 2 + host_key_public_modulus_len;

  inmsg = pvar->ssh_state.payload;

  if (!CRYPT_set_cookie(pvar, inmsg)) return;
  if (!CRYPT_set_server_key(pvar, get_uint32(inmsg + 8), pvar->ssh_state.payload + 12,
    inmsg + server_key_public_modulus_pos)) return;
  if (!CRYPT_set_host_key(pvar, get_uint32(inmsg + host_key_bits_pos),
    inmsg + host_key_bits_pos + 4,
    inmsg + host_key_public_modulus_pos)) return;
  pvar->ssh_state.server_protocol_flags = get_uint32(inmsg + protocol_flags_pos);
  if (!CRYPT_set_supported_ciphers(pvar, get_uint32(inmsg + protocol_flags_pos + 4))) return;
  if (!AUTH_set_supported_auth_types(pvar, get_uint32(inmsg + protocol_flags_pos + 8))) return;

  pvar->ssh_state.packet_handler = expecting_nothing_handler;
  /* this must be the LAST THING in this function, since it can cause
     host_is_OK to be called. */
  HOSTS_check_host_key(pvar, pvar->ssh_state.hostname,
    get_uint32(inmsg + host_key_bits_pos),
    inmsg + host_key_bits_pos + 4,
    inmsg + host_key_public_modulus_pos);
}

static void server_public_key_handler(struct _TInstVar FAR * pvar, char FAR * data, int len, int padding) {
  switch (prep_packet(pvar, data, len, padding)) {
  case SSH_MSG_NONE:          break;
  case SSH_MSG_IGNORE:        handle_ignore(pvar); break;
  case SSH_MSG_DEBUG:         handle_debug(pvar); break;
  case SSH_MSG_DISCONNECT:    handle_disconnect(pvar); break;
  case SSH_SMSG_PUBLIC_KEY:   handle_server_public_key(pvar); break;
  default:
    notify_fatal_error(pvar, "Invalid packet; expected server public key");
  }
}

void SSH_init(struct _TInstVar FAR * pvar) {
  pvar->ssh_state.hostname = _strdup(pvar->ts->HostName);
  buf_create(&pvar->ssh_state.outbuf, &pvar->ssh_state.outbuflen);
  buf_create(&pvar->ssh_state.precompress_outbuf, &pvar->ssh_state.precompress_outbuflen);
  buf_create(&pvar->ssh_state.postdecompress_inbuf, &pvar->ssh_state.postdecompress_inbuflen);
  pvar->ssh_state.payload = NULL;
  pvar->ssh_state.compressing = FALSE;
  pvar->ssh_state.packet_handler = server_public_key_handler;
  pvar->ssh_state.status_flags = STATUS_DONT_SEND_USER_NAME | STATUS_DONT_SEND_CREDENTIALS;
  pvar->ssh_state.payload_datalen = 0;
  pvar->ssh_state.win_cols = pvar->ts->TerminalWidth;
  pvar->ssh_state.win_rows = pvar->ts->TerminalHeight;
}

void SSH_notify_disconnecting(struct _TInstVar FAR * pvar, char FAR * reason) {
  int len = reason == NULL ? 0 : strlen(reason);
  unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_MSG_DISCONNECT, len + 4);

  set_uint32(outmsg, len);
  if (reason != NULL) {
    memcpy(outmsg + 4, reason, len);
  }
  finish_send_packet(pvar);
}

void SSH_notify_host_OK(struct _TInstVar FAR * pvar) {
  if ((pvar->ssh_state.status_flags & STATUS_HOST_OK) == 0) {
    pvar->ssh_state.status_flags |= STATUS_HOST_OK;
    send_session_key(pvar);
  }
}

void SSH_notify_win_size(struct _TInstVar FAR * pvar, int cols, int rows) {
  pvar->ssh_state.win_cols = cols;
  pvar->ssh_state.win_rows = rows;

  if (pvar->ssh_state.packet_handler == interactive_mode_handler) {
    unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_WINDOW_SIZE, 16);
    
    set_uint32(outmsg, rows);
    set_uint32(outmsg + 4, cols);
    set_uint32(outmsg + 8, 0);
    set_uint32(outmsg + 12, 0);
    finish_send_packet(pvar);
  }
}

void SSH_notify_user_name(struct _TInstVar FAR * pvar) {
  try_send_user_name(pvar);
}

void SSH_notify_cred(struct _TInstVar FAR * pvar) {
  try_send_credentials(pvar);
}

void SSH_send(struct _TInstVar FAR * pvar, unsigned char const FAR * buf, int buflen) {
  if (pvar->ssh_state.packet_handler != interactive_mode_handler) {
    return;
  }

  while (buflen > 0) {
    int len = buflen > SSH_MAX_SEND_PACKET_SIZE ? SSH_MAX_SEND_PACKET_SIZE : buflen;
    unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_STDIN_DATA, 4 + len);
    
    set_uint32(outmsg, len);
    
    if (pvar->ssh_state.compressing) {
      buf_ensure_size(&pvar->ssh_state.outbuf, &pvar->ssh_state.outbuflen, len + (len >> 6) + 50);
      pvar->ssh_state.compress_stream.next_in = pvar->ssh_state.precompress_outbuf;
      pvar->ssh_state.compress_stream.avail_in = 5;
      pvar->ssh_state.compress_stream.next_out = pvar->ssh_state.outbuf + 12;
      pvar->ssh_state.compress_stream.avail_out = pvar->ssh_state.outbuflen - 12;
      
      if (deflate(&pvar->ssh_state.compress_stream, Z_NO_FLUSH) != Z_OK) {
        notify_fatal_error(pvar, "Error compressing packet data");
        return;
      }

      pvar->ssh_state.compress_stream.next_in = (unsigned char FAR *)buf;
      pvar->ssh_state.compress_stream.avail_in = len;

      if (deflate(&pvar->ssh_state.compress_stream, Z_PARTIAL_FLUSH) != Z_OK) {
        notify_fatal_error(pvar, "Error compressing packet data");
        return;
      }
    } else {
      memcpy(outmsg + 4, buf, len);
    }
    
    finish_send_packet_special(pvar, 1);

    buflen -= len;
    buf += len;
  }
}

int SSH_extract_payload(struct _TInstVar FAR * pvar, unsigned char FAR * dest, int len) {
  int num_bytes = pvar->ssh_state.payload_datalen;

  if (num_bytes > len) {
    num_bytes = len;
  }

  if (!pvar->ssh_state.compressing) {
    memcpy(dest, pvar->ssh_state.payload + pvar->ssh_state.payload_datastart, num_bytes);
    pvar->ssh_state.payload_datastart += num_bytes;
  } else if (num_bytes > 0) {
    pvar->ssh_state.decompress_stream.next_out = dest;
    pvar->ssh_state.decompress_stream.avail_out = num_bytes;
    
    if (inflate(&pvar->ssh_state.decompress_stream, Z_PARTIAL_FLUSH) != Z_OK) {
      notify_fatal_error(pvar, "Invalid compressed data in received packet");
      return 0;
    }
  }

  pvar->ssh_state.payload_datalen -= num_bytes;

  return num_bytes;
}

void SSH_get_compression_info(struct _TInstVar FAR * pvar, char FAR * dest, int len) {
  if (!pvar->ssh_state.compressing) {
    strncpy(dest, "None", len);
  } else {
    unsigned long total_in = pvar->ssh_state.compress_stream.total_in + pvar->ssh_state.decompress_stream.total_in;
    unsigned long total_out = pvar->ssh_state.compress_stream.total_out + pvar->ssh_state.decompress_stream.total_out;

    if (total_out > 0 && total_in > 0) {
      _snprintf(dest, len, "Level %d; ratio %.1f (up %ld:%ld, down %ld:%ld)",
        pvar->ssh_state.compression_level, ((double)total_out)/total_in,
        pvar->ssh_state.compress_stream.total_in, pvar->ssh_state.compress_stream.total_out,
        pvar->ssh_state.decompress_stream.total_in, pvar->ssh_state.decompress_stream.total_out);
    } else {
      _snprintf(dest, len, "Level %d");
    }
  }
}

void SSH_end(struct _TInstVar FAR * pvar) {
  free(pvar->ssh_state.hostname);
  buf_destroy(&pvar->ssh_state.outbuf, &pvar->ssh_state.outbuflen);
  buf_destroy(&pvar->ssh_state.precompress_outbuf, &pvar->ssh_state.precompress_outbuflen);
  buf_destroy(&pvar->ssh_state.postdecompress_inbuf, &pvar->ssh_state.postdecompress_inbuflen);

  if (pvar->ssh_state.compressing) {
    deflateEnd(&pvar->ssh_state.compress_stream);
    inflateEnd(&pvar->ssh_state.decompress_stream);
  }
}

/* support for port forwarding */
void SSH_channel_send(struct _TInstVar FAR * pvar, uint32 remote_channel_num,
                      unsigned char FAR * buf, int len) {
  unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_MSG_CHANNEL_DATA, 8 + len);
  
  set_uint32(outmsg, remote_channel_num);
  set_uint32(outmsg + 4, len);
  
  if (pvar->ssh_state.compressing) {
    buf_ensure_size(&pvar->ssh_state.outbuf, &pvar->ssh_state.outbuflen, len + (len >> 6) + 50);
    pvar->ssh_state.compress_stream.next_in = pvar->ssh_state.precompress_outbuf;
    pvar->ssh_state.compress_stream.avail_in = 9;
    pvar->ssh_state.compress_stream.next_out = pvar->ssh_state.outbuf + 12;
    pvar->ssh_state.compress_stream.avail_out = pvar->ssh_state.outbuflen - 12;
    
    if (deflate(&pvar->ssh_state.compress_stream, Z_NO_FLUSH) != Z_OK) {
      notify_fatal_error(pvar, "Error compressing packet data");
      return;
    }
    
    pvar->ssh_state.compress_stream.next_in = (unsigned char FAR *)buf;
    pvar->ssh_state.compress_stream.avail_in = len;
    
    if (deflate(&pvar->ssh_state.compress_stream, Z_PARTIAL_FLUSH) != Z_OK) {
      notify_fatal_error(pvar, "Error compressing packet data");
      return;
    }
  } else {
    memcpy(outmsg + 8, buf, len);
  }
  
  finish_send_packet_special(pvar, 1);
}

void SSH_fail_channel_open(struct _TInstVar FAR * pvar, uint32 remote_channel_num) {
  unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_MSG_CHANNEL_OPEN_FAILURE, 4);

  set_uint32(outmsg, remote_channel_num);
  finish_send_packet(pvar);
}

void SSH_confirm_channel_open(struct _TInstVar FAR * pvar, uint32 remote_channel_num, uint32 local_channel_num) {
  unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_MSG_CHANNEL_OPEN_CONFIRMATION, 8);

  set_uint32(outmsg, remote_channel_num);
  set_uint32(outmsg + 4, local_channel_num);
  finish_send_packet(pvar);
}

void SSH_channel_output_eof(struct _TInstVar FAR * pvar, uint32 remote_channel_num) {
  unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_MSG_CHANNEL_OUTPUT_CLOSED, 4);

  set_uint32(outmsg, remote_channel_num);
  finish_send_packet(pvar);
}

void SSH_channel_input_eof(struct _TInstVar FAR * pvar, uint32 remote_channel_num) {
  unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_MSG_CHANNEL_INPUT_EOF, 4);

  set_uint32(outmsg, remote_channel_num);
  finish_send_packet(pvar);
}

void SSH_request_forwarding(struct _TInstVar FAR * pvar, int from_server_port,
                            char FAR * to_local_host, int to_local_port) {
  int host_len = strlen(to_local_host);
  unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_PORT_FORWARD_REQUEST,
    12 + host_len);

  set_uint32(outmsg, from_server_port);
  set_uint32(outmsg + 4, host_len);
  memcpy(outmsg + 8, to_local_host, host_len);
  set_uint32(outmsg + 8 + host_len, to_local_port);
  finish_send_packet(pvar);
  pvar->ssh_state.num_FWD_server_acks_expected++;
}

void SSH_request_X11_forwarding(struct _TInstVar FAR * pvar,
                                char FAR * auth_protocol, char FAR * auth_data, int screen_num) {
  int protocol_len = strlen(auth_protocol);
  int data_len = strlen(auth_data);

  if ((pvar->ssh_state.server_protocol_flags & SSH_PROTOFLAG_SCREEN_NUMBER) != 0) {
    unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_X11_REQUEST_FORWARDING,
      12 + protocol_len + data_len);
    
    set_uint32(outmsg, protocol_len);
    memcpy(outmsg + 4, auth_protocol, protocol_len);
    set_uint32(outmsg + 4 + protocol_len, data_len);
    memcpy(outmsg + 8 + protocol_len, auth_data, data_len);
    set_uint32(outmsg + 8 + protocol_len + data_len, screen_num);
  } else {
    unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_CMSG_X11_REQUEST_FORWARDING,
      8 + protocol_len + data_len);
    
    set_uint32(outmsg, protocol_len);
    memcpy(outmsg + 4, auth_protocol, protocol_len);
    set_uint32(outmsg + 4 + protocol_len, data_len);
    memcpy(outmsg + 8 + protocol_len, auth_data, data_len);
  }

  finish_send_packet(pvar);
  pvar->ssh_state.num_FWD_server_acks_expected++;
}

void SSH_open_channel(struct _TInstVar FAR * pvar, uint32 local_channel_num,
                      char FAR * to_remote_host, int to_remote_port, char FAR * originator) {
  int host_len = strlen(to_remote_host);

  if ((pvar->ssh_state.server_protocol_flags & SSH_PROTOFLAG_HOST_IN_FWD_OPEN) != 0) {
    int originator_len = strlen(originator);
    unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_MSG_PORT_OPEN,
      16 + host_len + originator_len);
    
    set_uint32(outmsg, local_channel_num);
    set_uint32(outmsg + 4, host_len);
    memcpy(outmsg + 8, to_remote_host, host_len);
    set_uint32(outmsg + 8 + host_len, to_remote_port);
    set_uint32(outmsg + 12 + host_len, originator_len);
    memcpy(outmsg + 16 + host_len, originator, originator_len);
  } else {
    unsigned char FAR * outmsg = begin_send_packet(pvar, SSH_MSG_PORT_OPEN,
      12 + host_len);
    
    set_uint32(outmsg, local_channel_num);
    set_uint32(outmsg + 4, host_len);
    memcpy(outmsg + 8, to_remote_host, host_len);
    set_uint32(outmsg + 8 + host_len, to_remote_port);
  }
  
  finish_send_packet(pvar);
}
