diff --git a/Makefile b/Makefile index 8d824d0..b9236bb 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,8 @@ LDFLAGS+=-pthread OBJ=build/femtotcp.o \ build/port/posix/linux_tap.o -EXE=build/tcpecho build/test-evloop +EXE=build/tcpecho build/tcp_netcat_poll build/tcp_netcat_select \ + build/test-evloop LIB=libfemtotcp.so @@ -47,6 +48,14 @@ build/tcpecho: $(OBJ) build/port/posix/bsd_socket.o build/test/tcp_echo.o @echo "[LD] $@" @$(CC) $(CFLAGS) $(LDFLAGS) -o $@ -Wl,--start-group $(^) -Wl,--end-group +build/tcp_netcat_poll: $(OBJ) build/port/posix/bsd_socket.o build/test/tcp_netcat_poll.o + @echo "[LD] $@" + @$(CC) $(CFLAGS) $(LDFLAGS) -o $@ -Wl,--start-group $(^) -Wl,--end-group + +build/tcp_netcat_select: $(OBJ) build/port/posix/bsd_socket.o build/test/tcp_netcat_select.o + @echo "[LD] $@" + @$(CC) $(CFLAGS) $(LDFLAGS) -o $@ -Wl,--start-group $(^) -Wl,--end-group + build/%.o: src/%.c @mkdir -p `dirname $@` || true @echo "[CC] $<" diff --git a/src/femtotcp.c b/src/femtotcp.c index 6bc34b9..2bb9ded 100644 --- a/src/femtotcp.c +++ b/src/femtotcp.c @@ -1047,13 +1047,13 @@ static void tcp_input(struct ipstack *S, struct ipstack_tcp_seg *tcp, uint32_t f t->sock.tcp.state = TCP_CLOSE_WAIT; t->sock.tcp.ack = ee32(tcp->seq) + 1; tcp_send_ack(t); - t->events |= CB_EVENT_CLOSED; + t->events |= CB_EVENT_CLOSED | CB_EVENT_READABLE; } else if (t->sock.tcp.state == TCP_FIN_WAIT_1) { t->sock.tcp.state = TCP_CLOSING; t->sock.tcp.ack = ee32(tcp->seq) + 1; tcp_send_ack(t); - t->events |= CB_EVENT_CLOSED; + t->events |= CB_EVENT_CLOSED | CB_EVENT_READABLE; } } @@ -1105,7 +1105,7 @@ static void tcp_input(struct ipstack *S, struct ipstack_tcp_seg *tcp, uint32_t f t->sock.tcp.state = TCP_CLOSING; } t->sock.tcp.ack = ee32(tcp->seq) + 1; - t->events |= CB_EVENT_CLOSED; + t->events |= CB_EVENT_CLOSED | CB_EVENT_READABLE; tcp_send_ack(t); } if (tcp->flags & 0x10) { diff --git a/src/port/posix/bsd_socket.c b/src/port/posix/bsd_socket.c index b35a721..6921cd6 100644 --- a/src/port/posix/bsd_socket.c +++ b/src/port/posix/bsd_socket.c @@ -10,12 +10,14 @@ #include #include #include +#include #define FEMTO_POSIX #include "config.h" #include "femtotcp.h" static __thread int in_the_stack = 1; static struct ipstack *IPSTACK = NULL; +pthread_mutex_t ipstack_mutex = PTHREAD_MUTEX_INITIALIZER; /* host_ functions are the original functions from the libc */ static int (*host_socket ) (int domain, int type, int protocol) = NULL; @@ -54,14 +56,18 @@ static int (*host_fcntl) (int fd, int cmd, ...); if(in_the_stack) { \ return host_##call(fd, ## __VA_ARGS__); \ } else { \ + pthread_mutex_lock(&ipstack_mutex); \ if ((fd & (MARK_TCP_SOCKET | MARK_UDP_SOCKET)) != 0) { \ int __femto_retval = ft_##call(IPSTACK, fd, ## __VA_ARGS__); \ if (__femto_retval < 0) { \ errno = __femto_retval; \ + pthread_mutex_unlock(&ipstack_mutex); \ return -1; \ } \ + pthread_mutex_unlock(&ipstack_mutex); \ return __femto_retval; \ }else { \ + pthread_mutex_unlock(&ipstack_mutex); \ return host_##call(fd, ## __VA_ARGS__); \ } \ } @@ -70,6 +76,7 @@ static int (*host_fcntl) (int fd, int cmd, ...); if(in_the_stack) { \ return host_##call(fd, ## __VA_ARGS__); \ } else { \ + pthread_mutex_lock(&ipstack_mutex); \ if ((fd & (MARK_TCP_SOCKET | MARK_UDP_SOCKET)) != 0) { \ int __femto_retval; \ do { \ @@ -80,10 +87,13 @@ static int (*host_fcntl) (int fd, int cmd, ...); } while (__femto_retval == -11); \ if (__femto_retval < 0) { \ errno = __femto_retval; \ + pthread_mutex_unlock(&ipstack_mutex); \ return -1; \ } \ + pthread_mutex_unlock(&ipstack_mutex); \ return __femto_retval; \ }else { \ + pthread_mutex_unlock(&ipstack_mutex); \ return host_##call(fd, ## __VA_ARGS__); \ } \ } @@ -126,10 +136,14 @@ int fcntl(int fd, int cmd, ...) { va_start(ap, cmd); arg = va_arg(ap, int); va_end(ap); + int ret; if (in_the_stack) { return host_fcntl(fd, cmd, arg); } else { - return ft_fcntl(IPSTACK, fd, cmd, arg); + pthread_mutex_lock(&ipstack_mutex); + ret = ft_fcntl(IPSTACK, fd, cmd, arg); + pthread_mutex_unlock(&ipstack_mutex); + return ret; } } @@ -171,10 +185,13 @@ void poller_callback(int fd, uint16_t event, void *arg) int ft_poll(struct ipstack *ipstack, struct pollfd *fds, nfds_t nfds, int timeout) { nfds_t i; int fd; + int ret; printf("Called poll\n"); if (in_the_stack) { return host_poll(fds, nfds, timeout); } + memset(tcp_pollers, 0, sizeof(tcp_pollers)); + memset(udp_pollers, 0, sizeof(udp_pollers)); for (i = 0; i < nfds; i++) { fd = fds[i].fd; struct bsd_poll_helper *poller; @@ -201,7 +218,9 @@ int ft_poll(struct ipstack *ipstack, struct pollfd *fds, nfds_t nfds, int timeou } /* Call the original poll */ printf("Calling host_poll\n"); - int ret = host_poll(fds, nfds, timeout); + pthread_mutex_unlock(&ipstack_mutex); + ret = host_poll(fds, nfds, timeout); + pthread_mutex_lock(&ipstack_mutex); for (i = 0; i < nfds; i++) { struct bsd_poll_helper *poller; if ((fds[i].fd & MARK_TCP_SOCKET) != 0) @@ -213,7 +232,7 @@ int ft_poll(struct ipstack *ipstack, struct pollfd *fds, nfds_t nfds, int timeou printf("Replacing events\n"); if (fds[i].revents & POLLIN) { char c; - read(poller->pipefds[0], &c, 1); + host_read(poller->pipefds[0], &c, 1); switch(c) { case 'r': fds[i].revents |= POLLIN; @@ -231,8 +250,8 @@ int ft_poll(struct ipstack *ipstack, struct pollfd *fds, nfds_t nfds, int timeou fds[i].revents &= fds[i].events & (POLLHUP | POLLERR); } printf("Closing pipes\n"); - close(poller->pipefds[0]); - close(poller->pipefds[1]); + host_close(poller->pipefds[0]); + host_close(poller->pipefds[1]); poller->fd = 0; fds[i].fd = poller->fd; fds[i].events = poller->events; @@ -246,74 +265,85 @@ int ft_select(struct ipstack *ipstack, int nfds, fd_set *readfds, fd_set *writef int i; int maxfd; int ret; - printf("Called select\n"); + fd_set readfds_local; /* Assume MARK_TCP_SOCKET < MARK_UDP_SOCKET */ if (nfds < MARK_TCP_SOCKET + 1) { return host_select(nfds, readfds, writefds, exceptfds, timeout); } - for (i = 0; i < MARK_TCP_SOCKET && i < nfds; i++) { - if (FD_ISSET(i, readfds) || FD_ISSET(i, writefds) || FD_ISSET(i, exceptfds)) { + memset(tcp_pollers, 0, sizeof(tcp_pollers)); + memset(udp_pollers, 0, sizeof(udp_pollers)); + for (i = 0; (i < MARK_TCP_SOCKET) && (i < nfds); i++) { + if ((readfds && FD_ISSET(i, readfds)) || + (writefds && FD_ISSET(i, writefds)) || + (exceptfds && FD_ISSET(i, exceptfds))) { maxfd = i; } } - for (i = MARK_TCP_SOCKET; i < nfds && i < MAX_TCPSOCKETS; i++) { - if (FD_ISSET(i, readfds) || FD_ISSET(i, writefds) || FD_ISSET(i, exceptfds)) { - int tcp_pos = i & (~MARK_TCP_SOCKET); - printf("Found TCP %d\n", i); - pipe(tcp_pollers[tcp_pos].pipefds); + /* At this point, we do need a fd_set to read from pipes */ + if (!readfds) { + FD_ZERO(&readfds_local); + readfds = &readfds_local; + } + for (i = MARK_TCP_SOCKET; i < nfds && i < (MARK_TCP_SOCKET | MAX_TCPSOCKETS); i++) { + int tcp_pos = i & (~MARK_TCP_SOCKET); + if ((readfds && (FD_ISSET(i, readfds))) || (writefds && (FD_ISSET(i, writefds))) || (exceptfds && (FD_ISSET(i, exceptfds)))) { + if (pipe(tcp_pollers[tcp_pos].pipefds) < 0) + return -1; tcp_pollers[tcp_pos].fd = i; tcp_pollers[tcp_pos].events = 0; ipstack_register_callback(ipstack, i, poller_callback, ipstack); - if (FD_ISSET(i, readfds)) { + if (readfds && (FD_ISSET(i, readfds))) { tcp_pollers[tcp_pos].events |= POLLIN; FD_CLR(i, readfds); - FD_SET(tcp_pollers[tcp_pos].pipefds[1], readfds); + FD_SET(tcp_pollers[tcp_pos].pipefds[0], readfds); } - if (FD_ISSET(i, writefds)) { + if (writefds && (FD_ISSET(i, writefds))) { tcp_pollers[tcp_pos].events |= POLLOUT; FD_CLR(i, writefds); - FD_SET(tcp_pollers[tcp_pos].pipefds[1], writefds); + FD_SET(tcp_pollers[tcp_pos].pipefds[0], writefds); } - if (FD_ISSET(i, exceptfds)) { + if (exceptfds && (FD_ISSET(i, exceptfds))) { tcp_pollers[tcp_pos].events |= POLLERR | POLLHUP; FD_CLR(i, exceptfds); - FD_SET(tcp_pollers[tcp_pos].pipefds[1], exceptfds); + FD_SET(tcp_pollers[tcp_pos].pipefds[0], exceptfds); } - if (maxfd < tcp_pollers[tcp_pos].pipefds[1]) { - maxfd = tcp_pollers[tcp_pos].pipefds[1]; + if (maxfd < tcp_pollers[tcp_pos].pipefds[0]) { + maxfd = tcp_pollers[tcp_pos].pipefds[0]; } + } else { } } - for (i = MARK_UDP_SOCKET; i < nfds && i < MAX_UDPSOCKETS; i++) { + for (i = MARK_UDP_SOCKET; i < nfds && i < (MARK_UDP_SOCKET | MAX_UDPSOCKETS); i++) { + int udp_pos = i & (~MARK_UDP_SOCKET); if (FD_ISSET(i, readfds) || FD_ISSET(i, writefds) || FD_ISSET(i, exceptfds)) { - int udp_pos = i & (~MARK_UDP_SOCKET); - printf("Found TCP %d\n", i); pipe(udp_pollers[udp_pos].pipefds); udp_pollers[udp_pos].fd = i; udp_pollers[udp_pos].events = 0; ipstack_register_callback(ipstack, i, poller_callback, ipstack); - if (FD_ISSET(i, readfds)) { + if (readfds && FD_ISSET(i, readfds)) { udp_pollers[udp_pos].events |= POLLIN; FD_CLR(i, readfds); - FD_SET(udp_pollers[udp_pos].pipefds[1], readfds); + FD_SET(udp_pollers[udp_pos].pipefds[0], readfds); } - if (FD_ISSET(i, writefds)) { + if (writefds && FD_ISSET(i, writefds)) { udp_pollers[udp_pos].events |= POLLOUT; FD_CLR(i, writefds); - FD_SET(udp_pollers[udp_pos].pipefds[1], writefds); + FD_SET(udp_pollers[udp_pos].pipefds[0], writefds); } - if (FD_ISSET(i, exceptfds)) { + if (exceptfds && FD_ISSET(i, exceptfds)) { udp_pollers[udp_pos].events |= POLLERR | POLLHUP; FD_CLR(i, exceptfds); - FD_SET(udp_pollers[udp_pos].pipefds[1], exceptfds); + FD_SET(udp_pollers[udp_pos].pipefds[0], exceptfds); } - if (maxfd < udp_pollers[udp_pos].pipefds[1]) { - maxfd = udp_pollers[udp_pos].pipefds[1]; + if (maxfd < udp_pollers[udp_pos].pipefds[0]) { + maxfd = udp_pollers[udp_pos].pipefds[0]; } } } /* Call the original select */ + pthread_mutex_unlock(&ipstack_mutex); ret = host_select(maxfd + 1, readfds, writefds, exceptfds, timeout); + pthread_mutex_lock(&ipstack_mutex); if (ret <= 0) { return ret; } @@ -322,39 +352,39 @@ int ft_select(struct ipstack *ipstack, int nfds, fd_set *readfds, fd_set *writef if (tcp_pollers[i].fd == 0) { continue; } - if (FD_ISSET(tcp_pollers[i].pipefds[1], readfds)) { + if (FD_ISSET(tcp_pollers[i].pipefds[0], readfds)) { char c; - read(tcp_pollers[i].pipefds[1], &c, 1); - if (c == 'r') { + host_read(tcp_pollers[i].pipefds[0], &c, 1); + if (readfds && (c == 'r')) { FD_SET(tcp_pollers[i].fd, readfds); - } else if (c == 'w') { + } else if (writefds && (c == 'w')) { FD_SET(tcp_pollers[i].fd, writefds); - } else if (c == 'e') { + } else if (exceptfds && (c == 'e')) { FD_SET(tcp_pollers[i].fd, exceptfds); } } - close(tcp_pollers[i].pipefds[0]); - close(tcp_pollers[i].pipefds[1]); ipstack_register_callback(ipstack, tcp_pollers[i].fd, NULL, NULL); + host_close(tcp_pollers[i].pipefds[0]); + host_close(tcp_pollers[i].pipefds[1]); tcp_pollers[i].fd = 0; } for (i = 0; i < MAX_UDPSOCKETS; i++) { if (udp_pollers[i].fd == 0) { continue; } - if (FD_ISSET(udp_pollers[i].pipefds[1], readfds)) { + if (FD_ISSET(udp_pollers[i].pipefds[0], readfds)) { char c; - read(udp_pollers[i].pipefds[1], &c, 1); - if (c == 'r') { + read(udp_pollers[i].pipefds[0], &c, 1); + if (readfds && (c == 'r')) { FD_SET(udp_pollers[i].fd, readfds); - } else if (c == 'w') { + } else if (writefds && (c == 'w')) { FD_SET(udp_pollers[i].fd, writefds); - } else if (c == 'e') { + } else if (exceptfds && (c == 'e')) { FD_SET(udp_pollers[i].fd, exceptfds); } } - close(udp_pollers[i].pipefds[0]); - close(udp_pollers[i].pipefds[1]); + host_close(udp_pollers[i].pipefds[0]); + host_close(udp_pollers[i].pipefds[1]); ipstack_register_callback(ipstack, tcp_pollers[i].fd, NULL, NULL); udp_pollers[i].fd = 0; } @@ -362,10 +392,14 @@ int ft_select(struct ipstack *ipstack, int nfds, fd_set *readfds, fd_set *writef } int select(int nfds, fd_set *readfds, fd_set *writefds, fd_set *exceptfds, struct timeval *timeout) { + int ret; if (in_the_stack) { return host_select(nfds, readfds, writefds, exceptfds, timeout); } else { - return ft_select(IPSTACK, nfds, readfds, writefds, exceptfds, timeout); + pthread_mutex_lock(&ipstack_mutex); + ret = ft_select(IPSTACK, nfds, readfds, writefds, exceptfds, timeout); + pthread_mutex_unlock(&ipstack_mutex); + return ret; } } @@ -415,6 +449,7 @@ int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { } ssize_t recvfrom(int sockfd, void *buf, size_t len, int flags, struct sockaddr *addr, socklen_t *addrlen) { + printf("Called recv(): len=%lu\n", len); conditional_steal_blocking_call(recvfrom, sockfd, buf, len, flags, addr, addrlen); } @@ -439,10 +474,14 @@ ssize_t write(int sockfd, const void *buf, size_t len) { } int poll(struct pollfd *fds, nfds_t nfds, int timeout) { + int ret; if (in_the_stack) { return host_poll(fds, nfds, timeout); } else { - return ft_poll(IPSTACK, fds, nfds, timeout); + pthread_mutex_lock(&ipstack_mutex); + ret = ft_poll(IPSTACK, fds, nfds, timeout); + pthread_mutex_unlock(&ipstack_mutex); + return ret; } } @@ -458,8 +497,10 @@ void *ft_posix_ip_loop(void *arg) { uint32_t ms_next; struct timeval tv; while (1) { + pthread_mutex_lock(&ipstack_mutex); gettimeofday(&tv, NULL); ms_next = ipstack_poll(ipstack, tv.tv_sec * 1000 + tv.tv_usec / 1000); + pthread_mutex_unlock(&ipstack_mutex); usleep(ms_next * 1000); in_the_stack = 1; } diff --git a/src/test/tcp_netcat_select.c b/src/test/tcp_netcat_select.c new file mode 100644 index 0000000..032b2d0 --- /dev/null +++ b/src/test/tcp_netcat_select.c @@ -0,0 +1,102 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#define PORT 12346 + +int main() { + int server_fd, new_socket = -1; + struct sockaddr_in server_addr; + fd_set readfds; + int max_fd; + + // Create a TCP socket + server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (server_fd == -1) { + perror("Socket creation failed"); + exit(EXIT_FAILURE); + } + printf("server socket: %d\n", server_fd); + + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(PORT); + + // Bind the socket to the address + if (bind(server_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) { + perror("Bind failed"); + close(server_fd); + exit(EXIT_FAILURE); + } + + // Listen for incoming connections + if (listen(server_fd, 3) == -1) { + perror("Listen failed"); + close(server_fd); + exit(EXIT_FAILURE); + } + + // Initialize file descriptor sets + FD_ZERO(&readfds); + FD_SET(STDIN_FILENO, &readfds); // Monitor stdin + FD_SET(server_fd, &readfds); // Monitor the server socket + max_fd = server_fd; + + while (1) { + fd_set tempfds = readfds; + int activity = select(max_fd + 1, &tempfds, NULL, NULL, NULL); + if (activity == -1) { + perror("Select error"); + close(server_fd); + exit(EXIT_FAILURE); + } + + if (FD_ISSET(STDIN_FILENO, &tempfds)) { + // Data available on stdin + char buffer[1024]; + ssize_t bytes_read = read(STDIN_FILENO, buffer, sizeof(buffer)); + if (bytes_read > 0 && new_socket != -1) { + // Write stdin data to the socket + send(new_socket, buffer, bytes_read, 0); + } + } + + if ((new_socket == -1) && FD_ISSET(server_fd, &tempfds)) { + printf("Server socket activity\n"); + // New connection on the socket + if (new_socket == -1) { + new_socket = accept(server_fd, NULL, NULL); + if (new_socket == -1) { + perror("Accept failed"); + continue; + } + printf("New connection established\n"); + FD_SET(new_socket, &readfds); // Monitor the new socket + max_fd = (new_socket > max_fd) ? new_socket : max_fd; + continue; + } + } + if ((new_socket != -1) && FD_ISSET(new_socket, &tempfds)) { + // Data available on the socket + char buffer[1024]; + ssize_t bytes_received = recv(new_socket, buffer, sizeof(buffer), 0); + if (bytes_received > 0) { + write(STDOUT_FILENO, buffer, bytes_received); + } else if (bytes_received == 0) { + // Connection closed by the client + close(new_socket); + FD_CLR(new_socket, &readfds); // Stop monitoring the socket + new_socket = -1; + printf("Connection closed\n"); + } + } + } + + close(server_fd); + return 0; +}