From 370ac0541984791978f7e92db2d90f9e9ade6ec5 Mon Sep 17 00:00:00 2001 From: Philippe Antoine Date: Fri, 22 Dec 2023 11:59:35 +0100 Subject: [PATCH] detect/integer: rust derive for enumerations Ticket: 6647 Allows keywords using integers to use strings in signature parsing based on a rust enumeration with a derive. --- rust/derive/src/lib.rs | 16 ++++++ rust/derive/src/stringenum.rs | 96 +++++++++++++++++++++++++++++++++++ rust/src/detect/mod.rs | 43 ++++++++++++++++ rust/src/detect/uint.rs | 45 ++++++++++++++++ 4 files changed, 200 insertions(+) create mode 100644 rust/derive/src/stringenum.rs diff --git a/rust/derive/src/lib.rs b/rust/derive/src/lib.rs index a2b7a6ad04..a36f19390c 100644 --- a/rust/derive/src/lib.rs +++ b/rust/derive/src/lib.rs @@ -23,6 +23,7 @@ use proc_macro::TokenStream; mod applayerevent; mod applayerframetype; +mod stringenum; /// The `AppLayerEvent` derive macro generates a `AppLayerEvent` trait /// implementation for enums that define AppLayerEvents. @@ -50,3 +51,18 @@ pub fn derive_app_layer_event(input: TokenStream) -> TokenStream { pub fn derive_app_layer_frame_type(input: TokenStream) -> TokenStream { applayerframetype::derive_app_layer_frame_type(input) } + +#[proc_macro_derive(EnumStringU8, attributes(name))] +pub fn derive_enum_string_u8(input: TokenStream) -> TokenStream { + stringenum::derive_enum_string::(input, "u8") +} + +#[proc_macro_derive(EnumStringU16, attributes(name))] +pub fn derive_enum_string_u16(input: TokenStream) -> TokenStream { + stringenum::derive_enum_string::(input, "u16") +} + +#[proc_macro_derive(EnumStringU32, attributes(name))] +pub fn derive_enum_string_u32(input: TokenStream) -> TokenStream { + stringenum::derive_enum_string::(input, "u32") +} diff --git a/rust/derive/src/stringenum.rs b/rust/derive/src/stringenum.rs new file mode 100644 index 0000000000..5344b934cd --- /dev/null +++ b/rust/derive/src/stringenum.rs @@ -0,0 +1,96 @@ +/* Copyright (C) 2023 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +extern crate proc_macro; +use super::applayerevent::transform_name; +use proc_macro::TokenStream; +use quote::quote; +use syn::{self, parse_macro_input, DeriveInput}; +use std::str::FromStr; + +pub fn derive_enum_string(input: TokenStream, ustr: &str) -> TokenStream where ::Err: std::fmt::Display { + let input = parse_macro_input!(input as DeriveInput); + let name = input.ident; + let mut values = Vec::new(); + let mut names = Vec::new(); + let mut fields = Vec::new(); + + if let syn::Data::Enum(ref data) = input.data { + for v in (&data.variants).into_iter() { + if let Some((_, val)) = &v.discriminant { + let fname = transform_name(&v.ident.to_string()); + names.push(fname); + fields.push(v.ident.clone()); + if let syn::Expr::Lit(l) = val { + if let syn::Lit::Int(li) = &l.lit { + if let Ok(value) = li.base10_parse::() { + values.push(value); + } else { + panic!("EnumString requires explicit {}", ustr); + } + } else { + panic!("EnumString requires explicit literal integer"); + } + } else { + panic!("EnumString requires explicit literal"); + } + } else { + panic!("EnumString requires explicit values"); + } + } + } else { + panic!("EnumString can only be derived for enums"); + } + + let is_suricata = std::env::var("CARGO_PKG_NAME").map(|var| var == "suricata").unwrap_or(false); + let crate_id = if is_suricata { + syn::Ident::new("crate", proc_macro2::Span::call_site()) + } else { + syn::Ident::new("suricata", proc_macro2::Span::call_site()) + }; + + let utype_str = syn::Ident::new(ustr, proc_macro2::Span::call_site()); + + let expanded = quote! { + impl #crate_id::detect::EnumString<#utype_str> for #name { + fn from_u(v: #utype_str) -> Option { + match v { + #( #values => Some(#name::#fields) ,)* + _ => None, + } + } + fn into_u(self) -> #utype_str { + match self { + #( #name::#fields => #values ,)* + } + } + fn to_str(&self) -> &'static str { + match *self { + #( #name::#fields => #names ,)* + } + } + fn from_str(s: &str) -> Option { + match s { + #( #names => Some(#name::#fields) ,)* + _ => None + } + } + } + }; + + proc_macro::TokenStream::from(expanded) +} diff --git a/rust/src/detect/mod.rs b/rust/src/detect/mod.rs index d33c9ae7fa..cad086f161 100644 --- a/rust/src/detect/mod.rs +++ b/rust/src/detect/mod.rs @@ -25,3 +25,46 @@ pub mod stream_size; pub mod uint; pub mod uri; pub mod requires; + +/// EnumString trait that will be implemented on enums that +/// derive StringEnum. +pub trait EnumString { + /// Return the enum variant of the given numeric value. + fn from_u(v: T) -> Option where Self: Sized; + + /// Convert the enum variant to the numeric value. + fn into_u(self) -> T; + + /// Return the string for logging the enum value. + fn to_str(&self) -> &'static str; + + /// Get an enum variant from parsing a string. + fn from_str(s: &str) -> Option where Self: Sized; +} + +#[cfg(test)] +mod test { + use super::*; + use suricata_derive::EnumStringU8; + + #[derive(Clone, Debug, PartialEq, EnumStringU8)] + #[repr(u8)] + pub enum TestEnum { + Zero = 0, + BestValueEver = 42, + } + + #[test] + fn test_enum_string_u8() { + assert_eq!(TestEnum::from_u(0), Some(TestEnum::Zero)); + assert_eq!(TestEnum::from_u(1), None); + assert_eq!(TestEnum::from_u(42), Some(TestEnum::BestValueEver)); + assert_eq!(TestEnum::Zero.into_u(), 0); + assert_eq!(TestEnum::BestValueEver.into_u(), 42); + assert_eq!(TestEnum::Zero.to_str(), "zero"); + assert_eq!(TestEnum::BestValueEver.to_str(), "best_value_ever"); + assert_eq!(TestEnum::from_str("zero"), Some(TestEnum::Zero)); + assert_eq!(TestEnum::from_str("nope"), None); + assert_eq!(TestEnum::from_str("best_value_ever"), Some(TestEnum::BestValueEver)); + } +} diff --git a/rust/src/detect/uint.rs b/rust/src/detect/uint.rs index 0d813bfd08..fd6079a536 100644 --- a/rust/src/detect/uint.rs +++ b/rust/src/detect/uint.rs @@ -23,6 +23,8 @@ use nom7::error::{make_error, ErrorKind}; use nom7::Err; use nom7::IResult; +use super::EnumString; + use std::ffi::CStr; #[derive(PartialEq, Eq, Clone, Debug)] @@ -46,6 +48,29 @@ pub struct DetectUintData { pub mode: DetectUintMode, } +/// Parses a string for detection with integers, using enumeration strings +/// +/// Needs to specify T1 the integer type (like u8) +/// And the Enumeration for the stringer. +/// Will try to parse numerical value first, as any integer detection keyword +/// And if this fails, will resort to using the enumeration strings. +/// +/// Returns Some DetectUintData on success, None on failure +pub fn detect_parse_uint_enum>(s: &str) -> Option> { + if let Ok((_, ctx)) = detect_parse_uint::(s) { + return Some(ctx); + } + if let Some(enum_val) = T2::from_str(s) { + let ctx = DetectUintData:: { + arg1: enum_val.into_u(), + arg2: T1::min_value(), + mode: DetectUintMode::DetectUintModeEqual, + }; + return Some(ctx); + } + return None; +} + pub trait DetectIntType: std::str::FromStr + std::cmp::PartialOrd @@ -442,6 +467,26 @@ pub unsafe extern "C" fn rs_detect_u16_free(ctx: &mut DetectUintData) { mod tests { use super::*; + use suricata_derive::EnumStringU8; + + #[derive(Clone, Debug, PartialEq, EnumStringU8)] + #[repr(u8)] + pub enum TestEnum { + Zero = 0, + BestValueEver = 42, + } + + #[test] + fn test_detect_parse_uint_enum() { + let ctx = detect_parse_uint_enum::("best_value_ever").unwrap(); + assert_eq!(ctx.arg1, 42); + assert_eq!(ctx.mode, DetectUintMode::DetectUintModeEqual); + + let ctx = detect_parse_uint_enum::(">1").unwrap(); + assert_eq!(ctx.arg1, 1); + assert_eq!(ctx.mode, DetectUintMode::DetectUintModeGt); + } + #[test] fn test_parse_uint_hex() { let (_, val) = detect_parse_uint::("0x100").unwrap();