Logo Search packages:      
Sourcecode: sbrsh version File versions  Download package

daemon.c

/*
 * Copyright (c) 2003, 2004, 2005 Nokia
 * Author: Timo Savola <tsavola@movial.fi>
 *
 * This program is licensed under GPL (see COPYING for details)
 */

#define _GNU_SOURCE

#include "types.h"
#include "daemon.h"
#include "common.h"
#include "protocol.h"
#include "buffer.h"
#include "mount.h"
#include "fakeroot.h"
#include "version.h"

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <stdarg.h>
#include <ctype.h>
#include <syslog.h>
#include <pwd.h>
#include <netdb.h>
#include <signal.h>
#include <errno.h>
#include <pty.h>
#include <utmp.h>
#include <fcntl.h>
#include <getopt.h>
#include <time.h>
#include <grp.h>
#include <limits.h>
#include <libgen.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/resource.h>
#include <sys/select.h>
#include <sys/param.h>
#include <netinet/in.h>
#include <assert.h>

#define ISSET(fd, set)  ((fd) >= 0 && FD_ISSET((fd), (set)))
#define debug_vector    if (debug_file) print_debug_vector

extern char **environ;

FILE *debug_file = NULL;
static char *debug_filename = NULL;

static int daemon_pid = -1;

static char *port = DEFAULT_PORT;
static bool_t local_only = FALSE;
static bool_t sandbox = TRUE;
static bool_t allow_root = FALSE;
static int mount_expiration = DEFAULT_MOUNT_EXPIRATION;
static char *mount_cmd = DEFAULT_MOUNT_CMD;
static char *umount_cmd = DEFAULT_UMOUNT_CMD;
static char *bind_opt = DEFAULT_BIND_OPT;

/** Doubly linked list of mount entries. */
static struct {
      mount_t *head;
      mount_t *tail;
} all_mounts = {
      NULL,
      NULL
};

/** The mounts of the handler processes. */
static pid_mounts_t *pid_mounts = NULL;

static struct {
      pid_t pid;
      const char *name;
} debug_info = {
      -1,
      NULL
};

void set_debug_name(const char *name)
{
      debug_info.pid = getpid();
      debug_info.name = name;
}

/**
 * Prints only timestamp and pid to debug log.
 */
static void print_debug_prefix(void)
{
      struct timeval tv;
      struct tm *t;
      const char *domain = "";

      gettimeofday(&tv, NULL);
      t = gmtime(&tv.tv_sec);

      if (debug_info.pid == getpid())
            domain = debug_info.name;

      fprintf(debug_file, "%02d-%02d-%04d %02d:%02d:%02d.%03ld %5d %7s ",
            t->tm_mday, t->tm_mon + 1, 1900 + t->tm_year, t->tm_hour, t->tm_min,
            t->tm_sec, tv.tv_usec / 1000, getpid(), domain);
}

static void open_debug_log(void)
{
      if (debug_file) {
            debug("Debugging is already enabled");
            return;
      }

      if (!debug_filename) {
            debug_filename = malloc(strlen(DEFAULT_DEBUG_FILENAME_FMT) + 5);
            if (!debug_filename) {
                  oom_error();
                  return;
            }

            sprintf(debug_filename, DEFAULT_DEBUG_FILENAME_FMT, port);
      }

      debug_file = fopen(debug_filename, "a");
      if (!debug_file) {
            error("Can't append to %s", debug_filename);
            return;
      }

      debug("Debugging enabled");
      debug("sbrshd version %d%s", PROTOCOL_VERSION, REVISION);
      debug("Port: %s", port);
      debug("Local only: %s", local_only ? "yes" : "no");
      debug("Sandbox: %s", sandbox ? "yes" : "no");
      debug("Allow root: %s", allow_root ? "yes" : "no");

      if (mount_expiration > 0) {
            debug("Mount expiration: %d seconds", mount_expiration);
      } else if (mount_expiration == 0) {
            debug("Mount expiration: immediate");
      } else {
            debug("Mount expiration: never");
      }
}

static void close_debug_log(void)
{
      if (debug_file) {
            debug("Debugging disabled");

            fclose(debug_file);
            debug_file = NULL;
      }
}

/**
 * Prints message to debug log.
 */
void print_debug(const char *msg, ...)
{
      va_list arg;

      print_debug_prefix();

      va_start(arg, msg);
      vfprintf(debug_file, msg, arg);
      va_end(arg);

      fprintf(debug_file, "\n");
      fflush(debug_file);
}

/**
 * Prints string vector to debug log.
 */
static void print_debug_vector(const char *msg, char **vec)
{
      print_debug_prefix();
      fprintf(debug_file, "%s", msg);

      while (*vec)
            fprintf(debug_file, " %s", *vec++);

      fprintf(debug_file, "\n");
      fflush(debug_file);
}

/**
 * Prints message and errno description (if non-zero) to syslog and sends it to
 * the client (if data is non-null).
 */
static void print_error(handler_t *data, const char *progname, const char *msg,
                  va_list arg, int priority)
{
      char str[1024] = { '\0' }, *err;
      size_t len;

      vsnprintf(str, sizeof (str) - 1, msg, arg);
      len = strlen(str);

      err = strerror(errno);
      if (errno && err && (len + strlen(err) + 3) < sizeof (str)) {
            strcat(str, " (");
            strcat(str, err);
            strcat(str, ")");

            len = strlen(str);
      }

      if (progname) {
            fprintf(stderr, "%s: %s\n", progname, str);
      } else {
            syslog(priority, "%s", str);
            debug(priority == LOG_WARNING ? "Warning: %s" : "Error: %s", str);
      }

      if (data)
            write_str_packet(data->sd, PTYPE_ERROR, str);
}

void send_error(handler_t *data, const char *msg, ...)
{
      va_list arg;

      if (!msg) {
            msg = strerror(errno);
            errno = 0;
      }

      va_start(arg, msg);
      print_error(data, NULL, msg, arg, LOG_ERR);
      va_end(arg);
}

/**
 * Prints message and errno description to syslog and stderr.
 */
static void error_err(const char *progname, const char *msg, ...)
{
      va_list arg;

      va_start(arg, msg);
      print_error(NULL, progname, msg, arg, LOG_ERR);
      va_end(arg);
}

/*
 * Prints message and errno description to syslog.
 */
void error(const char *msg, ...)
{
      va_list arg;

      va_start(arg, msg);
      print_error(NULL, NULL, msg, arg, LOG_ERR);
      va_end(arg);
}

/**
 * Prints message to syslog.
 */
static void warn(const char *msg, ...)
{
      va_list arg;

      errno = 0;

      va_start(arg, msg);
      print_error(NULL, NULL, msg, arg, LOG_WARNING);
      va_end(arg);
}

/**
 * Passes the contents of BUF to send_error().
 */
static void flush(handler_t *data, char *buf, size_t *lenp)
{
      if (*lenp > 0) {
            buf[*lenp] = '\0';
            *lenp = 0;

            errno = 0;
            send_error(data, buf);
      }
}

/**
 * Executes argv[0] in a child process.
 * @return -1 on error, 0 otherwise
 */
static int execute(handler_t *data, char **argv)
{
      pid_t pid;
      int status, err[2];
      size_t pos = 0;
      char line[1024];

      if (pipe(err) < 0) {
            send_error(data, "Can't create pipe");
            return -1;
      }

      pid = fork();
      if (pid < 0) {
            send_error(data, "Can't fork");
            return -1;
      }

      if (pid == 0) {
            /* child */

            set_debug_name("EXECUTE");

            if (dup2(err[1], STDOUT_FILENO) != STDOUT_FILENO) {
                  send_error(data, "Can't duplicate pipe as stdout");
                  exit(1);
            }
            if (dup2(err[1], STDERR_FILENO) != STDERR_FILENO) {
                  send_error(data, "Can't duplicate pipe as stderr");
                  exit(1);
            }
            close(err[0]);

            execv(argv[0], argv);

            send_error(data, "Can't execute command: %s", argv[0]);
            exit(1);
      }

      /* parent */

      close(err[1]);

      while (1) {
            ssize_t len;
            char c;

            len = read(err[0], &c, 1);
            if (len < 0 && errno == EINTR)
                  continue;

            if (len <= 0) {
                  flush(data, line, &pos);
                  break;
            }

            if (c == '\n' || c == '\0') {
                  flush(data, line, &pos);
                  continue;
            }

            if (pos == sizeof (line))
                  flush(data, line, &pos);

            line[pos++] = c;
      }

      close(err[0]);

      if (waitpid(pid, &status, 0) != pid)
            return -1;

      if (status < 0)
            return -1;

      errno = 0;
      return (WEXITSTATUS(status) != 0) ? -1 : 0;
}

