mpm: run engines as few times as possible

In various scenarios buffers would be checked my MPM more than
once. This was because the buffers would be inspected for a
certain progress value or higher.

For example, for each packet in a file upload, the engine would
not just rerun the 'http client body' MPM on the new data, it
would also rerun the method, uri, headers, cookie, etc MPMs.

This was obviously inefficent, so this patch changes the logic.

The patch only runs the MPM engines when the progress is exactly
the intended progress. If the progress is beyond the desired
value, it is run once. A tracker is added to the app layer API,
where the completed MPMs are tracked.

Implemented for HTTP, TLS and SSH.
pull/2673/head
Victor Julien 9 years ago
parent d304be5bc3
commit a0fad6bb7f

@ -2718,6 +2718,28 @@ static int HTPSetTxDetectState(void *alstate, void *vtx, DetectEngineState *s)
return 0;
}
static uint64_t HTPGetTxMpmIDs(void *vtx)
{
htp_tx_t *tx = (htp_tx_t *)vtx;
HtpTxUserData *tx_ud = htp_tx_get_user_data(tx);
return tx_ud ? tx_ud->mpm_ids : 0;
}
static int HTPSetTxMpmIDs(void *vtx, uint64_t mpm_ids)
{
htp_tx_t *tx = (htp_tx_t *)vtx;
HtpTxUserData *tx_ud = htp_tx_get_user_data(tx);
if (tx_ud == NULL) {
tx_ud = HTPMalloc(sizeof(*tx_ud));
if (unlikely(tx_ud == NULL))
return -ENOMEM;
memset(tx_ud, 0, sizeof(*tx_ud));
htp_tx_set_user_data(tx, tx_ud);
}
tx_ud->mpm_ids = mpm_ids;
return 0;
}
static int HTPRegisterPatternsForProtocolDetection(void)
{
char *methods[] = { "GET", "PUT", "POST", "HEAD", "TRACE", "OPTIONS",
@ -2806,6 +2828,8 @@ void RegisterHTPParsers(void)
AppLayerParserRegisterDetectStateFuncs(IPPROTO_TCP, ALPROTO_HTTP,
HTPStateHasTxDetectState,
HTPGetTxDetectState, HTPSetTxDetectState);
AppLayerParserRegisterMpmIDsFuncs(IPPROTO_TCP, ALPROTO_HTTP,
HTPGetTxMpmIDs, HTPSetTxMpmIDs);
AppLayerParserRegisterParser(IPPROTO_TCP, ALPROTO_HTTP, STREAM_TOSERVER,
HTPHandleRequestData);

@ -188,6 +188,9 @@ typedef struct HtpBody_ {
/** Now the Body Chunks will be stored per transaction, at
* the tx user data */
typedef struct HtpTxUserData_ {
/** flags to track which mpm has run */
uint64_t mpm_ids;
/* Body of the request (if any) */
uint8_t request_body_init;
uint8_t response_body_init;
@ -228,7 +231,6 @@ typedef struct HtpTxUserData_ {
} HtpTxUserData;
typedef struct HtpState_ {
/* Connection parser structure for each connection */
htp_connp_t *connp;
/* Connection structure for each connection */

@ -116,6 +116,9 @@ typedef struct AppLayerParserProtoCtx_
DetectEngineState *(*GetTxDetectState)(void *tx);
int (*SetTxDetectState)(void *alstate, void *tx, DetectEngineState *);
uint64_t (*GetTxMpmIDs)(void *tx);
int (*SetTxMpmIDs)(void *tx, uint64_t);
/* each app-layer has its own value */
uint32_t stream_depth;
@ -537,6 +540,18 @@ void AppLayerParserRegisterDetectStateFuncs(uint8_t ipproto, AppProto alproto,
SCReturn;
}
void AppLayerParserRegisterMpmIDsFuncs(uint8_t ipproto, AppProto alproto,
uint64_t(*GetTxMpmIDs)(void *tx),
int (*SetTxMpmIDs)(void *tx, uint64_t))
{
SCEnter();
alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].GetTxMpmIDs = GetTxMpmIDs;
alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].SetTxMpmIDs = SetTxMpmIDs;
SCReturn;
}
/***** Get and transaction functions *****/
void *AppLayerParserGetProtocolParserLocalStorage(uint8_t ipproto, AppProto alproto)
@ -929,6 +944,24 @@ int AppLayerParserSetTxDetectState(uint8_t ipproto, AppProto alproto,
SCReturnInt(r);
}
uint64_t AppLayerParserGetTxMpmIDs(uint8_t ipproto, AppProto alproto, void *tx)
{
if (alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].GetTxMpmIDs != NULL) {
return alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].GetTxMpmIDs(tx);
}
return 0ULL;
}
int AppLayerParserSetTxMpmIDs(uint8_t ipproto, AppProto alproto, void *tx, uint64_t mpm_ids)
{
int r = 0;
if (alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].SetTxMpmIDs != NULL) {
r = alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].SetTxMpmIDs(tx, mpm_ids);
}
SCReturnInt(r);
}
/***** General *****/
int AppLayerParserParse(ThreadVars *tv, AppLayerParserThreadCtx *alp_tctx, Flow *f, AppProto alproto,

@ -153,6 +153,9 @@ void AppLayerParserRegisterDetectStateFuncs(uint8_t ipproto, AppProto alproto,
void AppLayerParserRegisterGetStreamDepth(uint8_t ipproto,
AppProto alproto,
uint32_t (*GetStreamDepth)(void));
void AppLayerParserRegisterMpmIDsFuncs(uint8_t ipproto, AppProto alproto,
uint64_t (*GetTxMpmIDs)(void *tx),
int (*SetTxMpmIDs)(void *tx, uint64_t));
/***** Get and transaction functions *****/
@ -195,6 +198,9 @@ int AppLayerParserHasTxDetectState(uint8_t ipproto, AppProto alproto, void *alst
DetectEngineState *AppLayerParserGetTxDetectState(uint8_t ipproto, AppProto alproto, void *tx);
int AppLayerParserSetTxDetectState(uint8_t ipproto, AppProto alproto, void *alstate, void *tx, DetectEngineState *s);
uint64_t AppLayerParserGetTxMpmIDs(uint8_t ipproto, AppProto alproto, void *tx);
int AppLayerParserSetTxMpmIDs(uint8_t ipproto, AppProto alproto, void *tx, uint64_t);
/***** General *****/
int AppLayerParserParse(ThreadVars *tv, AppLayerParserThreadCtx *tctx, Flow *f, AppProto alproto,

@ -557,6 +557,19 @@ static int SSHGetTxLogged(void *state, void *tx, uint32_t logger)
return 0;
}
static uint64_t SSHGetTxMpmIDs(void *vtx)
{
SshState *ssh_state = (SshState *)vtx;
return ssh_state->mpm_ids;
}
static int SSHSetTxMpmIDs(void *vtx, uint64_t mpm_ids)
{
SshState *ssh_state = (SshState *)vtx;
ssh_state->mpm_ids = mpm_ids;
return 0;
}
static int SSHGetAlstateProgressCompletionStatus(uint8_t direction)
{
return SSH_STATE_FINISHED;
@ -632,6 +645,8 @@ void RegisterSSHParsers(void)
AppLayerParserRegisterGetStateProgressFunc(IPPROTO_TCP, ALPROTO_SSH, SSHGetAlstateProgress);
AppLayerParserRegisterLoggerFuncs(IPPROTO_TCP, ALPROTO_SSH, SSHGetTxLogged, SSHSetTxLogged);
AppLayerParserRegisterMpmIDsFuncs(IPPROTO_TCP, ALPROTO_SSH,
SSHGetTxMpmIDs, SSHSetTxMpmIDs);
AppLayerParserRegisterGetStateProgressCompletionStatus(ALPROTO_SSH,
SSHGetAlstateProgressCompletionStatus);

@ -76,6 +76,9 @@ typedef struct SshState_ {
/* specifies which loggers are done logging */
uint32_t logged;
/* bit flags of mpms that have already run */
uint64_t mpm_ids;
DetectEngineState *de_state;
} SshState;

@ -245,6 +245,19 @@ int SSLGetAlstateProgress(void *tx, uint8_t direction)
return TLS_STATE_IN_PROGRESS;
}
static uint64_t SSLGetTxMpmIDs(void *vtx)
{
SSLState *ssl_state = (SSLState *)vtx;
return ssl_state->mpm_ids;
}
static int SSLSetTxMpmIDs(void *vtx, uint64_t mpm_ids)
{
SSLState *ssl_state = (SSLState *)vtx;
ssl_state->mpm_ids = mpm_ids;
return 0;
}
static int TLSDecodeHandshakeHello(SSLState *ssl_state, uint8_t *input,
uint32_t input_len)
{
@ -1832,6 +1845,8 @@ void RegisterSSLParsers(void)
AppLayerParserRegisterGetStateProgressFunc(IPPROTO_TCP, ALPROTO_TLS, SSLGetAlstateProgress);
AppLayerParserRegisterLoggerFuncs(IPPROTO_TCP, ALPROTO_TLS, SSLGetTxLogged, SSLSetTxLogged);
AppLayerParserRegisterMpmIDsFuncs(IPPROTO_TCP, ALPROTO_TLS,
SSLGetTxMpmIDs, SSLSetTxMpmIDs);
AppLayerParserRegisterGetStateProgressCompletionStatus(ALPROTO_TLS,
SSLGetAlstateProgressCompletionStatus);

@ -191,6 +191,9 @@ typedef struct SSLState_ {
/* specifies which loggers are done logging */
uint32_t logged;
/* MPM/prefilter Id's */
uint64_t mpm_ids;
/* there might be a better place to store this*/
uint16_t hb_record_len;

@ -118,6 +118,7 @@ static inline void PrefilterTx(DetectEngineThreadCtx *det_ctx,
if (tx == NULL)
continue;
uint64_t mpm_ids = AppLayerParserGetTxMpmIDs(ipproto, alproto, tx);
const int tx_progress = AppLayerParserGetStateProgress(ipproto, alproto, tx, flags);
SCLogDebug("tx %p progress %d", tx, tx_progress);
@ -127,16 +128,30 @@ static inline void PrefilterTx(DetectEngineThreadCtx *det_ctx,
goto next;
if (engine->tx_min_progress > tx_progress)
goto next;
if (tx_progress > engine->tx_min_progress) {
if (mpm_ids & (1<<(engine->gid))) {
goto next;
}
}
PROFILING_PREFILTER_START(p);
engine->cb.PrefilterTx(det_ctx, engine->pectx,
p, p->flow, tx, idx, flags);
PROFILING_PREFILTER_END(p, engine->gid);
if (tx_progress > engine->tx_min_progress) {
mpm_ids |= (1<<(engine->gid));
}
next:
if (engine->is_last)
break;
engine++;
} while (1);
if (mpm_ids != 0) {
//SCLogNotice("tx %p Mpm IDs: %"PRIx64, tx, mpm_ids);
AppLayerParserSetTxMpmIDs(ipproto, alproto, tx, mpm_ids);
}
}
}

Loading…
Cancel
Save