diff --git a/src/source-ipfw.c b/src/source-ipfw.c index 7f31de71a8..1a32eaef93 100644 --- a/src/source-ipfw.c +++ b/src/source-ipfw.c @@ -76,6 +76,7 @@ void TmModuleVerdictIPFWRegister (void) tmm_modules[TMM_VERDICTIPFW].Func = NULL; tmm_modules[TMM_VERDICTIPFW].ThreadExitPrintStats = NULL; tmm_modules[TMM_VERDICTIPFW].ThreadDeinit = NULL; + tmm_modules[TMM_VERDICTIPFW].flags = TM_FLAG_VERDICT_TM; } void TmModuleDecodeIPFWRegister (void) @@ -181,6 +182,7 @@ void TmModuleVerdictIPFWRegister (void) tmm_modules[TMM_VERDICTIPFW].ThreadDeinit = VerdictIPFWThreadDeinit; tmm_modules[TMM_VERDICTIPFW].cap_flags = SC_CAP_NET_ADMIN | SC_CAP_NET_RAW | SC_CAP_NET_BIND_SERVICE; /** \todo untested */ + tmm_modules[TMM_VERDICTIPFW].flags = TM_FLAG_VERDICT_TM; } /** diff --git a/src/source-nfq.c b/src/source-nfq.c index 5293f2529a..13eb936585 100644 --- a/src/source-nfq.c +++ b/src/source-nfq.c @@ -75,6 +75,7 @@ void TmModuleVerdictNFQRegister (void) tmm_modules[TMM_VERDICTNFQ].ThreadExitPrintStats = NULL; tmm_modules[TMM_VERDICTNFQ].ThreadDeinit = NULL; tmm_modules[TMM_VERDICTNFQ].cap_flags = SC_CAP_NET_ADMIN; + tmm_modules[TMM_VERDICTNFQ].flags = TM_FLAG_VERDICT_TM; } void TmModuleDecodeNFQRegister (void) @@ -187,6 +188,7 @@ void TmModuleVerdictNFQRegister (void) tmm_modules[TMM_VERDICTNFQ].ThreadInit = VerdictNFQThreadInit; tmm_modules[TMM_VERDICTNFQ].Func = VerdictNFQ; tmm_modules[TMM_VERDICTNFQ].ThreadDeinit = VerdictNFQThreadDeinit; + tmm_modules[TMM_VERDICTNFQ].flags = TM_FLAG_VERDICT_TM; } void TmModuleDecodeNFQRegister (void) diff --git a/src/source-windivert.c b/src/source-windivert.c index 347d2e7a0f..bd9f5fe7ef 100644 --- a/src/source-windivert.c +++ b/src/source-windivert.c @@ -69,6 +69,7 @@ void TmModuleVerdictWinDivertRegister(void) { tmm_modules[TMM_VERDICTWINDIVERT].name = "VerdictWinDivert"; tmm_modules[TMM_VERDICTWINDIVERT].ThreadInit = NoWinDivertSupportExit; + tmm_modules[TMM_VERDICTWINDIVERT].flags = TM_FLAG_VERDICT_TM; } void TmModuleDecodeWinDivertRegister(void) @@ -382,6 +383,7 @@ void TmModuleVerdictWinDivertRegister(void) tm_ptr->ThreadInit = VerdictWinDivertThreadInit; tm_ptr->Func = VerdictWinDivert; tm_ptr->ThreadDeinit = VerdictWinDivertThreadDeinit; + tm_ptr->flags = TM_FLAG_VERDICT_TM; } void TmModuleDecodeWinDivertRegister(void) diff --git a/src/suricata.c b/src/suricata.c index 0afb332d61..7b18541285 100644 --- a/src/suricata.c +++ b/src/suricata.c @@ -2277,12 +2277,13 @@ void PostRunDeinit(const int runmode, struct timeval *start_time) FlowDisableFlowManagerThread(); /* disable capture */ TmThreadDisableReceiveThreads(); - /* tell packet threads to enter flow timeout loop */ - TmThreadDisablePacketThreads(THV_REQ_FLOW_LOOP, THV_FLOW_LOOP); + /* tell relevant packet threads to enter flow timeout loop */ + TmThreadDisablePacketThreads( + THV_REQ_FLOW_LOOP, THV_FLOW_LOOP, (TM_FLAG_RECEIVE_TM | TM_FLAG_DETECT_TM)); /* run cleanup on the flow hash */ FlowForceReassembly(); - /* gracefully shut down packet threads */ - TmThreadDisablePacketThreads(THV_KILL, THV_RUNNING_DONE); + /* gracefully shut down all packet threads */ + TmThreadDisablePacketThreads(THV_KILL, THV_RUNNING_DONE, TM_FLAG_PACKET_ALL); SCPrintElapsedTime(start_time); FlowDisableFlowRecyclerThread(); diff --git a/src/tm-modules.h b/src/tm-modules.h index 4642ff46a6..2e2e70ce65 100644 --- a/src/tm-modules.h +++ b/src/tm-modules.h @@ -35,6 +35,12 @@ #define TM_FLAG_LOGAPI_TM 0x10 /**< TM is run by Log API */ #define TM_FLAG_MANAGEMENT_TM 0x20 #define TM_FLAG_COMMAND_TM 0x40 +#define TM_FLAG_VERDICT_TM 0x80 + +/* all packet modules combined */ +#define TM_FLAG_PACKET_ALL \ + (TM_FLAG_RECEIVE_TM | TM_FLAG_DECODE_TM | TM_FLAG_STREAM_TM | TM_FLAG_DETECT_TM | \ + TM_FLAG_VERDICT_TM) typedef TmEcode (*ThreadInitFunc)(ThreadVars *, const void *, void **); typedef TmEcode (*ThreadDeinitFunc)(ThreadVars *, void *); diff --git a/src/tm-threads.c b/src/tm-threads.c index 1885b16cbe..c8f746b735 100644 --- a/src/tm-threads.c +++ b/src/tm-threads.c @@ -507,7 +507,7 @@ static void *TmThreadsSlotVar(void *td) TmThreadsHandleInjectedPackets(tv); } - if (TmThreadsCheckFlag(tv, THV_REQ_FLOW_LOOP)) { + if (TmThreadsCheckFlag(tv, (THV_KILL | THV_REQ_FLOW_LOOP))) { run = 0; } } @@ -1482,10 +1482,21 @@ static void TmThreadDebugValidateNoMorePackets(void) #endif } +/** \internal + * \brief check if a thread has any of the modules indicated by TM_FLAG_* + * \param tv thread + * \param flags TM_FLAG_*'s + * \retval bool true if at least on of the flags is present */ +static inline bool CheckModuleFlags(const ThreadVars *tv, const uint8_t flags) +{ + return (tv->tmm_flags & flags) != 0; +} + /** * \brief Disable all packet threads * \param set flag to set * \param check flag to check + * \param module_flags bitflags of TmModule's to apply the `set` flag to. * * Support 2 stages in shutting down the packet threads: * 1. set THV_REQ_FLOW_LOOP and wait for THV_FLOW_LOOP @@ -1493,8 +1504,11 @@ static void TmThreadDebugValidateNoMorePackets(void) * * During step 1 the main loop is exited, and the flow loop logic is entered. * During step 2, the flow loop logic is done and the thread closes. + * + * `module_flags` limits which threads are disabled */ -void TmThreadDisablePacketThreads(const uint16_t set, const uint16_t check) +void TmThreadDisablePacketThreads( + const uint16_t set, const uint16_t check, const uint8_t module_flags) { struct timeval start_ts; struct timeval cur_ts; @@ -1518,6 +1532,11 @@ again: /* loop through the packet threads and kill them */ SCMutexLock(&tv_root_lock); for (ThreadVars *tv = tv_root[TVT_PPT]; tv != NULL; tv = tv->next) { + /* only set flow worker threads to THV_REQ_FLOW_LOOP */ + if (!CheckModuleFlags(tv, module_flags)) { + SCLogDebug("%s does not have any of the modules %02x, skip", tv->name, module_flags); + continue; + } TmThreadsSetFlag(tv, set); /* separate worker threads (autofp) will still wait at their @@ -1533,8 +1552,9 @@ again: } /* wait for it to reach the expected state */ - while (!TmThreadsCheckFlag(tv, check)) { + if (!TmThreadsCheckFlag(tv, check)) { SCMutexUnlock(&tv_root_lock); + SCLogDebug("%s did not reach state %u, again", tv->name, check); SleepMsec(1); goto again; diff --git a/src/tm-threads.h b/src/tm-threads.h index 7698c3519a..9250a4d603 100644 --- a/src/tm-threads.h +++ b/src/tm-threads.h @@ -121,7 +121,8 @@ void TmThreadWaitForFlag(ThreadVars *, uint32_t); TmEcode TmThreadsSlotVarRun (ThreadVars *tv, Packet *p, TmSlot *slot); -void TmThreadDisablePacketThreads(const uint16_t set, const uint16_t check); +void TmThreadDisablePacketThreads( + const uint16_t set, const uint16_t check, const uint8_t module_flags); void TmThreadDisableReceiveThreads(void); uint32_t TmThreadCountThreadsByTmmFlags(uint8_t flags);