static void check_for_busybox(const char *progname)
{
      char mount_buf[PATH_MAX], *real_mount;

      real_mount = realpath(mount_cmd, mount_buf);
      if (!real_mount) {
            error_err(progname, "Can't get real path of %s", mount_cmd);
            exit(1);
      }

      if (strstr(basename(real_mount), "busybox") != NULL) {
            debug("%s is Busybox", mount_cmd);
            bind_opt = DEFAULT_BIND_OPT_BUSYBOX;
      }
}

/**
 * Read the lines of the config file into a string vector.
 */
static char **read_config(const char *filename)
{
      int count, i, slot;
      char **lines;
      FILE *file;

      file = fopen(filename, "r");
      if (!file) {
            if (errno == ENOENT)
                  return calloc(1, sizeof (char *));

            error_err(filename, "cannot open for reading");
            return NULL;
      }

      count = 1;  /* count EOF as a newline */
      while (1) {
            int c = fgetc(file);
            if (c == EOF)
                  break;
            if (c == '\n')
                  ++count;
      }

      rewind(file);

      lines = calloc(count + 1, sizeof (char *));
      if (!lines) {
            oom_error();
            goto _out;
      }

      for (i = 0, slot = 0; i < count; ++i) {
            int pos, len;
            char *line, *ptr;

            pos = ftell(file);
            if (pos < 0) {
                  error_err(filename, "tell");
                  goto _err;
            }

            len = 0;
            while (1) {
                  int c = fgetc(file);
                  if (c == EOF)
                        break;
                  ++len;
                  if (c == '\n')
                        break;
            }

            if (fseek(file, pos, SEEK_SET) < 0) {
                  error_err(filename, "seek");
                  goto _err;
            }

            line = calloc(len + 1, sizeof (char));
            if (!line) {
                  oom_error();
                  goto _err;
            }

            if (fread(line, sizeof (char), len, file) != len) {
                  error_err(filename, "read");
                  goto _err;
            }

            if (len == 0)
                  continue;

            if (line[len - 1] == '\n')
                  line[len - 1] = '\0';

            ptr = strchr(line, '#');
            if (ptr)
                  *ptr = '\0';

            if (strlen(line) > 0)
                  lines[slot++] = line;
      }

_out:
      fclose(file);
      return lines;

_err:
      free_vec((void **) lines, NULL);
      goto _out;
}

/**
 * Write a string vector into the config file.
 */
static int write_config(const char *filename, char **lines)
{
      FILE *file;
      int i, retval = -1;

      file = fopen(filename, "w");
      if (!file) {
            error_err(filename, "cannot open for writing");
            return -1;
      }

      for (i = 0; lines[i]; ++i) {
            char *line;
            int len;

            line = lines[i];
            len = strlen(line);

            if (fwrite(line, sizeof (char), len, file) != len) {
                  error_err(filename, "write");
                  goto _out;
            }

            if (len == 0 || line[len - 1] != '\n')
                  if (fputc('\n', file) == EOF) {
                        error_err(filename, "write");
                        goto _out;
                  }
      }

      retval = 0;

_out:
      fclose(file);
      return retval;
}

/**
 * Add an entry to the config file.
 */
static int add_to_config(char *host)
{
      char **lines;
      int i, retval = 0;
      bool_t found = FALSE;

      lines = read_config(CONFIG_NAME);
      if (lines == NULL)
            return -1;

      for (i = 0; lines[i]; ++i) {
            char *line = lines[i];
            char *ptr = find_space(line);

            if (strlen(host) == ptr - line &&
                strncmp(line, host, ptr - line) == 0) {
                  found = TRUE;
                  break;
            }
      }

      if (!found) {
            int oldlineslen;
            char **vec, *line;

            oldlineslen = calc_vec_len((void **) lines);

            vec = calloc(sizeof (char *), oldlineslen + 2);
            if (!vec) {
                  oom_error();
                  retval = -1;
                  goto _out;
            }

            memcpy(vec, lines, sizeof (char *) * oldlineslen);

            free(lines);
            lines = vec;

            line = strdup(host);
            if (!line) {
                  oom_error();
                  retval = -1;
                  goto _out;
            }

            lines[oldlineslen] = line;

            retval = write_config(CONFIG_NAME, lines);
      }

_out:
      free_vec((void **) lines, NULL);
      return retval;
}

/**
 * Finds an IP address from a config file.
 * @param data handler state
 * @param filename the config file path
 * @return TRUE if found
 */
static bool_t find_host(handler_t *data, const char *filename)
{
      bool_t found = FALSE;
      char **lines;
      int i;

      lines = read_config(CONFIG_NAME);
      if (lines == NULL)
            return FALSE;

      for (i = 0; lines[i]; ++i) {
            char *line, *ptr, *p, *star;
            int len;

            line = lines[i];
            ptr = find_space(line);

            for (p = ptr; *p; p++)
                  if (!isspace(*p)) {
                        errno = 0;
                        error("Malformed line in %s: \"%s\" (old config file format?)",
                              CONFIG_NAME, line);
                        return FALSE;
                  }

            *ptr = '\0';
            len = ptr - line;

            star = strchr(line, '*');
            if (star) {
                  int pos = star - line;

                  if (pos != len - 1) {
                        errno = 0;
                        error("Malformed line in %s: %s", CONFIG_NAME, line);
                        continue;
                  }

                  --len;
            }

            if (!star && strlen(data->host) != len)
                  continue;

            if (strncmp(line, data->host, len) == 0) {
                  debug("Found matching line: %s", line);
                  found = TRUE;
                  break;
            }
      }

      free_vec((void **) lines, NULL);

      return found;
}

/**
 * Recursively creates all needed directories (with MKDIR_PERMS).
 */
static int mkdirs(const char *path)
{
      char *parent, *p;
      int rc = -1;

      if (mkdir(path, MKDIR_PERMS) == 0 || errno == EEXIST)
            return 0;

      if (errno != ENOENT)
            return -1;

      parent = strdup(path);
      if (!parent)
            return -1;

      p = strrchr(parent, '/');
      if (p && parent != p) {
            *p = '\0';
            rc = 0;
      }

      if (rc == 0)
            rc = mkdirs(parent);

      free(parent);

      if (rc == 0 && mkdir(path, MKDIR_PERMS) < 0 && errno != EEXIST)
            rc = -1;

      return rc;
}

/**
 * Frees all resources associated with a mount entry.
 */
static void mnt_free(mount_t *m)
{
      if (m->info.opts)
            free(m->info.opts);

      if (m->info.device)
            free(m->info.device);

      if (m->info.point)
            free(m->info.point);

      free(m);
}

/**
 * Finds an entry from the MOUNTS list.
 * @param point the mount point of the entry
 */
static mount_t *mnt_list_find(const char *point)
{
      mount_t *mnt;

      for (mnt = all_mounts.head; mnt; mnt = mnt->next)
            if (strcmp(mnt->info.point, point) == 0)
                  break;

      return mnt;
}

/**
 * Adds an entry to the MOUNTS list. The entries are kept in ascending order
 * based on their mount point strings.
 */
static void mnt_list_add(mount_t *mnt)
{
      mount_t *prev, *next;

      prev = NULL;
      next = all_mounts.head;
      while (next) {
            if (strcmp(next->info.point, mnt->info.point) > 0)
                  break;

            prev = next;
            next = next->next;
      }

      mnt->prev = prev;
      mnt->next = next;

      if (prev)
            prev->next = mnt;
      else
            all_mounts.head = mnt;

      if (next)
            next->prev = mnt;
      else
            all_mounts.tail = mnt;
}

/**
 * Removes and frees an entry from the MOUNTS list.
 */
static void mnt_list_del(mount_t *mnt)
{
      if (mnt->prev)
            mnt->prev->next = mnt->next;
      else
            all_mounts.head = mnt->next;

      if (mnt->next)
            mnt->next->prev = mnt->prev;
      else
            all_mounts.tail = mnt->prev;

      mnt_free(mnt);
}

/**
 * Executes mount_cmd.
 */
