diff --git a/src/unix-manager.c b/src/unix-manager.c index ee87678d54..9b85756242 100644 --- a/src/unix-manager.c +++ b/src/unix-manager.c @@ -58,15 +58,19 @@ typedef struct Task_ { TAILQ_ENTRY(Task_) next; } Task; +typedef struct UnixClient_ { + int fd; + TAILQ_ENTRY(UnixClient_) next; +} UnixClient; + typedef struct UnixCommand_ { time_t start_timestamp; int socket; - int client; struct sockaddr_un client_addr; int select_max; - fd_set select_set; TAILQ_HEAD(, Command_) commands; TAILQ_HEAD(, Task_) tasks; + TAILQ_HEAD(, UnixClient_) clients; } UnixCommand; /** @@ -85,11 +89,11 @@ int UnixNew(UnixCommand * this) this->start_timestamp = time(NULL); this->socket = -1; - this->client = -1; this->select_max = 0; TAILQ_INIT(&this->commands); TAILQ_INIT(&this->tasks); + TAILQ_INIT(&this->clients); /* Create socket dir */ ret = mkdir(SOCKET_PATH, S_IRWXU|S_IXGRP|S_IRGRP); @@ -180,17 +184,47 @@ int UnixNew(UnixCommand * this) return 1; } +void UnixCommandSetMaxFD(UnixCommand *this) { + UnixClient *item; + + if (this == NULL) { + SCLogError(SC_ERR_INVALID_ARGUMENT, "Unix command is NULL, warn devel"); + return; + } + + this->select_max = this->socket + 1; + TAILQ_FOREACH(item, &this->clients, next) { + if (item->fd >= this->select_max) { + this->select_max = item->fd + 1; + } + } +} + /** * \brief Close the unix socket */ -void UnixCommandClose(UnixCommand *this) +void UnixCommandClose(UnixCommand *this, int fd) { - if (this->client == -1) + UnixClient *item; + int found = 0; + + TAILQ_FOREACH(item, &this->clients, next) { + if (item->fd == fd) { + found = 1; + break; + } + } + + if (found == 0) { + SCLogError(SC_ERR_INVALID_VALUE, "No fd found in client list"); return; - SCLogInfo("Unix socket: close client connection"); - close(this->client); - this->client = -1; - this->select_max = this->socket + 1; + } + + TAILQ_REMOVE(&this->clients, item, next); + + close(item->fd); + UnixCommandSetMaxFD(this); + SCFree(item); } /** @@ -198,9 +232,9 @@ void UnixCommandClose(UnixCommand *this) */ int UnixCommandSendCallback(const char *buffer, size_t size, void *data) { - UnixCommand *this = (UnixCommand *) data; + int fd = *(int *) data; - if (send(this->client, buffer, size, MSG_NOSIGNAL) == -1) { + if (send(fd, buffer, size, MSG_NOSIGNAL) == -1) { SCLogInfo("Unable to send block: %s", strerror(errno)); return -1; } @@ -227,13 +261,15 @@ int UnixCommandAccept(UnixCommand *this) json_t *server_msg; json_t *version; json_error_t jerror; + int client; int ret; + UnixClient *uclient = NULL; /* accept client socket */ socklen_t len = sizeof(this->client_addr); - this->client = accept(this->socket, (struct sockaddr *) &this->client_addr, + client = accept(this->socket, (struct sockaddr *) &this->client_addr, &len); - if (this->client < 0) { + if (client < 0) { SCLogInfo("Unix socket: accept() error: %s", strerror(errno)); return 0; @@ -242,30 +278,30 @@ int UnixCommandAccept(UnixCommand *this) /* read client version */ buffer[sizeof(buffer)-1] = 0; - ret = recv(this->client, buffer, sizeof(buffer)-1, 0); + ret = recv(client, buffer, sizeof(buffer)-1, 0); if (ret < 0) { SCLogInfo("Command server: client doesn't send version"); - UnixCommandClose(this); + UnixCommandClose(this, client); return 0; } if (ret >= (int)(sizeof(buffer)-1)) { SCLogInfo("Command server: client message is too long, " "disconnect him."); - UnixCommandClose(this); + UnixCommandClose(this, client); } buffer[ret] = 0; client_msg = json_loads(buffer, 0, &jerror); if (client_msg == NULL) { SCLogInfo("Invalid command, error on line %d: %s\n", jerror.line, jerror.text); - UnixCommandClose(this); + UnixCommandClose(this, client); return 0; } version = json_object_get(client_msg, "version"); if(!json_is_string(version)) { SCLogInfo("error: version is not a string"); - UnixCommandClose(this); + UnixCommandClose(this, client); return 0; } @@ -273,7 +309,7 @@ int UnixCommandAccept(UnixCommand *this) if (strcmp(json_string_value(version), UNIX_PROTO_VERSION) != 0) { SCLogInfo("Unix socket: invalid client version: \"%s\"", json_string_value(version)); - UnixCommandClose(this); + UnixCommandClose(this, client); return 0; } else { SCLogInfo("Unix socket: client version: \"%s\"", @@ -283,23 +319,28 @@ int UnixCommandAccept(UnixCommand *this) /* send answer */ server_msg = json_object(); if (server_msg == NULL) { - UnixCommandClose(this); + UnixCommandClose(this, client); return 0; } json_object_set_new(server_msg, "return", json_string("OK")); - if (json_dump_callback(server_msg, UnixCommandSendCallback, this, 0) == -1) { + if (json_dump_callback(server_msg, UnixCommandSendCallback, &client, 0) == -1) { SCLogWarning(SC_ERR_SOCKET, "Unable to send command"); - UnixCommandClose(this); + UnixCommandClose(this, client); return 0; } /* client connected */ SCLogInfo("Unix socket: client connected"); - if (this->socket < this->client) - this->select_max = this->client + 1; - else - this->select_max = this->socket + 1; + + uclient = SCMalloc(sizeof(UnixClient)); + if (uclient == NULL) { + SCLogError(SC_ERR_MEM_ALLOC, "Can't allocate new cient"); + return 0; + } + uclient->fd = client; + TAILQ_INSERT_TAIL(&this->clients, uclient, next); + UnixCommandSetMaxFD(this); return 1; } @@ -326,7 +367,7 @@ int UnixCommandBackgroundTasks(UnixCommand* this) * * \retval 0 in case of error, 1 in case of success */ -int UnixCommandExecute(UnixCommand * this, char *command) +int UnixCommandExecute(UnixCommand * this, char *command, UnixClient *client) { int ret = 1; json_error_t error; @@ -387,7 +428,7 @@ int UnixCommandExecute(UnixCommand * this, char *command) } /* send answer */ - if (json_dump_callback(server_msg, UnixCommandSendCallback, this, 0) == -1) { + if (json_dump_callback(server_msg, UnixCommandSendCallback, &client->fd, 0) == -1) { SCLogWarning(SC_ERR_SOCKET, "Unable to send command"); goto error_cmd; } @@ -399,15 +440,15 @@ error_cmd: error: json_decref(jsoncmd); json_decref(server_msg); - UnixCommandClose(this); + UnixCommandClose(this, client->fd); return 0; } -void UnixCommandRun(UnixCommand * this) +void UnixCommandRun(UnixCommand * this, UnixClient *client) { char buffer[4096]; int ret; - ret = recv(this->client, buffer, sizeof(buffer) - 1, 0); + ret = recv(client->fd, buffer, sizeof(buffer) - 1, 0); if (ret <= 0) { if (ret == 0) { SCLogInfo("Unix socket: lost connection with client"); @@ -415,16 +456,16 @@ void UnixCommandRun(UnixCommand * this) SCLogInfo("Unix socket: error on recv() from client: %s", strerror(errno)); } - UnixCommandClose(this); + UnixCommandClose(this, client->fd); return; } if (ret >= (int)(sizeof(buffer)-1)) { SCLogInfo("Command server: client command is too long, " "disconnect him."); - UnixCommandClose(this); + UnixCommandClose(this, client->fd); } buffer[ret] = 0; - UnixCommandExecute(this, buffer); + UnixCommandExecute(this, buffer, client); } /** @@ -436,15 +477,19 @@ int UnixMain(UnixCommand * this) { struct timeval tv; int ret; + fd_set select_set; + UnixClient *uclient; /* Wait activity on the socket */ - FD_ZERO(&this->select_set); - FD_SET(this->socket, &this->select_set); - if (0 <= this->client) - FD_SET(this->client, &this->select_set); + FD_ZERO(&select_set); + FD_SET(this->socket, &select_set); + TAILQ_FOREACH(uclient, &this->clients, next) { + FD_SET(uclient->fd, &select_set); + } + tv.tv_sec = 0; tv.tv_usec = 200 * 1000; - ret = select(this->select_max, &this->select_set, NULL, NULL, &tv); + ret = select(this->select_max, &select_set, NULL, NULL, &tv); /* catch select() error */ if (ret == -1) { @@ -457,7 +502,6 @@ int UnixMain(UnixCommand * this) } if (suricata_ctl_flags & (SURICATA_STOP | SURICATA_KILL)) { - UnixCommandClose(this); return 1; } @@ -466,10 +510,13 @@ int UnixMain(UnixCommand * this) return 1; } - if (0 <= this->client && FD_ISSET(this->client, &this->select_set)) { - UnixCommandRun(this); + + TAILQ_FOREACH(uclient, &this->clients, next) { + if (FD_ISSET(uclient->fd, &select_set)) { + UnixCommandRun(this, uclient); + } } - if (FD_ISSET(this->socket, &this->select_set)) { + if (FD_ISSET(this->socket, &select_set)) { if (!UnixCommandAccept(this)) return 0; } @@ -644,6 +691,7 @@ TmEcode UnixManagerRegisterBackgroundTask( void *UnixManagerThread(void *td) { ThreadVars *th_v = (ThreadVars *)td; + int ret; /* set the thread name */ (void) SCSetThreadName(th_v->name); @@ -677,10 +725,17 @@ void *UnixManagerThread(void *td) TmThreadsSetFlag(th_v, THV_INIT_DONE); while (1) { - UnixMain(&command); + ret = UnixMain(&command); + if (ret == 0) { + SCLogError(SC_ERR_FATAL, "Fatal error on unix socket"); + } - if (TmThreadsCheckFlag(th_v, THV_KILL)) { - UnixCommandClose(&command); + if ((ret == 0) || (TmThreadsCheckFlag(th_v, THV_KILL))) { + UnixClient *item; + TAILQ_FOREACH(item, &(&command)->clients, next) { + close(item->fd); + SCFree(item); + } SCPerfSyncCounters(th_v, 0); break; }