diff --git a/src/detect-ttl.c b/src/detect-ttl.c index 051131f950..97e9055af4 100644 --- a/src/detect-ttl.c +++ b/src/detect-ttl.c @@ -29,6 +29,7 @@ #include "detect.h" #include "detect-parse.h" +#include "detect-engine-prefilter-common.h" #include "detect-ttl.h" #include "util-debug.h" @@ -47,6 +48,9 @@ static int DetectTtlSetup (DetectEngineCtx *, Signature *, char *); void DetectTtlFree (void *); void DetectTtlRegisterTests (void); +static int PrefilterSetupTtl(SigGroupHead *sgh); +static _Bool PrefilterTtlIsPrefilterable(const Signature *s); + /** * \brief Registration function for ttl: keyword */ @@ -61,10 +65,29 @@ void DetectTtlRegister(void) sigmatch_table[DETECT_TTL].Free = DetectTtlFree; sigmatch_table[DETECT_TTL].RegisterTests = DetectTtlRegisterTests; + sigmatch_table[DETECT_TTL].SupportsPrefilter = PrefilterTtlIsPrefilterable; + sigmatch_table[DETECT_TTL].SetupPrefilter = PrefilterSetupTtl; + DetectSetupParseRegexes(PARSE_REGEX, &parse_regex, &parse_regex_study); return; } +static inline int TtlMatch(const uint8_t pttl, const uint8_t mode, + const uint8_t dttl1, const uint8_t dttl2) +{ + if (mode == DETECT_TTL_EQ && pttl == dttl1) + return 1; + else if (mode == DETECT_TTL_LT && pttl < dttl1) + return 1; + else if (mode == DETECT_TTL_GT && pttl > dttl1) + return 1; + else if (mode == DETECT_TTL_RA && (pttl > dttl1 && pttl < dttl2)) + return 1; + + return 0; + +} + /** * \brief This function is used to match TTL rule option on a packet with those passed via ttl: * @@ -79,32 +102,21 @@ void DetectTtlRegister(void) int DetectTtlMatch (ThreadVars *t, DetectEngineThreadCtx *det_ctx, Packet *p, Signature *s, const SigMatchCtx *ctx) { - int ret = 0; - uint8_t pttl; - const DetectTtlData *ttld = (const DetectTtlData *)ctx; - if (PKT_IS_PSEUDOPKT(p)) return 0; + uint8_t pttl; if (PKT_IS_IPV4(p)) { pttl = IPV4_GET_IPTTL(p); } else if (PKT_IS_IPV6(p)) { pttl = IPV6_GET_HLIM(p); } else { SCLogDebug("Packet is of not IPv4 or IPv6"); - return ret; + return 0; } - if (ttld->mode == DETECT_TTL_EQ && pttl == ttld->ttl1) - ret = 1; - else if (ttld->mode == DETECT_TTL_LT && pttl < ttld->ttl1) - ret = 1; - else if (ttld->mode == DETECT_TTL_GT && pttl > ttld->ttl1) - ret = 1; - else if (ttld->mode == DETECT_TTL_RA && (pttl > ttld->ttl1 && pttl < ttld->ttl2)) - ret = 1; - - return ret; + const DetectTtlData *ttld = (const DetectTtlData *)ctx; + return TtlMatch(pttl, ttld->mode, ttld->ttl1, ttld->ttl2); } /** @@ -293,6 +305,74 @@ void DetectTtlFree(void *ptr) SCFree(ttld); } +/* prefilter code */ + +static void +PrefilterPacketTtlMatch(DetectEngineThreadCtx *det_ctx, Packet *p, const void *pectx) +{ + if (PKT_IS_PSEUDOPKT(p)) { + SCReturn; + } + + uint8_t pttl; + if (PKT_IS_IPV4(p)) { + pttl = IPV4_GET_IPTTL(p); + } else if (PKT_IS_IPV6(p)) { + pttl = IPV6_GET_HLIM(p); + } else { + SCLogDebug("Packet is of not IPv4 or IPv6"); + return; + } + + const PrefilterPacketHeaderCtx *ctx = pectx; + + if (TtlMatch(pttl, ctx->v1.u8[0], ctx->v1.u8[1], ctx->v1.u8[2])) + { + SCLogDebug("packet matches ttl/hl %u", pttl); + PrefilterAddSids(&det_ctx->pmq, ctx->sigs_array, ctx->sigs_cnt); + } +} + +static void +PrefilterPacketTtlSet(PrefilterPacketHeaderValue *v, void *smctx) +{ + const DetectTtlData *a = smctx; + v->u8[0] = a->mode; + v->u8[1] = a->ttl1; + v->u8[2] = a->ttl2; +} + +static _Bool +PrefilterPacketTtlCompare(PrefilterPacketHeaderValue v, void *smctx) +{ + const DetectTtlData *a = smctx; + if (v.u8[0] == a->mode && + v.u8[1] == a->ttl1 && + v.u8[2] == a->ttl2) + return TRUE; + return FALSE; +} + +static int PrefilterSetupTtl(SigGroupHead *sgh) +{ + return PrefilterSetupPacketHeader(sgh, DETECT_TTL, + PrefilterPacketTtlSet, + PrefilterPacketTtlCompare, + PrefilterPacketTtlMatch); +} + +static _Bool PrefilterTtlIsPrefilterable(const Signature *s) +{ + const SigMatch *sm; + for (sm = s->sm_lists[DETECT_SM_LIST_MATCH] ; sm != NULL; sm = sm->next) { + switch (sm->type) { + case DETECT_TTL: + return TRUE; + } + } + return FALSE; +} + #ifdef UNITTESTS #include "detect-engine.h" #include "detect-engine-mpm.h"