static int do_mount(handler_t *data, const mount_info_t *mi)
{
      char *argv[8];

      memset(argv, 0, sizeof (argv));

      argv[0] = mount_cmd;
      argv[1] = mi->device;
      argv[2] = mi->point;

      switch (mi->type) {
      case MTYPE_NFS:
            argv[3] = "-t";
            argv[4] = "nfs";
            argv[5] = mi->opts ? "-o" : NULL;
            argv[6] = mi->opts;

            if (mi->opts) {
                  debug("Executing: %s %s %s -t nfs -o %s", argv[0], argv[1], argv[2], argv[6]);
            } else {
                  debug("Executing: %s %s %s -t nfs", argv[0], argv[1], argv[2]);
            }
            break;

      case MTYPE_BIND:
            split_string(bind_opt, &argv[3], &argv[4], &argv[5], &argv[6], NULL);

            debug("Executing: %s %s %s %s %s %s %s", argv[0], argv[1], argv[2],
                  argv[3] ? argv[3] : "", argv[4] ? argv[4] : "",
                  argv[5] ? argv[5] : "", argv[6] ? argv[6] : "");

            break;

      default:
            errno = EINVAL;
            return -1;
      }

      return execute(data, argv);
}

static int do_unmount(handler_t *data, char *point)
{
      char *argv[] = { umount_cmd, point, NULL };

      debug("Executing: %s %s", argv[0], argv[1]);

      return execute(data, argv);
}

static mount_t *mnt_create(handler_t *data, mount_info_t *mi)
{
      mount_t *mnt;

      mnt = mnt_list_find(mi->point);
      if (mnt)
            return mnt;

      mnt = calloc(1, sizeof (mount_t));
      if (!mnt) {
            errno = 0;
            send_error(data, oom);
            return NULL;
      }

      mntinfo_copy(&mnt->info, mi);

      mnt_list_add(mnt);

      return mnt;
}

/**
 * Checks /proc/mounts if mount_info_t is already mounted.
 * @param data handler state
 * @return -1 on error, 0 otherwise
 */
static int is_mounted(handler_t *data, const char *point)
{
      int mounted = FALSE;
      FILE *file;
      char buf[1024], buf1[PATH_MAX], buf2[PATH_MAX], *point1;

      point1 = realpath(point, buf1);
      if (!point1)
            return FALSE;

      file = fopen(MOUNTS_FILE, "r");
      if (!file) {
            send_error(data, "Can't open " MOUNTS_FILE);
            return -1;
      }

      while (1) {
            char *device, *point2tmp, *point2;

            if (read_line(file, buf, sizeof (buf)) < 0)
                  break;

            split_string(buf, &device, &point2tmp, NULL);
            if (!device || !point2tmp)
                  continue;

            point2 = realpath(point2tmp, buf2);
            if (!point2) {
                  error("Can't get real path of %s", point2tmp);
                  continue;
            }

            if (strcmp(point1, point2) == 0) {
                  mounted = TRUE;
                  break;
            }
      }

      fclose(file);

      debug(mounted ? "%s is mounted" : "%s is not mounted", point);

      return mounted;
}

/**
 * Mounts a filesystem if not already mounted.
 */
static mount_t *add_mount(handler_t *data, mount_info_t *mi)
{
      int mounted;
      mount_t *mnt;

      mounted = is_mounted(data, mi->point);
      if (mounted < 0)
            return NULL;

      mnt = mnt_create(data, mi);
      if (!mnt)
            return NULL;

      if (!mounted) {
            debug("Creating directory %s", mnt->info.point);

            if (mkdirs(mnt->info.point) < 0) {
                  send_error(data, "Can't create directory: %s", mnt->info.point);
                  return NULL;
            }

            if (do_mount(data, &mnt->info) < 0)
                  return NULL;
      }

      ++mnt->usage;

      return mnt;
}

/**
 * Unmounts a filesystem.
 * @return -1 on error, 0 otherwise
 */
static int remove_mount(handler_t *data, const mount_info_t *mi)
{
      int mounted;
      mount_t *mnt;

      mounted = is_mounted(data, mi->point);
      if (mounted < 0)
            return -1;

      mnt = mnt_list_find(mi->point);
      if (mnt) {
            if (mnt->usage == 0)
                  mnt_list_del(mnt);
            else
                  send_error(data, "Warning: %s usage is %d", mnt->info.point, mnt->usage);
      }

      if (mounted && do_unmount(data, mi->point) < 0)
            return -1;

      return 0;
}

/**
 * Decrements the usage count of a mount and possibly sets its expiration time.
 */
static void release_mount(mount_t *mnt)
{
      --mnt->usage;

      debug("Releasing mount %s (%d users remain)", mnt->info.point, mnt->usage);

      if (mnt->usage <= 0 && mount_expiration >= 0)
            mnt->expiration = time(NULL) + mount_expiration;
}

/**
 * Steps through the MOUNTS list and unmounts all unused and expired entries.
 */
static void expire_mounts(void)
{
      mount_t *prev, *curr;
      time_t now;

      debug("Checking for expired mounts");

      now = time(NULL);

      prev = NULL;
      curr = all_mounts.tail;
      while (curr) {
            prev = curr->prev;

            if (curr->usage <= 0 && curr->expiration <= now) {
                  debug("Unmounting %s", curr->info.point);

                  if (curr->next)
                        curr->next->prev = curr->prev;
                  else
                        all_mounts.tail = curr->prev;

                  if (curr->prev)
                        curr->prev->next = curr->next;
                  else
                        all_mounts.head = curr->next;

                  do_unmount(NULL, curr->info.point);
                  mnt_free(curr);
            }

            curr = prev;
      }
}

/**
 * Unmounts all entries in the MOUNTS list.
 */
static void unmount_all(void)
{
      mount_t *prev, *mnt;

      mnt = all_mounts.tail;
      while (mnt) {
            prev = mnt->prev;

            debug("Unmounting %s", mnt->info.point);

            do_unmount(NULL, mnt->info.point);
            mnt_free(mnt);

            mnt = prev;
      }

      all_mounts.head = NULL;
      all_mounts.tail = NULL;
}

/**
 * Unmounts all filesystems listed in the command info's mount info vector.
 * @return 0 or -1 on error
 */
static int unmount_infos(handler_t *data)
{
      mount_info_t **mip;
      int cnt, rc = 0;

      debug("Unmounting filesystems");

      /* count the options and go to end of the vector */
      for (cnt = 0, mip = data->param.mounts; *mip; ++cnt, ++mip)
            ;

      /* try to unmount filesystems in reverse order */
      for (--mip; cnt-- > 0; --mip)
            if (remove_mount(data, *mip) < 0)
                  rc = -1;

      return rc;
}

/**
 * Gets the amount of data we can read from a file (which may be a tty).
 */
static ssize_t get_data_length(handler_t *data, int fd, size_t max)
{
      ssize_t len;

      len = max;

      if (isatty(fd)) {
            if (ioctl(fd, FIONREAD, &len) < 0) {
                  send_error(data, "Can't check pty for available data");
                  return -1;
            }

            if (len < 0) {
                  send_error(data, "ioctl() gave invalid read length");
                  return -1;
            }

            if (len > max)
                  len = max;
      }

      return len;
}

/**
 * Reads some data from a fd and writes it into the socket (sd) in a packet.
 * The fd will be set to -1 at EOF.
 * @param data handler state
 * @param out describes the output
 */
static void send_data(handler_t *data, output_desc_t *out)
{
      ssize_t len;

      len = get_data_length(data, out->fd, out->req);
      if (len < 0)
            goto _error;

      /* nothing to do? */
      if (len == 0)
            return;

      len = read_ni(out->fd, data->tmp_buf, len);
      if (len < 0) {
            send_error(data, "Can't read output of the process");
            goto _error;
      }

      if (len) {
            debug("Sending %s DATA packet (%d bytes)",
                  output_desc_is_stdout(out) ? "OUT" : "ERR", len);

            if (write_buf_packet(data->sd, out->data_type, len, data->tmp_buf) < 0) {
                  error("Can't write packet to socket");
                  goto _error;
            }

            out->req -= len;
      } else {
            debug(output_desc_is_stdout(out) ? "Stdout hit EOF" : "Stderr hit EOF");

            out->fd = -1;
            out->req = 0;
      }

      return;

_error:
      data->error = TRUE;
}

