#include <sys/sendfile.h>
#include <sys/errno.h>
#include <sys/fcntl.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

static void usage();
static int write_all(int connection, const char *buf, size_t size);
static int sendfile_all(int connection, int fd, off_t file_size);

#define LISTENER_PORT 8080
static const char response_header[] = "HTTP/1.0 200 OK\r\n\r\n";
int
main(int argc, char **argv)
{
  int port_num;
  struct sockaddr_in addr;
  int listener;
  int fd;
  struct stat file_info;
  int toggle_tcp_nodelay = 0;
  int enable_tcp_nodelay = 0;
  int enable_tcp_cork = 0;
  int enable_nonblocking = 0;
  int i;

  if (argc < 3) {
    usage();
  }
  for (i = 1; i < argc - 2; i++) {
    if (!strcmp(argv[i], "--toggle-nodelay")) {
      toggle_tcp_nodelay = 1;
    }
    else if (!strcmp(argv[i], "--nodelay")) {
      enable_tcp_nodelay = 1;
    }
    else if (!strcmp(argv[i], "--cork")) {
      enable_tcp_cork = 1;
    }
    else if (!strcmp(argv[i], "--nonblock")) {
      enable_nonblocking = 1;
    }
    else {
      usage();
    }
  }

  fd = open(argv[i], O_RDONLY);
  if (fd == -1) {
    perror("open");
    exit(1);
  }
  if (fstat(fd, &file_info) == -1) {
    perror("fstat");
    exit(1);
  }

  port_num = atoi(argv[i + 1]);
  listener = socket(AF_INET, SOCK_STREAM, PF_UNSPEC);
  if (listener == -1) {
    perror("socket");
    exit(1);
  }
  
  {
      int flag = 1;
      if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR,
                     &flag, sizeof flag) == -1)
          perror("setsocketopt/SO_REUSEADDR");
  }

  addr.sin_family = AF_INET;
  addr.sin_port = htons(port_num);
  addr.sin_addr.s_addr = htonl(INADDR_ANY);
  if (bind(listener, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
    perror("bind");
    exit(1);
  }
  if (listen(listener, 1024) == -1) {
    perror("listen");
    exit(1);
  }

  for (;;) {
    off_t offset = 0;
    off_t bytes_remaining = file_info.st_size;
    struct sockaddr_in new_addr;
    socklen_t addr_length = sizeof(new_addr);
    int connection = accept(listener, (struct sockaddr *)&new_addr,
			    &addr_length);
    if (connection == -1) {
      perror("accept");
      break;
    }

    {
      int flag;
      socklen_t len = sizeof flag;
      getsockopt(5, SOL_TCP, TCP_NODELAY, &flag, &len);
    }


    if (enable_nonblocking) {
      int flags = fcntl(connection, F_GETFL, 0);
      if (flags == -1) {
	perror("fcntl(F_GETFL)");
	break;
      }
      flags |= O_NONBLOCK;
      if (fcntl(connection, F_SETFL, flags) == -1) {
	perror("fcntl(F_SETFL)");
	break;
      }
    }
    if (enable_tcp_nodelay || toggle_tcp_nodelay) {
      int flag = 1;
      if (setsockopt(connection, IPPROTO_TCP, TCP_NODELAY, &flag,
		     sizeof(flag)) == -1) {
	perror("setsockopt(TCP_NODELAY=1)");
      }
    }
    if (toggle_tcp_nodelay) {
      int flag = 0;
      if (setsockopt(connection, IPPROTO_TCP, TCP_NODELAY, &flag,
		     sizeof(flag)) == -1) {
	perror("setsockopt(TCP_NODELAY=0)");
      }
    }
    if (enable_tcp_cork) {
      int flag = 1;
      if (setsockopt(connection, IPPROTO_TCP, TCP_CORK, &flag,
		     sizeof(flag)) == -1) {
	perror("setsockopt(TCP_CORK)");
      }
    }
    if (write_all(connection, response_header, sizeof(response_header) - 1)
	== -1) {
      perror("write");
    }
    else if (sendfile_all(connection, fd, file_info.st_size) == -1) {
      perror("sendfile");
    }
    close(connection);
  }

  close(listener);
  exit(0);
  return 0;
}

static void
usage()
{
  fprintf(stderr, "usage: sendfile_test [--nodelay] [--cork] [--nonblock] filename portnum\n");
  exit(2);
}

static int
write_all(int connection, const char *data, size_t size)
{
  size_t offset = 0;
  while (offset < size) {
    int rv = write(connection, data + offset,  size - offset);
    if (rv == -1) {
      if (errno == EAGAIN) {
	struct pollfd poll_fd;
	poll_fd.fd = connection;
	poll_fd.events = POLLOUT | POLLERR | POLLHUP | POLLNVAL;
	if (poll(&poll_fd, 1, -1) == -1) {
	  perror("poll");
	  break;
	}
	continue;
      }
      return rv;
    }
    offset += rv;
  }
  return (int)offset;
}

static int
sendfile_all(int connection, int fd, off_t file_size)
{
  off_t offset = 0;
  while (offset < file_size) {
    int rv = sendfile(connection, fd, &offset, file_size - offset);
    if (rv == -1) {
      if (errno == EAGAIN) {
	struct pollfd poll_fd;
	poll_fd.fd = connection;
	poll_fd.events = POLLOUT | POLLERR | POLLHUP | POLLNVAL;
	if (poll(&poll_fd, 1, -1) == -1) {
	  perror("poll");
	  break;
	}
	continue;
      }
      return rv;
    }
    offset += rv;
  }
  return (int)offset;
}

