detect/byte_math: Permit var name for bytes value

Issue: 6145

Modifications to permit a variable name to be used for the byte_math
bytes value.
pull/9198/head
Jeff Lucovsky 3 years ago committed by Victor Julien
parent fb847d8bb0
commit 690b65ae88

@ -33,6 +33,7 @@ pub const DETECT_BYTEMATH_FLAG_STRING: u8 = 0x02;
pub const DETECT_BYTEMATH_FLAG_BITMASK: u8 = 0x04;
pub const DETECT_BYTEMATH_FLAG_ENDIAN: u8 = 0x08;
pub const DETECT_BYTEMATH_FLAG_RVALUE_VAR: u8 = 0x10;
pub const DETECT_BYTEMATH_FLAG_NBYTES_VAR: u8 = 0x20;
// Ensure required values are provided
const DETECT_BYTEMATH_FLAG_NBYTES: u8 = 0x1;
@ -98,6 +99,7 @@ enum ResultValue {
pub struct DetectByteMathData {
rvalue_str: *const c_char,
result: *const c_char,
nbytes_str: *const c_char,
rvalue: u32,
offset: i32,
bitmask_val: u32,
@ -120,6 +122,9 @@ impl Drop for DetectByteMathData {
if !self.rvalue_str.is_null() {
let _ = CString::from_raw(self.rvalue_str as *mut c_char);
}
if !self.nbytes_str.is_null() {
let _ = CString::from_raw(self.nbytes_str as *mut c_char);
}
}
}
}
@ -133,6 +138,7 @@ impl Default for DetectByteMathData {
offset: 0,
oper: ByteMathOperator::OperatorNone,
rvalue_str: std::ptr::null_mut(),
nbytes_str: std::ptr::null_mut(),
rvalue: 0,
result: std::ptr::null_mut(),
endian: DETECT_BYTEMATH_ENDIAN_DEFAULT,
@ -190,12 +196,12 @@ fn get_endian_value(value: &str) -> Result<ByteMathEndian, ()> {
// Parsed as a u64 for validation with u32 {min,max} so values greater than uint32
// are not treated as a string value.
fn parse_rvalue(input: &str) -> IResult<&str, ResultValue, RuleParseError<&str>> {
let (input, rvalue) = parse_token(input)?;
if let Ok(val) = rvalue.parse::<u64>() {
fn parse_var(input: &str) -> IResult<&str, ResultValue, RuleParseError<&str>> {
let (input, value) = parse_token(input)?;
if let Ok(val) = value.parse::<u64>() {
Ok((input, ResultValue::Numeric(val)))
} else {
Ok((input, ResultValue::String(rvalue.to_string())))
Ok((input, ResultValue::String(value.to_string())))
}
}
@ -259,7 +265,7 @@ fn parse_bytemath(input: &str) -> IResult<&str, DetectByteMathData, RuleParseErr
if 0 != (required_flags & DETECT_BYTEMATH_FLAG_RVALUE) {
return Err(make_error("rvalue already set".to_string()));
}
let (_, res) = parse_rvalue(val)?;
let (_, res) = parse_var(val)?;
match res {
ResultValue::Numeric(val) => {
if val >= u32::MIN.into() && val <= u32::MAX.into() {
@ -358,14 +364,29 @@ fn parse_bytemath(input: &str) -> IResult<&str, DetectByteMathData, RuleParseErr
if 0 != (required_flags & DETECT_BYTEMATH_FLAG_NBYTES) {
return Err(make_error("nbytes already set".to_string()));
}
byte_math.nbytes = val
.parse()
.map_err(|_| make_error(format!("invalid bytes value: {}", val)))?;
if byte_math.nbytes < 1 || byte_math.nbytes > 10 {
return Err(make_error(format!(
"invalid bytes value: must be between 1 and 10: {}",
byte_math.nbytes
)));
let (_, res) = parse_var(val)?;
match res {
ResultValue::Numeric(val) => {
if (1..=10).contains(&val) {
byte_math.nbytes = val as u8
} else {
return Err(make_error(format!(
"invalid nbytes value: must be between 1 and 10: {}",
val
)));
}
}
ResultValue::String(val) => match CString::new(val) {
Ok(newval) => {
byte_math.nbytes_str = newval.into_raw();
byte_math.flags |= DETECT_BYTEMATH_FLAG_NBYTES_VAR;
}
_ => {
return Err(make_error(
"parse string not safely convertible to C".to_string(),
))
}
},
}
required_flags |= DETECT_BYTEMATH_FLAG_NBYTES;
}
@ -439,6 +460,14 @@ mod tests {
return false;
}
if !self.nbytes_str.is_null() && !other.nbytes_str.is_null() {
let s_val = unsafe { CStr::from_ptr(self.nbytes_str) };
let o_val = unsafe { CStr::from_ptr(other.nbytes_str) };
res = s_val == o_val;
} else if !self.nbytes_str.is_null() || !other.nbytes_str.is_null() {
return false;
}
if !self.result.is_null() && !self.result.is_null() {
let s_val = unsafe { CStr::from_ptr(self.result) };
let o_val = unsafe { CStr::from_ptr(other.result) };
@ -462,7 +491,7 @@ mod tests {
}
fn valid_test(
args: &str, nbytes: u8, offset: i32, oper: ByteMathOperator, rvalue_str: &str, rvalue: u32,
args: &str, nbytes: u8, offset: i32, oper: ByteMathOperator, rvalue_str: &str, nbytes_str: &str, rvalue: u32,
result: &str, base: ByteMathBase, endian: ByteMathEndian, bitmask_val: u32, flags: u8,
) {
let bmd = DetectByteMathData {
@ -474,6 +503,11 @@ mod tests {
} else {
std::ptr::null_mut()
},
nbytes_str: if !nbytes_str.is_empty() {
CString::new(nbytes_str).unwrap().into_raw()
} else {
std::ptr::null_mut()
},
rvalue,
result: CString::new(result).unwrap().into_raw(),
base,
@ -501,6 +535,7 @@ mod tests {
3933,
ByteMathOperator::Addition,
"myrvalue",
"",
0,
"myresult",
ByteMathBase::BaseDec,
@ -517,6 +552,7 @@ mod tests {
3933,
ByteMathOperator::Addition,
"",
"",
99,
"other",
ByteMathBase::BaseDec,
@ -531,6 +567,7 @@ mod tests {
-3933,
ByteMathOperator::Addition,
"rvalue",
"",
0,
"foo",
BASE_DEFAULT,
@ -539,6 +576,21 @@ mod tests {
DETECT_BYTEMATH_FLAG_RVALUE_VAR,
);
valid_test(
"bytes nbytes_var, offset -3933, oper +, rvalue myrvalue, result foo",
0,
-3933,
ByteMathOperator::Addition,
"rvalue",
"nbytes_var",
0,
"foo",
BASE_DEFAULT,
ByteMathEndian::BigEndian,
0,
DETECT_BYTEMATH_FLAG_RVALUE_VAR | DETECT_BYTEMATH_FLAG_NBYTES_VAR,
);
// Out of order
valid_test(
"string dec, endian big, result other, rvalue 99, oper +, offset 3933, bytes 4",
@ -546,6 +598,7 @@ mod tests {
3933,
ByteMathOperator::Addition,
"",
"",
99,
"other",
ByteMathBase::BaseDec,

@ -76,20 +76,30 @@ void DetectBytemathRegister(void)
#endif
}
static inline bool DetectByteMathValidateNbytesOnly(const DetectByteMathData *data, int32_t nbytes)
{
return nbytes >= 1 &&
(((data->flags & DETECT_BYTEMATH_FLAG_STRING) && nbytes <= 10) || (nbytes <= 4));
}
int DetectByteMathDoMatch(DetectEngineThreadCtx *det_ctx, const SigMatchData *smd,
const Signature *s, const uint8_t *payload,
uint16_t payload_len, uint64_t rvalue, uint64_t *value, uint8_t endian)
const Signature *s, const uint8_t *payload, uint16_t payload_len, uint8_t nbytes,
uint64_t rvalue, uint64_t *value, uint8_t endian)
{
const DetectByteMathData *data = (DetectByteMathData *)smd->ctx;
if (payload_len == 0) {
return 0;
}
if (!DetectByteMathValidateNbytesOnly(data, nbytes)) {
return 0;
}
const uint8_t *ptr;
int32_t len;
uint64_t val;
int extbytes;
if (payload_len == 0) {
return 0;
}
/* Calculate the ptr value for the byte-math op and length remaining in
* the packet from that point.
*/
@ -116,33 +126,30 @@ int DetectByteMathDoMatch(DetectEngineThreadCtx *det_ctx, const SigMatchData *sm
}
/* Validate that the to-be-extracted is within the packet */
if (ptr < payload || data->nbytes > len) {
SCLogDebug("Data not within payload pkt=%p, ptr=%p, len=%"PRIu32", nbytes=%d",
payload, ptr, len, data->nbytes);
if (ptr < payload || nbytes > len) {
SCLogDebug("Data not within payload pkt=%p, ptr=%p, len=%" PRIu32 ", nbytes=%d", payload,
ptr, len, nbytes);
return 0;
}
/* Extract the byte data */
if (data->flags & DETECT_BYTEMATH_FLAG_STRING) {
extbytes = ByteExtractStringUint64(&val, data->base,
data->nbytes, (const char *)ptr);
extbytes = ByteExtractStringUint64(&val, data->base, nbytes, (const char *)ptr);
if (extbytes <= 0) {
if (val == 0) {
SCLogDebug("No Numeric value");
return 0;
} else {
SCLogDebug("error extracting %d bytes of string data: %d",
data->nbytes, extbytes);
SCLogDebug("error extracting %d bytes of string data: %d", nbytes, extbytes);
return -1;
}
}
} else {
ByteMathEndian bme = endian;
int endianness = (bme == BigEndian) ? BYTE_BIG_ENDIAN : BYTE_LITTLE_ENDIAN;
extbytes = ByteExtractUint64(&val, endianness, data->nbytes, ptr);
if (extbytes != data->nbytes) {
SCLogDebug("error extracting %d bytes of numeric data: %d",
data->nbytes, extbytes);
extbytes = ByteExtractUint64(&val, endianness, nbytes, ptr);
if (extbytes != nbytes) {
SCLogDebug("error extracting %d bytes of numeric data: %d", nbytes, extbytes);
return 0;
}
}
@ -206,7 +213,8 @@ int DetectByteMathDoMatch(DetectEngineThreadCtx *det_ctx, const SigMatchData *sm
* \retval bmd On success an instance containing the parsed data.
* On failure, NULL.
*/
static DetectByteMathData *DetectByteMathParse(DetectEngineCtx *de_ctx, const char *arg, char **rvalue)
static DetectByteMathData *DetectByteMathParse(
DetectEngineCtx *de_ctx, const char *arg, char **nbytes, char **rvalue)
{
DetectByteMathData *bmd;
if ((bmd = ScByteMathParse(arg)) == NULL) {
@ -214,6 +222,19 @@ static DetectByteMathData *DetectByteMathParse(DetectEngineCtx *de_ctx, const ch
return NULL;
}
if (bmd->nbytes_str) {
if (nbytes == NULL) {
SCLogError("byte_math supplied with "
"var name for nbytes. \"nbytes\" argument supplied to "
"this function must be non-NULL");
goto error;
}
*nbytes = SCStrdup(bmd->nbytes_str);
if (*nbytes == NULL) {
goto error;
}
}
if (bmd->rvalue_str) {
if (rvalue == NULL) {
SCLogError("byte_math supplied with "
@ -262,9 +283,10 @@ static int DetectByteMathSetup(DetectEngineCtx *de_ctx, Signature *s, const char
SigMatch *prev_pm = NULL;
DetectByteMathData *data;
char *rvalue = NULL;
char *nbytes = NULL;
int ret = -1;
data = DetectByteMathParse(de_ctx, arg, &rvalue);
data = DetectByteMathParse(de_ctx, arg, &nbytes, &rvalue);
if (data == NULL)
goto error;
@ -336,6 +358,18 @@ static int DetectByteMathSetup(DetectEngineCtx *de_ctx, Signature *s, const char
}
}
if (nbytes != NULL) {
DetectByteIndexType index;
if (!DetectByteRetrieveSMVar(nbytes, s, &index)) {
SCLogError("unknown byte_ keyword var seen in byte_math - %s", nbytes);
goto error;
}
data->nbytes = index;
data->flags |= DETECT_BYTEMATH_FLAG_NBYTES_VAR;
SCFree(nbytes);
nbytes = NULL;
}
if (rvalue != NULL) {
DetectByteIndexType index;
if (!DetectByteRetrieveSMVar(rvalue, s, &index)) {
@ -386,6 +420,8 @@ static int DetectByteMathSetup(DetectEngineCtx *de_ctx, Signature *s, const char
error:
if (rvalue)
SCFree(rvalue);
if (nbytes)
SCFree(nbytes);
DetectByteMathFree(de_ctx, data);
return ret;
}
@ -448,8 +484,10 @@ SigMatch *DetectByteMathRetrieveSMVar(const char *arg, const Signature *s)
static int DetectByteMathParseTest01(void)
{
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +,"
"rvalue 10, result bar", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +,"
"rvalue 10, result bar",
NULL, NULL);
FAIL_IF(bmd == NULL);
FAIL_IF_NOT(bmd->nbytes == 4);
@ -468,8 +506,10 @@ static int DetectByteMathParseTest01(void)
static int DetectByteMathParseTest02(void)
{
/* bytes value invalid */
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 257, offset 2, oper +, "
"rvalue 39, result bar", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 257, offset 2, oper +, "
"rvalue 39, result bar",
NULL, NULL);
FAIL_IF_NOT(bmd == NULL);
@ -479,8 +519,10 @@ static int DetectByteMathParseTest02(void)
static int DetectByteMathParseTest03(void)
{
/* bytes value invalid */
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 11, offset 2, oper +, "
"rvalue 39, result bar", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 11, offset 2, oper +, "
"rvalue 39, result bar",
NULL, NULL);
FAIL_IF_NOT(bmd == NULL);
PASS;
@ -489,8 +531,10 @@ static int DetectByteMathParseTest03(void)
static int DetectByteMathParseTest04(void)
{
/* offset value invalid */
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 70000, oper +,"
" rvalue 39, result bar", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 70000, oper +,"
" rvalue 39, result bar",
NULL, NULL);
FAIL_IF_NOT(bmd == NULL);
@ -500,8 +544,10 @@ static int DetectByteMathParseTest04(void)
static int DetectByteMathParseTest05(void)
{
/* oper value invalid */
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 11, offset 16, oper &,"
"rvalue 39, result bar", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 11, offset 16, oper &,"
"rvalue 39, result bar",
NULL, NULL);
FAIL_IF_NOT(bmd == NULL);
PASS;
@ -512,9 +558,10 @@ static int DetectByteMathParseTest06(void)
uint8_t flags = DETECT_BYTEMATH_FLAG_RELATIVE;
char *rvalue = NULL;
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 0, oper +,"
"rvalue 248, result var, relative",
&rvalue);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 0, oper +,"
"rvalue 248, result var, relative",
NULL, &rvalue);
FAIL_IF(bmd == NULL);
FAIL_IF_NOT(bmd->nbytes == 4);
@ -535,9 +582,10 @@ static int DetectByteMathParseTest07(void)
{
char *rvalue = NULL;
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +,"
"rvalue foo, result bar",
&rvalue);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +,"
"rvalue foo, result bar",
NULL, &rvalue);
FAIL_IF_NOT(rvalue);
FAIL_IF_NOT(bmd->nbytes == 4);
FAIL_IF_NOT(bmd->offset == 2);
@ -557,8 +605,10 @@ static int DetectByteMathParseTest07(void)
static int DetectByteMathParseTest08(void)
{
/* ensure Parse checks the pointer value when rvalue is a var */
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +,"
"rvalue foo, result bar", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +,"
"rvalue foo, result bar",
NULL, NULL);
FAIL_IF_NOT(bmd == NULL);
PASS;
@ -568,9 +618,10 @@ static int DetectByteMathParseTest09(void)
{
uint8_t flags = DETECT_BYTEMATH_FLAG_RELATIVE;
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +,"
"rvalue 39, result bar, relative",
NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +,"
"rvalue 39, result bar, relative",
NULL, NULL);
FAIL_IF(bmd == NULL);
FAIL_IF_NOT(bmd->nbytes == 4);
@ -591,9 +642,11 @@ static int DetectByteMathParseTest10(void)
{
uint8_t flags = DETECT_BYTEMATH_FLAG_ENDIAN;
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +,"
"rvalue 39, result bar, endian"
" big", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +,"
"rvalue 39, result bar, endian"
" big",
NULL, NULL);
FAIL_IF(bmd == NULL);
FAIL_IF_NOT(bmd->nbytes == 4);
@ -614,9 +667,10 @@ static int DetectByteMathParseTest11(void)
{
uint8_t flags = DETECT_BYTEMATH_FLAG_ENDIAN;
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +, "
"rvalue 39, result bar, dce",
NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +, "
"rvalue 39, result bar, dce",
NULL, NULL);
FAIL_IF(bmd == NULL);
FAIL_IF_NOT(bmd->nbytes == 4);
@ -637,9 +691,11 @@ static int DetectByteMathParseTest12(void)
{
uint8_t flags = DETECT_BYTEMATH_FLAG_RELATIVE | DETECT_BYTEMATH_FLAG_STRING;
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +,"
"rvalue 39, result bar, "
"relative, string dec", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +,"
"rvalue 39, result bar, "
"relative, string dec",
NULL, NULL);
FAIL_IF(bmd == NULL);
FAIL_IF_NOT(bmd->nbytes == 4);
@ -662,10 +718,12 @@ static int DetectByteMathParseTest13(void)
DETECT_BYTEMATH_FLAG_RELATIVE |
DETECT_BYTEMATH_FLAG_BITMASK;
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +, "
"rvalue 39, result bar, "
"relative, string dec, bitmask "
"0x8f40", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +, "
"rvalue 39, result bar, "
"relative, string dec, bitmask "
"0x8f40",
NULL, NULL);
FAIL_IF(bmd == NULL);
FAIL_IF_NOT(bmd->nbytes == 4);
@ -688,8 +746,10 @@ static int DetectByteMathParseTest13(void)
static int DetectByteMathParseTest14(void)
{
/* incomplete */
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +,"
"rvalue foo", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +,"
"rvalue foo",
NULL, NULL);
FAIL_IF_NOT(bmd == NULL);
@ -700,8 +760,10 @@ static int DetectByteMathParseTest15(void)
{
/* incomplete */
DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +, "
"result bar", NULL);
DetectByteMathData *bmd = DetectByteMathParse(NULL,
"bytes 4, offset 2, oper +, "
"result bar",
NULL, NULL);
FAIL_IF_NOT(bmd == NULL);
@ -718,7 +780,7 @@ static int DetectByteMathParseTest16(void)
"rvalue 39, result bar, "
"relative, string dec, bitmask "
"0x8f40",
NULL);
NULL, NULL);
FAIL_IF(bmd == NULL);
FAIL_IF_NOT(bmd->nbytes == 4);

@ -28,6 +28,6 @@ void DetectBytemathRegister(void);
SigMatch *DetectByteMathRetrieveSMVar(const char *, const Signature *);
int DetectByteMathDoMatch(DetectEngineThreadCtx *, const SigMatchData *, const Signature *,
const uint8_t *, uint16_t, uint64_t, uint64_t *, uint8_t);
const uint8_t *, uint16_t, uint8_t, uint64_t, uint64_t *, uint8_t);
#endif /* __DETECT_BYTEMATH_H__ */

@ -588,8 +588,15 @@ uint8_t DetectEngineContentInspection(DetectEngineCtx *de_ctx, DetectEngineThrea
rvalue = bmd->rvalue;
}
uint8_t nbytes;
if (bmd->flags & DETECT_BYTEMATH_FLAG_NBYTES_VAR) {
nbytes = (uint8_t)det_ctx->byte_values[bmd->nbytes];
} else {
nbytes = bmd->nbytes;
}
DEBUG_VALIDATE_BUG_ON(buffer_len > UINT16_MAX);
if (DetectByteMathDoMatch(det_ctx, smd, s, buffer, (uint16_t)buffer_len, rvalue,
if (DetectByteMathDoMatch(det_ctx, smd, s, buffer, (uint16_t)buffer_len, nbytes, rvalue,
&det_ctx->byte_values[bmd->local_id], endian) != 1) {
goto no_match;
}

Loading…
Cancel
Save