/**
 * Sends a REQuest for more DATA.
 */
static void send_request(handler_t *data, input_desc_t *in)
{
      /* Already waiting? */
      if (in->wait)
            return;

      debug("Sending IN REQ packet");

      if (write_enum(data->sd, in->req_type) < 0) {
            error("Can't write packet to socket");
            goto _error;
      }

      in->wait = TRUE;
      return;

_error:
      data->error = TRUE;
}

/**
 * Writes (some of) in->buf to in->fd.
 */
static void write_buffer(handler_t *data, input_desc_t *in)
{
      debug("Writing buffer to stdin");

      if (buf_write_out(in->buf, &in->fd) < 0)
            debug("Stdin is closed (%s)", strerror(errno));

      if (buf_is_empty(in->buf))
            in->wait = FALSE;
}

/**
 * Reads bytes from socket to buf_in.
 * @param data handler state
 * @param len the maximum number of bytes to be read
 */
static void receive_stream(handler_t *data)
{
      uint32_t len;
      if (read_uint32(data->sd, &len) < 0) {
            send_error(data, "Can't read data packet length");
            goto _error;
      }

      if (len > 0) {
            debug("Receiving IN DATA packet (%d bytes)", len);
            if (buf_read_in(data->in.buf, data->sd, len) < 0) {
                  send_error(data, "Can't read IN DATA packet to buffer");
                  goto _error;
            }
      } else {
            debug("Receiving IN DATA packet: EOF");
            buf_set_eof(data->in.buf);
      }

      return;

_error:
      data->error = TRUE;
}

/**
 * Read an IN DATA or OUT/ERR REQ packet from socket.
 * @return 1 if EOF from client, 0 otherwise
 */
static int receive_packet(handler_t *data)
{
      ptype_t type = read_enum(data->sd);
      switch (type) {
      case -1:
            if (errno == 0) {
                  debug("EOF from client");
                  return 1;
            }

            error("Can't read packet header from socket");
            goto _error;

      case PTYPE_IN_DATA:
            receive_stream(data);
            return 0;

      case PTYPE_OUT_REQ:
            debug("Receiving OUT REQ packet");
            data->out.req = BUFFER_SIZE;
            return 0;

      case PTYPE_ERR_REQ:
            debug("Receiving ERR REQ packet");
            data->err.req = BUFFER_SIZE;
            return 0;

      default:
            errno = 0;
            send_error(data, "Received packet has unexpected type (0x%02x)", type);
            /* goto _error */
      }

_error:
      data->error = TRUE;
      return 0;
}

/**
 * Prints information on an exited process (pid) to debug log.
 */
static void print_status(int pid, int status)
{
      int sig;

      if (WIFEXITED(status)) {
            debug("Process %d returned: %d", pid, WEXITSTATUS(status));

      } else if (WIFSIGNALED(status)) {
            sig = WTERMSIG(status);
            debug("Process %d terminated by signal: %s (%d)", pid, strsignal(sig), sig);

      } else if (WIFSTOPPED(status)) {
            sig = WSTOPSIG(status);
            debug("Process %d stopped by signal: %s (%d)", pid, strsignal(sig), sig);

      } else {
            error("Invalid status %d for pid %d", status, pid);
      }
}

#define print_debug_bool(m, b)  print_debug(m, (b) ? "yes" : "no")
#define debug_bool(m, b)        if (debug_file) print_debug_bool(m, b)

#define debug_status() \
      if (debug_file) { \
            print_debug_bool("Process alive  = %s", alive); \
            print_debug_bool("Stdin open     = %s", data->in.fd >= 0); \
            print_debug_bool("Stdin waiting  = %s", data->in.wait); \
            print_debug     ("Stdin buffered = %d bytes%s", \
                         buf_size(data->in.buf), \
                         data->in.buf->eof ? " (EOF set)" : ""); \
            print_debug_bool("Stdout open    = %s", data->out.fd >= 0); \
            print_debug     ("Stdout request = %d bytes", data->out.req); \
            print_debug_bool("Stderr open    = %s", data->err.fd >= 0); \
            print_debug     ("Stderr request = %d bytes", data->err.req); \
            print_debug_bool("Error          = %s", data->error); \
      }

#define debug_select(msg) \
      if (debug_file) { \
            print_debug(msg " [%3s %2s %3s %3s][%2s]", \
                      ISSET(data->sd,      &readfds) ? "NET" : "", \
                      ISSET(data->in.fd,   &readfds) ? "IN"  : "", \
                      ISSET(data->out.fd,  &readfds) ? "OUT" : "", \
                      ISSET(data->err.fd,  &readfds) ? "ERR" : "", \
                      ISSET(data->in.fd,  &writefds) ? "IN"  : ""); \
      }

/**
 * Manages the streams of the command executor (= the child-child process).
 * @param data handler state
 * @param pid of the child process
 * @return the return code (-1 on error)
 */
static int handler_manage(handler_t *data, pid_t pid)
{
      fd_set readfds, writefds;
      int maxfd1, status;
      bool_t alive = TRUE;

      debug("Managing process %d", pid);

      maxfd1 = MAX(data->sd, data->in.fd);
      maxfd1 = MAX(maxfd1, data->out.fd);
      maxfd1 = MAX(maxfd1, data->err.fd);
      ++maxfd1;

      data->in.buf = buf_alloc();
      if (!data->in.buf) {
            errno = 0;
            send_error(data, oom);
            goto _kill;
      }

      while (1) {
            int count, val;

            debug_status();

            if (data->error) {
                  debug("Ending loop due to error");
                  goto _kill;
            }

            /* non-TTY mode: check if we should exit */
            if (!data->param.term && data->in.fd < 0 && data->out.fd < 0 && data->err.fd < 0) {
                  debug("Ending loop due to closed stdin, stdout and stderr descriptors");
                  break;
            }

            FD_ZERO(&readfds);
            FD_ZERO(&writefds);

            FD_SET(data->sd, &readfds);

            if (data->out.fd >= 0 && data->out.req > 0)
                  FD_SET(data->out.fd, &readfds);

            if (!data->param.term && data->err.fd >= 0 && data->err.req > 0)
                  FD_SET(data->err.fd, &readfds);

            if (data->in.fd >= 0) {
                  /* we want select to wake up if stdin hits EOF
                   * (child exited but we haven't noticed yet) */
                  if (!data->param.term)
                        FD_SET(data->in.fd, &readfds);

                  if (!buf_is_empty(data->in.buf) || !data->in.wait)
                        FD_SET(data->in.fd, &writefds);
            }

            debug_select("Selecting");

            count = select(maxfd1, &readfds, &writefds, NULL, NULL);
            if (count < 0) {
                  if (errno == EINTR) {
                        debug("Select interrupted");
                  } else {
                        error("Select failed");
                        debug("Ending loop due to failed select");
                        goto _kill;
                  }

            } else if (count > 0) {
                  debug_select("Selected ");

                  /* TTY mode: check if we should exit */
                  if (data->param.term && !alive && ISSET(data->out.fd, &readfds) &&
                      get_data_length(data, data->out.fd, 1) <= 0) {
                        debug("Ending loop due to dead process and empty stdout buffer");
                        break;
                  }

                  /* EOF from client? */
                  if (ISSET(data->sd, &readfds) && receive_packet(data) == 1) {
                        close(data->sd);
                        data->sd = -1;

                        debug("Ending loop due to closed socket");
                        goto _kill;
                  }

                  if (ISSET(data->out.fd, &readfds))
                        send_data(data, &data->out);

                  if (ISSET(data->err.fd, &readfds))
                        send_data(data, &data->err);

                  if (ISSET(data->in.fd, &writefds)) {
                        if (buf_is_empty(data->in.buf))
                              send_request(data, &data->in);
                        else
                              write_buffer(data, &data->in);
                  }

                  /* stdin at EOF? */
                  if (!data->param.term && ISSET(data->in.fd, &readfds))
                        data->in.fd = -1;
            }

            /* collect late children */

            val = waitpid(-1, &status, WNOHANG);
            if (val < 0 && errno != ECHILD) {
                  send_error(data, "Can't wait for children");
                  return -1;
            }

            if (val > 0) {
                  if (val == pid)
                        alive = FALSE;

                  print_status(val, status);
            }
      }

      if (alive) {
            if (waitpid(pid, &status, 0) < 0) {
                  send_error(data, "Can't wait for child %d", pid);
                  return -1;
            }

            print_status(pid, status);
      }

      if (WIFEXITED(status))
            return WEXITSTATUS(status);

      return -1;

_kill:
      debug("Sending SIGTERM to command process %d", pid);
      kill(pid, SIGTERM);

      return -1;
}

