Add memcmp api with a plain memcmp function and a SSE3 accelerated memcmp.

remotes/origin/master-1.1.x
Victor Julien 15 years ago
parent 94898a91cc
commit 1859ed54c7

@ -175,6 +175,7 @@ util-privs.c util-privs.h \
util-decode-asn1.c util-decode-asn1.h \
util-ringbuffer.c util-ringbuffer.h \
util-validate.h \
util-memcmp.c util-memcmp.h \
tm-modules.c tm-modules.h \
tm-queues.c tm-queues.h \
tm-queuehandlers.c tm-queuehandlers.h \

@ -47,6 +47,7 @@
#include "util-spm.h"
#include "util-unittest.h"
#include "util-debug.h"
#include "util-memcmp.h"
/**
* \brief This function is called to determine and set which command is being
@ -62,15 +63,8 @@ static int FTPParseRequestCommand(void *ftp_state, uint8_t *input,
SCEnter();
FtpState *fstate = (FtpState *)ftp_state;
char inputlower[5];
if (input_len >= 4) {
memcpy(inputlower,input,4);
int i = 0;
for (; i < 4; i++)
inputlower[i] = tolower(inputlower[i]);
if (memcmp(inputlower, "port", 4) == 0) {
if (SCMemcmpLowercase("port", input, 4) == 0) {
fstate->command = FTP_COMMAND_PORT;
}

@ -43,6 +43,7 @@
#include "util-spm.h"
#include "util-unittest.h"
#include "util-memcmp.h"
#include "app-layer-smb.h"
@ -834,7 +835,7 @@ static int SMBParseHeader(Flow *f, void *smb_state,
switch (sstate->bytesprocessed) {
case 4:
if (input_len >= SMB_HDR_LEN) {
if (memcmp(p, "\xff\x53\x4d\x42", 4) != 0) {
if (SCMemcmp(p, "\xff\x53\x4d\x42", 4) != 0) {
SCLogDebug("SMB Header did not validate");
SCReturnInt(-1);
}

@ -43,6 +43,7 @@
#include "util-spm.h"
#include "util-unittest.h"
#include "util-debug.h"
#include "util-memcmp.h"
#include "app-layer-smb2.h"
@ -101,7 +102,7 @@ static uint32_t SMB2ParseHeader(void *smb2_state, AppLayerParserState *pstate,
switch (sstate->bytesprocessed) {
case 4:
if (input_len >= SMB2_HDR_LEN) {
if (memcmp(p, "\xfe\x53\x4d\x42", 4) != 0) {
if (SCMemcmp(p, "\xfe\x53\x4d\x42", 4) != 0) {
//printf("SMB2 Header did not validate\n");
return 0;
}

@ -49,6 +49,7 @@
#include "flow-private.h"
#include "util-byte.h"
#include "util-memcmp.h"
/**
* \brief Function to parse the SSH version string of the server
@ -131,7 +132,7 @@ static int SSHParseServerVersion(Flow *f, void *ssh_state, AppLayerParserState *
}
/* is it the version line? */
if (memcmp("SSH-", line_ptr, 4) == 0) {
if (SCMemcmp("SSH-", line_ptr, 4) == 0) {
if (line_len > 255) {
SCLogDebug("Invalid version string, it should be less than 255 characters including <CR><NL>");
SCReturnInt(-1);
@ -464,7 +465,7 @@ static int SSHParseClientVersion(Flow *f, void *ssh_state, AppLayerParserState *
}
/* is it the version line? */
if (memcmp("SSH-", line_ptr, 4) == 0) {
if (SCMemcmp("SSH-", line_ptr, 4) == 0) {
if (line_len > 255) {
SCLogDebug("Invalid version string, it should be less than 255 characters including <CR><NL>");
SCReturnInt(-1);

@ -52,6 +52,7 @@
#include "util-enum.h"
#include "util-debug.h"
#include "util-print.h"
#include "util-memcmp.h"
/** \todo make it possible to use multiple pattern matcher algorithms next to
eachother. */
@ -440,7 +441,7 @@ char ContentHashCompareFunc(void *data1, uint16_t len1, void *data2, uint16_t le
DetectContentData *co2 = ch2->ptr;
if (co1->content_len == co2->content_len &&
memcmp(co1->content, co2->content, co1->content_len) == 0)
SCMemcmp(co1->content, co2->content, co1->content_len) == 0)
return 1;
return 0;
@ -453,7 +454,7 @@ char UricontentHashCompareFunc(void *data1, uint16_t len1, void *data2, uint16_t
DetectUricontentData *ud2 = ch2->ptr;
if (ud1->uricontent_len == ud2->uricontent_len &&
memcmp(ud1->uricontent, ud2->uricontent, ud1->uricontent_len) == 0)
SCMemcmp(ud1->uricontent, ud2->uricontent, ud1->uricontent_len) == 0)
return 1;
return 0;
@ -1523,7 +1524,7 @@ static char MpmPatternIdCompare(void *p1, uint16_t len1, void *p2, uint16_t len2
SCReturnInt(0);
}
if (memcmp(e1->pattern, e2->pattern, e1->pattern_len) != 0) {
if (SCMemcmp(e1->pattern, e2->pattern, e1->pattern_len) != 0) {
SCReturnInt(0);
}

@ -48,7 +48,7 @@
#include "util-cidr.h"
#include "util-unittest.h"
#include "util-unittest-helper.h"
#include "util-memcmp.h"
/* prototypes */
int SigGroupHeadClearSigs(SigGroupHead *);
@ -240,7 +240,7 @@ char SigGroupHeadMpmCompareFunc(void *data1, uint16_t len1, void *data2,
if (sgh1->init->content_size != sgh2->init->content_size)
return 0;
if (memcmp(sgh1->init->content_array, sgh2->init->content_array,
if (SCMemcmp(sgh1->init->content_array, sgh2->init->content_array,
sgh1->init->content_size) != 0) {
return 0;
}
@ -367,7 +367,7 @@ char SigGroupHeadMpmUriCompareFunc(void *data1, uint16_t len1, void *data2,
if (sgh1->init->uri_content_size != sgh2->init->uri_content_size)
return 0;
if (memcmp(sgh1->init->uri_content_array, sgh2->init->uri_content_array,
if (SCMemcmp(sgh1->init->uri_content_array, sgh2->init->uri_content_array,
sgh1->init->uri_content_size) != 0) {
return 0;
}
@ -494,7 +494,7 @@ char SigGroupHeadMpmStreamCompareFunc(void *data1, uint16_t len1, void *data2,
if (sgh1->init->stream_content_size != sgh2->init->stream_content_size)
return 0;
if (memcmp(sgh1->init->stream_content_array, sgh2->init->stream_content_array,
if (SCMemcmp(sgh1->init->stream_content_array, sgh2->init->stream_content_array,
sgh1->init->stream_content_size) != 0) {
return 0;
}
@ -626,7 +626,7 @@ char SigGroupHeadCompareFunc(void *data1, uint16_t len1, void *data2,
if (sgh1->init->sig_size != sgh2->init->sig_size)
return 0;
if (memcmp(sgh1->init->sig_array, sgh2->init->sig_array, sgh1->init->sig_size) != 0)
if (SCMemcmp(sgh1->init->sig_array, sgh2->init->sig_array, sgh1->init->sig_size) != 0)
return 0;
return 1;

@ -144,6 +144,7 @@
#include "util-ringbuffer.h"
#include "util-mem.h"
#include "util-memcmp.h"
/*
* we put this here, because we only use it here in main.
@ -934,6 +935,8 @@ int main(int argc, char **argv)
#endif
DeStateRegisterTests();
DetectRingBufferRegisterTests();
MemcmpRegisterTests();
if (list_unittests) {
UtListTests(regex_arg);
}

@ -29,6 +29,7 @@
#include "suricata-common.h"
#include "util-hash.h"
#include "util-unittest.h"
#include "util-memcmp.h"
HashTable* HashTableInit(uint32_t size, uint32_t (*Hash)(struct HashTable_ *, void *, uint16_t), char (*Compare)(void *, uint16_t, void *, uint16_t), void (*Free)(void *)) {
@ -220,7 +221,7 @@ char HashTableDefaultCompare(void *data1, uint16_t len1, void *data2, uint16_t l
if (len1 != len2)
return 0;
if (memcmp(data1,data2,len1) != 0)
if (SCMemcmp(data1,data2,len1) != 0)
return 0;
return 1;
@ -288,7 +289,7 @@ static char HashTableDefaultCompareTest(void *data1, uint16_t len1, void *data2,
if (len1 != len2)
return 0;
if (memcmp(data1,data2,len1) != 0)
if (SCMemcmp(data1,data2,len1) != 0)
return 0;
return 1;

@ -30,6 +30,7 @@
#include "util-hashlist.h"
#include "util-unittest.h"
#include "util-debug.h"
#include "util-memcmp.h"
HashListTable* HashListTableInit(uint32_t size, uint32_t (*Hash)(struct HashListTable_ *, void *, uint16_t), char (*Compare)(void *, uint16_t, void *, uint16_t), void (*Free)(void *)) {
@ -233,7 +234,7 @@ char HashListTableDefaultCompare(void *data1, uint16_t len1, void *data2, uint16
if (len1 != len2)
return 0;
if (memcmp(data1,data2,len1) != 0)
if (SCMemcmp(data1,data2,len1) != 0)
return 0;
return 1;

@ -0,0 +1,185 @@
/* Copyright (C) 2007-2010 Open Information Security Foundation
*
* You can copy, redistribute or modify this Program under the terms of
* the GNU General Public License version 2 as published by the Free
* Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* version 2 along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301, USA.
*/
/**
* \file
*
* \author Victor Julien <victor@inliniac.net>
*
* Memcmp implementations.
*/
#include "suricata-common.h"
#include "util-memcmp.h"
#include "util-unittest.h"
/* code is implemented in util-memcmp.h as it's all inlined */
/* UNITTESTS */
#ifdef UNITTESTS
static int MemcmpTest01 (void) {
uint8_t a[] = "abcd";
uint8_t b[] = "abcd";
if (SCMemcmp(a, b, sizeof(a)-1) != 0)
return 0;
return 1;
}
static int MemcmpTest02 (void) {
uint8_t a[] = "abcdabcdabcdabcd";
uint8_t b[] = "abcdabcdabcdabcd";
if (SCMemcmp(a, b, sizeof(a)-1) != 0)
return 0;
return 1;
}
static int MemcmpTest03 (void) {
uint8_t a[] = "abcdabcd";
uint8_t b[] = "abcdabcd";
if (SCMemcmp(a, b, sizeof(a)-1) != 0)
return 0;
return 1;
}
static int MemcmpTest04 (void) {
uint8_t a[] = "abcd";
uint8_t b[] = "abcD";
if (SCMemcmp(a, b, sizeof(a)-1) != 1)
return 0;
return 1;
}
static int MemcmpTest05 (void) {
uint8_t a[] = "abcdabcdabcdabcd";
uint8_t b[] = "abcDabcdabcdabcd";
if (SCMemcmp(a, b, sizeof(a)-1) != 1)
return 0;
return 1;
}
static int MemcmpTest06 (void) {
uint8_t a[] = "abcdabcd";
uint8_t b[] = "abcDabcd";
if (SCMemcmp(a, b, sizeof(a)-1) != 1)
return 0;
return 1;
}
static int MemcmpTest07 (void) {
uint8_t a[] = "abcd";
uint8_t b[] = "abcde";
if (SCMemcmp(a, b, sizeof(a)-1) != 0)
return 0;
return 1;
}
static int MemcmpTest08 (void) {
uint8_t a[] = "abcdabcdabcdabcd";
uint8_t b[] = "abcdabcdabcdabcde";
if (SCMemcmp(a, b, sizeof(a)-1) != 0)
return 0;
return 1;
}
static int MemcmpTest09 (void) {
uint8_t a[] = "abcdabcd";
uint8_t b[] = "abcdabcde";
if (SCMemcmp(a, b, sizeof(a)-1) != 0)
return 0;
return 1;
}
static int MemcmpTest10 (void) {
uint8_t a[] = "abcd";
uint8_t b[] = "Zbcde";
if (SCMemcmp(a, b, sizeof(a)-1) != 1)
return 0;
return 1;
}
static int MemcmpTest11 (void) {
uint8_t a[] = "abcdabcdabcdabcd";
uint8_t b[] = "Zbcdabcdabcdabcde";
if (SCMemcmp(a, b, sizeof(a)-1) != 1)
return 0;
return 1;
}
static int MemcmpTest12 (void) {
uint8_t a[] = "abcdabcd";
uint8_t b[] = "Zbcdabcde";
if (SCMemcmp(a, b, sizeof(a)-1) != 1)
return 0;
return 1;
}
static int MemcmpTest13 (void) {
uint8_t a[] = "abcdefgh";
uint8_t b[] = "AbCdEfGhIjK";
if (SCMemcmpLowercase(a, b, sizeof(a)-1) != 0)
return 0;
return 1;
}
#endif /* UNITTESTS */
void MemcmpRegisterTests(void) {
#ifdef UNITTESTS
UtRegisterTest("MemcmpTest01", MemcmpTest01, 1);
UtRegisterTest("MemcmpTest02", MemcmpTest02, 1);
UtRegisterTest("MemcmpTest03", MemcmpTest03, 1);
UtRegisterTest("MemcmpTest04", MemcmpTest04, 1);
UtRegisterTest("MemcmpTest05", MemcmpTest05, 1);
UtRegisterTest("MemcmpTest06", MemcmpTest06, 1);
UtRegisterTest("MemcmpTest07", MemcmpTest07, 1);
UtRegisterTest("MemcmpTest08", MemcmpTest08, 1);
UtRegisterTest("MemcmpTest09", MemcmpTest09, 1);
UtRegisterTest("MemcmpTest10", MemcmpTest10, 1);
UtRegisterTest("MemcmpTest11", MemcmpTest11, 1);
UtRegisterTest("MemcmpTest12", MemcmpTest12, 1);
UtRegisterTest("MemcmpTest13", MemcmpTest13, 1);
#endif /* UNITTESTS */
}

@ -0,0 +1,153 @@
/* Copyright (C) 2007-2010 Open Information Security Foundation
*
* You can copy, redistribute or modify this Program under the terms of
* the GNU General Public License version 2 as published by the Free
* Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* version 2 along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301, USA.
*/
/**
* \file
*
* \author Victor Julien <victor@inliniac.net>
*
* Memcmp implementations.
*/
#ifndef __UTIL_MEMCMP_H__
#define __UTIL_MEMCMP_H__
void MemcmpRegisterTests(void);
#if defined(__SSE3__)
#include <pmmintrin.h> /* for SSE3 */
#define SCMEMCMP_BYTES 16
static inline int SCMemcmp(void *, void *, size_t);
static inline int SCMemcmpLowercase(void *, void *, size_t);
static inline int SCMemcmp(void *s1, void *s2, size_t len) {
size_t offset = 0;
__m128i b1, b2, c;
do {
/* do unaligned loads using _mm_loadu_si128. On my Core2 E6600 using
* _mm_lddqu_si128 was about 2% slower even though it's supposed to
* be faster. */
b1 = _mm_loadu_si128((const __m128i *) s1);
b2 = _mm_loadu_si128((const __m128i *) s2);
c = _mm_cmpeq_epi8(b1, b2);
int diff = len - offset;
if (diff < 16) {
int rmask = ~(0xFFFFFFFF << diff);
if ((_mm_movemask_epi8(c) & rmask) != rmask) {
return 1;
}
} else {
if (_mm_movemask_epi8(c) != 0x0000FFFF) {
return 1;
}
}
offset += SCMEMCMP_BYTES;
s1 += SCMEMCMP_BYTES;
s2 += SCMEMCMP_BYTES;
} while (len > offset);
return 0;
}
#define UPPER_LOW 0x40 /* "A" - 1 */
#define UPPER_HIGH 0x5B /* "Z" + 1 */
#define UPPER_DELTA 0xDF /* 0xFF - 0x20 */
static inline int SCMemcmpLowercase(void *s1, void *s2, size_t len) {
size_t offset = 0;
__m128i b1, b2, mask1, mask2, upper1, upper2, delta;
/* setup registers for upper to lower conversion */
upper1 = _mm_set1_epi8(UPPER_LOW);
upper2 = _mm_set1_epi8(UPPER_HIGH);
delta = _mm_set1_epi8(UPPER_DELTA);
do {
/* unaligned loading of the bytes to compare */
b1 = _mm_loadu_si128((const __m128i *) s1);
b2 = _mm_loadu_si128((const __m128i *) s2);
/* mark all chars bigger than upper1 */
mask1 = _mm_cmpgt_epi8(b2, upper1);
/* mark all chars lower than upper2 */
mask2 = _mm_cmplt_epi8(b2, upper2);
/* merge the two, leaving only those that are true in both */
mask1 = _mm_cmpeq_epi8(mask1, mask2);
/* sub delta leaves 0x20 only for uppercase positions, the
rest is 0x00 due to the saturation (reuse mask1 reg)*/
mask1 = _mm_subs_epu8(mask1, delta);
/* add to b2, converting uppercase to lowercase */
b2 = _mm_add_epi8(b2, mask1);
/* now all is lowercase, let's do the actual compare (reuse mask1 reg) */
mask1 = _mm_cmpeq_epi8(b1, b2);
int diff = len - offset;
if (diff < 16) {
int rmask = ~(0xFFFFFFFF << diff);
if ((_mm_movemask_epi8(mask1) & rmask) != rmask) {
return 1;
}
} else {
if (_mm_movemask_epi8(mask1) != 0x0000FFFF) {
return 1;
}
}
offset += SCMEMCMP_BYTES;
s1 += SCMEMCMP_BYTES;
s2 += SCMEMCMP_BYTES;
} while (len > offset);
return 0;
}
#else
/* No SIMD support */
#define SCMemcmp memcmp
static inline int
SCMemcmpLowercase(void *s1, void *s2, size_t n) {
size_t i;
/* check backwards because we already tested the first
* 2 to 4 chars. This way we are more likely to detect
* a miss and thus speed up a little... */
for (i = n - 1; i; i--) {
if (((uint8_t *)s1)[i] != u8_tolower(*(((uint8_t *)s2)+i)))
return 1;
}
return 0;
}
#endif /* __SSE3__ */
#endif /* __UTIL_MEMCMP_H__ */

@ -41,6 +41,7 @@
#include "util-debug.h"
#include "util-unittest.h"
#include "util-memcmp.h"
#include "conf.h"
#define INIT_HASH_SIZE 65536
@ -244,7 +245,7 @@ static inline int B2gCmpPattern(B2gPattern *p, uint8_t *pat, uint16_t patlen, ch
if (p->flags != flags)
return 0;
if (memcmp(p->cs, pat, patlen) != 0)
if (SCMemcmp(p->cs, pat, patlen) != 0)
return 0;
return 1;
@ -320,7 +321,7 @@ static int B2gAddPattern(MpmCtx *mpm_ctx, uint8_t *pat, uint16_t patlen, uint16_
/* nocase means no difference between cs and ci */
p->cs = p->ci;
} else {
if (memcmp(p->ci,pat,p->len) == 0) {
if (SCMemcmp(p->ci,pat,p->len) == 0) {
/* no diff between cs and ci: pat is lowercase */
p->cs = p->ci;
} else {
@ -675,21 +676,6 @@ void B2gPrintSearchStats(MpmThreadCtx *mpm_thread_ctx) {
#endif /* B2G_COUNTERS */
}
static inline int
memcmp_lowercase(uint8_t *s1, uint8_t *s2, uint16_t n) {
size_t i;
/* check backwards because we already tested the first
* 2 to 4 chars. This way we are more likely to detect
* a miss and thus speed up a little... */
for (i = n - 1; i; i--) {
if (u8_tolower(*(s2+i)) != s1[i])
return 1;
}
return 0;
}
/**
* \brief Function to get the user defined values for b2g algorithm from the
* config file 'suricata.yaml'
@ -978,7 +964,8 @@ uint32_t B2gSearchBNDMq(MpmCtx *mpm_ctx, MpmThreadCtx *mpm_thread_ctx, PatternMa
if (thi->flags & MPM_PATTERN_FLAG_NOCASE) {
if (memcmp_lowercase(thi->ci, buf+j, thi->len) == 0) {
//if (memcmp_lowercase(thi->ci, buf+j, thi->len) == 0) {
if (SCMemcmpLowercase(thi->ci, buf+j, thi->len) == 0) {
#ifdef PRINTMATCH
printf("CI Exact match: "); prt(p->ci, p->len); printf("\n");
#endif
@ -989,7 +976,8 @@ uint32_t B2gSearchBNDMq(MpmCtx *mpm_ctx, MpmThreadCtx *mpm_thread_ctx, PatternMa
COUNT(tctx->stat_loop_no_match++);
}
} else {
if (memcmp(thi->cs, buf+j, thi->len) == 0) {
if (SCMemcmp(thi->cs, buf+j, thi->len) == 0) {
//if (memcmp(thi->cs, buf+j, thi->len) == 0) {
#ifdef PRINTMATCH
printf("CS Exact match: "); prt(p->cs, p->len); printf("\n");
#endif
@ -1088,7 +1076,7 @@ uint32_t B2gSearch(MpmCtx *mpm_ctx, MpmThreadCtx *mpm_thread_ctx, PatternMatcher
if (thi->flags & MPM_PATTERN_FLAG_NOCASE) {
if (memcmp_lowercase(thi->ci, buf+pos, thi->len) == 0) {
if (SCMemcmpLowercase(thi->ci, buf+pos, thi->len) == 0) {
COUNT(tctx->stat_loop_match++);
matches += MpmVerifyMatch(mpm_thread_ctx, pmq, thi->id);
@ -1096,7 +1084,7 @@ uint32_t B2gSearch(MpmCtx *mpm_ctx, MpmThreadCtx *mpm_thread_ctx, PatternMatcher
COUNT(tctx->stat_loop_no_match++);
}
} else {
if (memcmp(thi->cs, buf+pos, thi->len) == 0) {
if (SCMemcmp(thi->cs, buf+pos, thi->len) == 0) {
COUNT(tctx->stat_loop_match++);
matches += MpmVerifyMatch(mpm_thread_ctx, pmq, thi->id);

@ -41,6 +41,7 @@
#include "util-debug.h"
#include "util-error.h"
#include "util-unittest.h"
#include "util-memcmp.h"
/**
* \brief Validates an IPV4 address and returns the network endian arranged
@ -1224,7 +1225,7 @@ static void SCRadixRemoveKey(uint8_t *key_stream, uint16_t key_bitlen,
}
i = prefix->bitlen / 8;
if (memcmp(node->prefix->stream, prefix->stream, i) == 0) {
if (SCMemcmp(node->prefix->stream, prefix->stream, i) == 0) {
mask = -1 << (8 - prefix->bitlen % 8);
if (prefix->bitlen % 8 == 0 ||
@ -1439,7 +1440,7 @@ static inline SCRadixNode *SCRadixFindKeyIPNetblock(uint8_t *key_stream, uint8_t
if (node->bit != key_bitlen || node->prefix == NULL)
return NULL;
if (memcmp(node->prefix->stream, key_stream, bytes) == 0) {
if (SCMemcmp(node->prefix->stream, key_stream, bytes) == 0) {
mask = -1 << (8 - key_bitlen % 8);
if (key_bitlen % 8 == 0 ||
@ -1497,7 +1498,7 @@ static SCRadixNode *SCRadixFindKey(uint8_t *key_stream, uint16_t key_bitlen,
}
bytes = key_bitlen / 8;
if (memcmp(node->prefix->stream, tmp_stream, bytes) == 0) {
if (SCMemcmp(node->prefix->stream, tmp_stream, bytes) == 0) {
mask = -1 << (8 - key_bitlen % 8);
if (key_bitlen % 8 == 0 ||

Loading…
Cancel
Save