/**
 * @return TRUE/FALSE or -1 on error
 */
static int authenticate(handler_t *data)
{
      bool_t ok;

      debug("Searching for host %s in %s", data->host, CONFIG_NAME);

      ok = find_host(data, CONFIG_NAME);
      if (!ok)
            warn("Unauthorized connection from %s", data->host);

      return ok;
}

/**
 * More or less like forkpty.
 * @return pid of the child process (0 for child) or -1 on error
 */
static pid_t fork_pty(handler_t *data, int *ptyfd)
{
      int master, slave;
      pid_t pid;

      if (openpty(&master, &slave, NULL, NULL, data->param.term) < 0) {
            send_error(data, "Can't open a pseudo-tty");
            return -1;
      }

      pid = fork();
      if (pid < 0) {
            send_error(data, "Can't fork");
            return -1;
      }

      if (pid == 0) {
            /* child */

            set_debug_name("COMMAND");

            close(master);
            if (login_tty(slave)) {
                  send_error(data, "Can't login to a pseudo-tty");
                  exit(1);
            }
      } else {
            /* parent */

            close(slave);
            *ptyfd = master;
      }

      return pid;
}

/**
 * More or less like fork, but the stdin/stdout/stderr of the child process
 * are set to point to sockets.
 * @param data handler state
 * @param infd a place to write the parent-end of the child's stdin socketpair
 * @param outfd a place to write the parent-end of the child's stdin socketpair
 * @param errfd a place to write the parent-end of the child's stdin socketpair
 * @return pid of the child process (0 for child) or -1 on error
 */
static pid_t fork_sockets(handler_t *data, int *infd, int *outfd, int *errfd)
{
      int insd[2], outsd[2], errsd[2];
      pid_t pid;

      if (socketpair(AF_UNIX, SOCK_STREAM, 0, insd) < 0 ||
          socketpair(AF_UNIX, SOCK_STREAM, 0, outsd) < 0 ||
          socketpair(AF_UNIX, SOCK_STREAM, 0, errsd) < 0) {
            send_error(data, "Can't create socket pairs");
            return -1;
      }

      pid = fork();
      if (pid < 0) {
            send_error(data, "Can't fork");
            return -1;
      }

      if (pid == 0) {
            /* child */

            set_debug_name("COMMAND");

            close(insd[0]);
            close(outsd[0]);
            close(errsd[0]);

            if (dup2(insd[1], STDIN_FILENO) != STDIN_FILENO) {
                  send_error(data, "Can't duplicate socket as stdin");
                  exit(1);
            }

            if (dup2(outsd[1], STDOUT_FILENO) != STDOUT_FILENO) {
                  send_error(data, "Can't duplicate socket as stdout");
                  exit(1);
            }

            if (dup2(errsd[1], STDERR_FILENO) != STDERR_FILENO) {
                  send_error(data, "Can't duplicate socket as stderr");
                  exit(1);
            }
      } else {
            /* parent */

            if (set_nonblocking(insd[0], TRUE)) {
                  send_error(data, "Can't make socket non-blocking");
                  return -1;
            }

            *infd = insd[0];
            *outfd = outsd[0];
            *errfd = errsd[0];
      }

      close(insd[1]);
      close(outsd[1]);
      close(errsd[1]);

      return pid;
}

/**
 * Read "SBOX_RLIMIT"-fields from ENVIRONment and set them in place.
 * @return -1 on error, 0 otherwise
 */
static int set_rlimits(handler_t *data)
{
      static struct { const char *key; int resource; } info[] = {
            { ENV_RLIMIT_PREFIX "CPU",    RLIMIT_CPU  },
            { ENV_RLIMIT_PREFIX "FSIZE",  RLIMIT_FSIZE      },
            { ENV_RLIMIT_PREFIX "DATA",   RLIMIT_DATA },
            { ENV_RLIMIT_PREFIX "STACK",  RLIMIT_STACK      },
            { ENV_RLIMIT_PREFIX "CORE",   RLIMIT_CORE },
            { ENV_RLIMIT_PREFIX "RSS",    RLIMIT_RSS  },
            { ENV_RLIMIT_PREFIX "NPROC",  RLIMIT_NPROC      },
            { ENV_RLIMIT_PREFIX "NOFILE", RLIMIT_NOFILE     },
            { ENV_RLIMIT_PREFIX "MEMLOCK",      RLIMIT_MEMLOCK    },
            { ENV_RLIMIT_PREFIX "AS",     RLIMIT_AS   },
            { NULL,                       0           }
      };

      const int prefixlen = strlen(ENV_RLIMIT_PREFIX);
      char *resname;
      int i;
      char *str;
      struct rlimit lim;

      for (i = 0; info[i].key; ++i) {
            resname = (char *) info[i].key + prefixlen;

            str = getenv(info[i].key);
            if (!str)
                  continue;

            str = skip_spaces(str);
            if (strlen(str) == 0)
                  continue;

            debug("Setting %s resource to %s", resname, str);

            if (getrlimit(info[i].resource, &lim) < 0) {
                  send_error(data, "Can't get %s resource limit", resname);
                  return -1;
            }

            lim.rlim_cur = RLIM_INFINITY;

            if (strncmp(str, ENV_RLIMIT_UNLIMITED, strlen(ENV_RLIMIT_UNLIMITED)) != 0) {
                  lim.rlim_cur = atol(str);

                  if (lim.rlim_cur == 0 && strcmp(str, "0") != 0) {
                        send_error(data, "Invalid %s resource limit value: %s", resname, str);
                        return -1;
                  }
            }

            if (setrlimit(info[i].resource, &lim) < 0) {
                  send_error(data, "Can't set %s resource limit to %d while maximum is %d",
                           resname, lim.rlim_cur, lim.rlim_max);
                  return -1;
            }
      }

      return 0;
}


/**
 * Copies a string and adds a "/" prefix if it doesn't have one.
 * abort()s if ENOMEM. This function is evil!
 */
static char *add_root_prefix(handler_t *data, char *orig)
{
      char *dir;

      if (orig[0] == '/')
            return orig;

      dir = malloc(1 + strlen(orig) + 1);
      if (!dir) {
            errno = 0;
            send_error(data, oom);
            exit(1);
      }

      strcpy(dir, "/");
      strcat(dir, orig);

      return dir;
}

/**
 * Executes the command. Never returns.
 */
static void execute_command(handler_t *data)
{
      char **argv;

      /*
       * Change the root directory
       */

      if (data->root) {
            char *dir = add_root_prefix(data, data->root);

            debug("Changing root directory to %s", dir);
            if (chroot(dir) < 0) {
                  send_error(data, "Can't change root directory to: %s", dir);
                  exit(1);
            }
      }

      chdir("/");

      /*
       * Change gids and uid (in that order because we need to be root)
       */
      {
            uid_t uid;
            gid_t gid;
            int gcount, i;
            gid_t groups[NGROUPS_MAX];

            if (!data->param.ids || data->param.ids->len < 2) {
                  errno = 0;
                  send_error(data, "Valid IDS parameter required");
                  exit(1);
            }

            uid = data->param.ids->vec[0];
            gid = data->param.ids->vec[1];

            if (!allow_root)
                  for (i = 0; i < data->param.ids->len; i++)
                        if (data->param.ids->vec[i] == 0) {
                              send_error(data, "root access denied");
                              exit(1);
                        }

            gcount = data->param.ids->len - 2;
            if (gcount > NGROUPS_MAX)
                  gcount = NGROUPS_MAX;

            for (i = 0; i < gcount; ++i)
                  groups[i] = data->param.ids->vec[i + 2];

            debug("Changing gid to %d", gid);
            if (setgid(gid) < 0) {
                  send_error(data, "Can't change group ID to %d", gid);
                  exit(1);
            }

            debug("Setting %d supplementary gids", gcount);
            if (setgroups(gcount, groups) < 0) {
                  send_error(data, "Can't set supplementary group IDs");
#ifndef DEBUG
                  exit(1);
#endif
            }

            debug("Changing uid to %d", uid);
            if (setuid(uid) < 0) {
                  send_error(data, "Can't change user ID to %d", uid);
                  exit(1);
            }
      }

      /*
       * Change environment
       */
      {
            char **env = data->param.environ;
            if (!env) {
                  env = calloc(1, sizeof (char *));
                  if (!env) {
                        errno = 0;
                        send_error(data, oom);
                        exit(1);
                  }
            }
            environ=env;
      }

      if (data->fakerootkey && putenv(data->fakerootkey) < 0) {
            send_error(data, "Can't put %s to environment", data->fakerootkey);
            exit(1);
      }

      /*
       * Read resource limits from environment
       */

      if (set_rlimits(data) < 0)
            exit(1);

      /*
       * Set umask
       */

      umask(data->param.umask);

      /*
       * Build command and arguments
       */

      argv = data->param.args;
      if (!argv) {
            argv = calloc(2, sizeof (char *));
            if (!argv) {
                  errno = 0;
                  send_error(data, oom);
                  exit(1);
            }

            argv[0] = getenv("SHELL");
            if (!argv[0])
                  argv[0] = DEFAULT_SHELL;
      }

      /*
       * Change current directory
       */

      if (data->param.cwd) {
            char *dir = add_root_prefix(data, data->param.cwd);

            debug("Changing current directory to %s", dir);
            if (chdir(dir) < 0) {
                  send_error(data, "Can't change current directory to %s inside sandbox", dir);
                  exit(1);
            }
      }

      /* hit it! */

      debug_vector("Executing command:", argv);

      set_closeonexec(data->sd);
      execvp(argv[0], argv);

      send_error(data, "Can't execute command: %s", argv[0]);
      exit(1);
}

/**
 * Does the handshake & stuff for a client.
 * @return action or -1 on error
 */
static int handler_startup(handler_t *data, mount_t ***mounts_ptr)
{
      size_t i, mount_count = 0;
      ptype_t action = 0;

      debug("Client IP address is %s", data->host);

      /*
       * Version
       */

      debug("Sending VERSION packet");
      if (send_version(data->sd) < 0) {
            error("Can't send protocol version packet");
            return -1;
      }

      debug("Reading client's VERSION packet");
      data->client_version = get_version(data->sd);
      if (data->client_version < 0) {
            error("Can't read protocol version packet");
            return -1;
      }

      if (data->client_version < PROTOCOL_VERSION) {
            errno = 0;
            send_error(data, "Client version %d is too old (version %d required)",
                     data->client_version, PROTOCOL_VERSION);
            return -1;
      }

      /*
       * Username
       */

      debug("Reading USER packet");
      if (read_enum(data->sd) != PTYPE_USER) {
            error("Received packet has unexpected type");
            return -1;
      }

      data->user = read_str(data->sd);
      if (!data->user) {
            error("Can't read USER packet");
            return -1;
      }

      /*
       * Authenticate
       */
      {
            uint16_t auth;

            int ok = authenticate(data);
            if (ok < 0)
                  return -1;

            auth = ok;

            debug("Sending AUTH packet");
            if (write_uint16_packet(data->sd, PTYPE_AUTH, auth) < 0) {
                  error("Can't write AUTH packet to socket");
                  return -1;
            }

            if (!auth)
                  return -1;
      }

      /*
       * Parameter packets
       */
      while (!action) {
            ptype_t type = read_enum(data->sd);
            switch (type) {
            case -1:
                  error("Can't read packet type from socket");
                  return -1;

            case PTYPE_TARGET:
                  debug("Receiving TARGET packet");
                  data->param.target = read_str(data->sd);
                  if (!data->param.target) {
                        error("Can't read TARGET packet");
                        return -1;
                  }
                  break;

            case PTYPE_MOUNTS:
                  debug("Receiving MOUNTS packet");
                  data->param.mounts = read_mountv(data->sd);
                  if (!data->param.mounts) {
                        error("Can't read MOUNTS packet");
                        return -1;
                  }
                  break;

            case PTYPE_ARGS:
                  debug("Receiving ARGS packet");
                  data->param.args = read_strv(data->sd);
                  if (!data->param.args) {
                        error("Can't read ARGS packet");
                        return -1;
                  }
                  break;

            case PTYPE_CWD:
                  debug("Receiving CWD packet");
                  data->param.cwd = read_str(data->sd);
                  if (!data->param.cwd) {
                        error("Can't read CWD packet");
                        return -1;
                  }
                  break;

            case PTYPE_ENVIRON:
                  debug("Receiving ENVIRON packet");
                  data->param.environ = read_strv(data->sd);
                  if (!data->param.environ) {
                        error("Can't read ENVIRON packet");
                        return -1;
                  }
                  break;

            case PTYPE_IDS:
                  debug("Receiving IDS packet");
                  data->param.ids = read_uint16v(data->sd);
                  if (!data->param.ids) {
                        error("Can't read IDS packet");
                        return -1;
                  }
                  break;

            case PTYPE_UMASK:
                  debug("Receiving UMASK packet");
                  if (read_uint16(data->sd, &data->param.umask) < 0) {
                        error("Can't read UMASK packet");
                        return -1;
                  }
                  break;

            case PTYPE_WINSIZE:
                  debug("Receiving WINSIZE packet");
                  data->param.term = read_winsize(data->sd);
                  if (!data->param.term) {
                        error("Can't read WINSIZE packet");
                        return -1;
                  }
                  break;

            case PTYPE_COMMAND:
                  debug("Received COMMAND packet");
                  action = type;
                  break;

            case PTYPE_MOUNT:
                  debug("Received MOUNT packet");
                  action = type;
                  break;

            case PTYPE_UMOUNT:
                  debug("Received UMOUNT packet");
                  action = type;
                  break;

            default:
                  errno = 0;
                  send_error(data, "Received packet has unexpected type (0x%02x)", type);
                  return -1;
            }
      }

      if (!data->param.target) {
            errno = 0;
            send_error(data, "TARGET parameter required");
            return -1;
      }
      if (action != PTYPE_COMMAND && !data->param.mounts) {
            errno = 0;
            send_error(data, "MOUNTS parameter required");
            return -1;
      }

      /*
       * Build path to new root
       */

      if (sandbox) {
            char root[PATH_MAX];

            snprintf(root, sizeof (root) - 1, SANDBOX_ROOT "/%s@%s/%s",
                   data->user, data->host, data->param.target);

            debug("Sandbox root is %s", root);

            data->root = strdup(root);
            if (!data->root) {
                  errno = 0;
                  send_error(data, oom);
                  return -1;
            }
      }

      /*
       * Count mounts
       */

      if (data->param.mounts)
            mount_count = calc_vec_len((void **) data->param.mounts);

      /*
       * Prepend root path to mount points
       */

      if (mount_count > 0 && data->root)
            for (i = 0; i < mount_count; ++i) {
                  char tmp[PATH_MAX], *str;
                  mount_info_t *mi;

                  mi = data->param.mounts[i];

                  snprintf(tmp, sizeof (tmp) - 1, "%s%s", data->root, mi->point);

                  str = strdup(tmp);
                  if (!str) {
                        send_error(data, NULL);
                        return -1;
                  }

                  free(mi->point);
                  mi->point = str;
            }

      if (action == PTYPE_UMOUNT) {
            /*
             * Unmount target's filesystems
             */

            debug("--umount requested");

            if (unmount_infos(data) < 0)
                  return -1;

      } else if (mount_count > 0) {
            /*
             * Mount filesystems
             */

            mount_t **mounts;

            mntinfo_sort_vec(data->param.mounts);

            mounts = calloc(mount_count + 1, sizeof (mount_t *));
            if (!mounts) {
                  errno = 0;
                  send_error(data, oom);
                  return -1;
            }

            for (i = 0; i < mount_count; ++i) {
                  mounts[i] = add_mount(data, data->param.mounts[i]);
                  if (!mounts[i]) {
                        errno = 0;
                        send_error(data, "Can't mount to point: %s", data->param.mounts[i]->point);

                        free_vec((void **) mounts, (free_func_t *) release_mount);
                        return -1;
                  }
            }

            *mounts_ptr = mounts;
      }

      return action;
}

static void send_rc(handler_t *data, uint16_t rc)
{
      if (data->sd < 0) {
            debug("Not sending RC packet");
            return;
      }

      debug("Sending RC packet: %d", rc);

      /* send the return code to the client */
      if (write_uint16_packet(data->sd, PTYPE_RC, rc) >= 0) {
            fd_set fds;

            /* do something until the client goes away */
            while (1) {
                  FD_ZERO(&fds);
                  FD_SET(data->sd, &fds);

                  if (select(data->sd + 1, &fds, NULL, NULL, NULL) < 0) {
                        if (errno == EINTR)
                              continue;
                        else
                              break;
                  }

                  if (read(data->sd, data->tmp_buf, BUFFER_SIZE) <= 0)
                        break;
            }
      } else {
            error("Can't write RC packet to socket");
      }

      close(data->sd);
      data->sd = -1;
}

static void handler_handle(handler_t *data)
{
      int rc = -1;
      pid_t relay_pid, pid;

      relay_pid = fakeroot_relay(data);
      if (relay_pid < 0)
            goto _rc;

      if (data->param.term) {
            int fd;

            debug("Creating command process in a pty");
            pid = fork_pty(data, &fd);

            data->in.fd = fd;
            data->out.fd = fd;
            data->err.fd = -1;
      } else {
            debug("Creating command process without a pty");
            pid = fork_sockets(data, &data->in.fd, &data->out.fd, &data->err.fd);
      }

      /* child? */
      if (pid == 0)
            execute_command(data);  /* never returns */

      if (pid > 0)
            rc = handler_manage(data, pid);

      if (relay_pid > 0) {
            debug("Sending SIGTERM to relay process %d", relay_pid);
            kill(relay_pid, SIGTERM);
      }

_rc:
      if (rc < 0)
            rc = INTERNAL_ERROR_CODE;

      send_rc(data, rc);
}

static void release_mounts(mount_t **vec)
{
      while (*vec)
            release_mount(*vec++);
}

static handler_t *alloc_handler(void)
{
      handler_t *h = calloc(1, sizeof (handler_t));
      if (h) {
            h->sd = -1;

            h->in.req_type = PTYPE_IN_REQ;
            h->out.data_type = PTYPE_OUT_DATA;
            h->err.data_type = PTYPE_ERR_DATA;

            h->error = 0;
      }

      return h;
}

/**
 * Frees all resources used by a handler.
 */
static void free_handler(handler_t *data)
{
      if (data->sd >= 0)
            close(data->sd);

      if (data->user)
            free((char *) data->user);

      if (data->param.target)
            free(data->param.target);

      if (data->param.mounts)
            free_vec((void **) data->param.mounts, (free_func_t *) mntinfo_free);

      if (data->param.args)
            free_vec((void **) data->param.args, NULL);

      if (data->param.cwd)
            free(data->param.cwd);

      if (data->param.environ)
            free_vec((void **) data->param.environ, NULL);

      if (data->param.ids)
            uint16v_free(data->param.ids);

      if (data->root)
            free(data->root);

      if (data->in.buf)
            buf_free(data->in.buf);

      if (data->fakerootkey)
            free(data->fakerootkey);

      free(data);
}

static void pid_mounts_add(pid_t pid, mount_t **vec)
{
      pid_mounts_t *pm, *node;

      /* allocate */

      pm = calloc(1, sizeof (pid_mounts_t));
      if (!pm) {
            oom_error();
            return;
      }

      pm->pid = pid;
      pm->mounts = vec;

      /* append */

      if (pid_mounts == NULL) {
            pid_mounts = pm;
      } else {
            for (node = pid_mounts; node->next; node = node->next)
                  ;
            node->next = pm;
      }
}

static void pid_mounts_del(pid_t pid)
{
      pid_mounts_t *pm = NULL, *node;

      if (!pid_mounts)
            return;

      /* find and remove */

      if (pid_mounts->pid == pid) {
            pm = pid_mounts;
            pid_mounts = pm->next;
      } else {
            for (node = pid_mounts; node->next; node = node->next)
                  if (node->next->pid == pid) {
                        pm = node->next;
                        node->next = pm->next;
                        break;
                  }
      }

      if (!pm)
            return;

      debug("Found mounts for pid %d", pid);

      /* deallocate */

      if (pm->mounts) {
            release_mounts(pm->mounts);
            free(pm->mounts);
      }

      free(pm);
}

static void accept_conn(int srvsd)
{
      struct sockaddr_storage addr;
      socklen_t len = sizeof (addr);
      int clisd;
      handler_t *data;
      mount_t **mounts = NULL;
      int action;
      pid_t pid;

      clisd = accept(srvsd, (struct sockaddr *) &addr, &len);
      if (clisd < 0) {
            error("Can't accept connection");
            return;
      }

      debug("New connection");

      data = alloc_handler();
      if (!data) {
            oom_error();
            close(clisd);
            return;
      }

      data->sd = clisd;

      /* get the IP address */
      if (getnameinfo((struct sockaddr *) &addr, len, data->host, sizeof (data->host),
                  NULL, 0, NI_NUMERICHOST) < 0) {
            error("Can't get client's IP address");
            goto _error;
      }

      action = handler_startup(data, &mounts);
      if (action < 0)
            goto _error;

      if (action == PTYPE_COMMAND) {
            /* start handler process */
            pid = fork();
            if (pid < 0) {
                  error("Can't fork");
                  goto _error;
            }

            /* child? */
            if (pid == 0) {
                  set_debug_name("HANDLER");

                  close(srvsd);

                  handler_handle(data);

                  debug("Handler process exiting");
                  exit(0);

                  /* not reached */
            }

            /* map pid to mounts vector */
            if (mounts)
                  pid_mounts_add(pid, mounts);

            free_handler(data);
      } else {
            send_rc(data, 0);

            free_handler(data);
            free_vec((void **) mounts, (free_func_t *) release_mount);
      }

      return;

_error:
      send_rc(data, INTERNAL_ERROR_CODE);

      free_handler(data);
      if (mounts)
            free(mounts);
}

/**
 * Unmounts all filesystems that we've mounted and exits.
 */
static void clean_exit(int rc)
{
      unmount_all();

      debug("sbrshd exiting");
      exit(rc);
}

static void sig_dummy(int sig)
{
#ifdef DEBUG
      int stored_errno = errno;
      debug(strsignal(sig));
      errno = stored_errno;
#endif
}

/*
 * Exits. If invoked on the DAEMON process, then it also unmounts everything.
 */
static void sig_exit(int sig)
{
      debug(strsignal(sig));

      if (getpid() == daemon_pid)
            clean_exit(0);

      exit(0);
}

static void sig_debug(int sig)
{
      int stored_errno = errno;

      debug(strsignal(sig));

      if (getpid() != daemon_pid)
            goto _ret;

      if (sig == SIGUSR1)
            open_debug_log();
      else if (sig == SIGUSR2)
            close_debug_log();

_ret:
      errno = stored_errno;
}

static char *get_absolute_path(const char *progname, char *relpath)
{
      char *tmp1, *tmp2, *dir, *file, *abspath;

      tmp1 = relpath;
      tmp2 = strdup(relpath);
      if (!tmp2) {
            errno = 0;
            error_err(progname, oom);
            exit(1);
      }

      dir = dirname(tmp1);
      file = basename(tmp2);

      abspath = malloc(PATH_MAX + 1);
      if (!abspath) {
            errno = 0;
            error_err(progname, oom);
            exit(1);
      }

      abspath = realpath(dir, abspath);
      if (!abspath) {
            error_err(progname, "Can't get real path of %s", dir);
            exit(1);
      }

      if (strlen(abspath) + 1 + strlen(file) > PATH_MAX) {
            error_err(progname, "Path is too long: %s/%s", abspath, file);
            exit(1);
      }

      strcat(abspath, "/");
      strcat(abspath, file);

      free(tmp2);

      return abspath;
}

static void usage(char *progname)
{
      fprintf(stderr, "Usage: %s [-p|--port <port>]"
                         " [-l|--local-only]"
                         " [-n|--no-sandbox]"
                         " [-r|--allow-root]"
                         " [-d|--debug <path>]"
                         " [-e|--mount-expiration <minutes>|none]"
                         " [-m|--mount-bin <path>]"
                         " [-u|--umount-bin <path>]"
                         " [-t|--mount-tab <path>]"
                         " [-b|--bind-opt <options>]\n"
                  "       %s add <address>\n"
                  "       %s -v|--version\n"
                  "       %s -h|--help\n",
            progname, progname, progname, progname);

      exit(0);
}

/**
 * Reads options. Prints usage and exits when necessary.
 */
static void read_args(char *progname, int argc, char **argv)
{
      const char *const optstring = "hvp:lnird:e:m:u:b:";
      struct option longopts[] = {
            { "help",             no_argument,       0, 'h' },
            { "version",          no_argument,       0, 'v' },
            { "port",             required_argument, 0, 'p' },
            { "local-only",       no_argument,       0, 'l' },
            { "no-sandbox",       no_argument,       0, 'n' },
            { "allow-root",       no_argument,       0, 'r' },
            { "debug",            required_argument, 0, 'd' },
            { "mount-expiration", required_argument, 0, 'e' },
            { "mount-bin",        required_argument, 0, 'm' },
            { "umount-bin",       required_argument, 0, 'u' },
            { "bind-opt",         required_argument, 0, 'b' },
            { 0 }
      };

      char *debugname = NULL;
      char *exp_str = NULL;
      bool_t default_bind_opt = TRUE;

      while (1) {
            int c = getopt_long(argc, argv, optstring, longopts, NULL);
            if (c < 0)
                  break;

            switch (c) {
            case 'p':
                  port = optarg;
                  break;

            case 'l':
                  local_only = TRUE;
                  break;

            case 'n':
                  sandbox = FALSE;
                  break;

            case 'r':
                  allow_root = TRUE;
                  break;

            case 'd':
                  debugname = optarg;
                  break;

            case 'e':
                  exp_str = optarg;
                  break;

            case 'm':
                  mount_cmd = optarg;
                  break;

            case 'u':
                  umount_cmd = optarg;
                  break;

            case 'b':
                  bind_opt = optarg;
                  default_bind_opt = FALSE;
                  break;

            case 'v':
                  fprintf(stderr, "Scratchbox Remote Shell daemon %d%s\n",
                        PROTOCOL_VERSION, REVISION);
                  exit(0);

            case 'h':
            case '?':
            default:
                  usage(progname);
            }
      }

      if (exp_str) {
            if (strcmp(exp_str, MOUNT_EXPIRATION_NONE) == 0) {
                  mount_expiration = -1;

            } else if (strcmp(exp_str, "0") == 0) {
                  mount_expiration = 0;

            } else {
                  int i = atoi(exp_str);
                  if (i <= 0) {
                        error_err(progname, "Invalid expiration time: %s minutes", exp_str);
                        exit(1);
                  }
                  mount_expiration = i * 60;
            }
      }

      /* we need the absolute path since we chdir to root */
      if (debugname)
            debug_filename = get_absolute_path(progname, debugname);

      if (default_bind_opt)
            check_for_busybox(progname);
}

/**
 * Do stuff, close fds and direct in/out/err to /dev/null or debug log.
 * To make this function sufficiently odd, setsid() will not be called.
 * @param listenfd this descriptor won't be closed
 */
int daemonize(int listenfd)
{
      int debugfd, fd;

      chdir("/");
      umask(0);

      /* Don't close debug file */
      debugfd = debug_file ? fileno(debug_file) : -1;

      for (fd = getdtablesize(); fd-- > 0; )
            if (fd != listenfd && fd != debugfd)
                  close(fd);

      fd = open(NULL_FILE, O_RDWR);
      if (fd < 0) {
            error("Can't open " NULL_FILE);
            return -1;
      }

      assert(fd == STDIN_FILENO);

      if (debugfd >= 0)
            fd = debugfd;

      if (dup2(fd, STDOUT_FILENO) != STDOUT_FILENO) {
            error("Can't duplicate descriptor %d as stdout", fd);
            return -1;
      }

      if (dup2(fd, STDERR_FILENO) != STDERR_FILENO) {
            error("Can't duplicate descriptor %d as stderr", fd);
            return -1;
      }

      return 0;
}

/*
 * Startup and main loop.
 */
int main(int argc, char **argv)
{
      char *progname;
      int srvsd;
      struct sigaction act_dummy, act_exit, act_debug;
      pid_t pid;
      unsigned int timeout = 0;
      struct timeval tv;

      progname = get_progname(argv[0]);

      /* add lines to config file */

      if (argc >= 2 && strcmp(argv[1], "add") == 0) {
            char *host;

            if (argc != 3) {
                  usage(progname);
                  return 1;
            }

            host = argv[2];

            if (strlen(host) == 0 || strchr(host, '@') != NULL) {
                  usage(progname);
                  return 1;
            }

            return -add_to_config(host);
      }

      /* read config */

      openlog(progname, LOG_PID, LOG_DAEMON);
      read_args(progname, argc, argv);

      /* create a socket, bind it and listen to it */
      {
            struct addrinfo *ai, hints = { 0 };

            hints.ai_flags = AI_ADDRCONFIG;
            hints.ai_family = AF_INET;
            hints.ai_socktype = SOCK_STREAM;

            if (!local_only)
                  hints.ai_flags |= AI_PASSIVE;

            if (getaddrinfo(NULL, port, &hints, &ai) < 0) {
                  error_err(progname, "Can't get address info");
                  return 1;
            }

            srvsd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
            if (srvsd < 0) {
                  error_err(progname, "Can't create socket");
                  return 1;
            }

            if (setsockopt_bool(srvsd, SOL_SOCKET, SO_REUSEADDR, TRUE) < 0) {
                  error_err(progname, "Can't set socket option (SO_REUSEADDR)");
                  return 1;
            }

            if (bind(srvsd, ai->ai_addr, ai->ai_addrlen) < 0) {
                  error_err(progname, "Can't bind socket");
                  return 1;
            }

            freeaddrinfo(ai);

            if (listen(srvsd, SOMAXCONN) < 0) {
                  error_err(progname, "Can't listen with socket");
                  return 1;
            }
      }

      /* daemonize */

#ifndef DEBUG
      pid = fork();
      if (pid < 0) {
            error_err(progname, "Can't fork");
            return 1;
      }
      if (pid > 0) {
            printf("%d\n", pid);
            return 0;
      }
#endif

      set_debug_name("DAEMON");
      if (debug_filename)
            open_debug_log();

      setsid();
      if (daemonize(srvsd) < 0)
            return 1;

      /* signal handlers */

      daemon_pid = getpid();

      act_dummy.sa_handler = sig_dummy;
      sigemptyset(&act_dummy.sa_mask);
      act_dummy.sa_flags = 0;

      act_exit.sa_handler = sig_exit;
      sigemptyset(&act_exit.sa_mask);
      act_exit.sa_flags = SA_ONESHOT;

      act_debug.sa_handler = sig_debug;
      sigemptyset(&act_debug.sa_mask);
      act_debug.sa_flags = 0;

#ifdef DEBUG
      sigaction(SIGINT, &act_exit, NULL);
#else
      sigaction(SIGHUP, &act_dummy, NULL);
#endif

      sigaction(SIGTERM, &act_exit, NULL);

      sigaction(SIGCHLD, &act_dummy, NULL);
      sigaction(SIGPIPE, &act_dummy, NULL);

      sigaction(SIGUSR1, &act_debug, NULL);
      sigaction(SIGUSR2, &act_debug, NULL);

      /* mount expiration */

      if (mount_expiration >= 0) {
            timeout = mount_expiration / MOUNT_EXPIRATION_FREQUENCY;

            tv.tv_sec = timeout;
            tv.tv_usec = 0;
      }

      while (1) {
            fd_set fds;
            int count;

            debug("Waiting for connection");

            FD_ZERO(&fds);
            FD_SET(srvsd, &fds);

            count = select(srvsd + 1, &fds, NULL, NULL, mount_expiration > 0 ? &tv : NULL);
            if (count < 0) {
                  /* failed? */
                  if (errno != EINTR) {
                        error("Select failed");
                        clean_exit(1);
                  }

                  /* we received a signal */
                  while (1) {
                        int status;

                        pid = waitpid(-1, &status, WNOHANG);
                        if (pid <= 0)
                              break;

                        /* release mounts if found */
                        print_status(pid, status);
                        pid_mounts_del(pid);
                  }
            } else if (count == 1) {
                  /* we received a connection */
                  accept_conn(srvsd);
            }

            /* expire mounts */
            if (mount_expiration >= 0) {
                  expire_mounts();

                  tv.tv_sec = timeout;
                  tv.tv_usec = 0;
            }
      }

      return 0;  /* not reached */
}

Generated by  Doxygen 1.6.0   Back to index