smb: Add rust registration function

Get rid of the C glue code and move registration completely to Rust.
pull/6285/head
Shivani Bhardwaj 4 years ago
parent 27af4bb002
commit e5c948df87

@ -28,7 +28,7 @@
use std;
use std::mem::transmute;
use std::str;
use std::ffi::{self, CStr};
use std::ffi::{self, CStr, CString};
use std::collections::HashMap;
@ -37,6 +37,7 @@ use nom;
use crate::core::*;
use crate::applayer;
use crate::applayer::*;
use crate::conf::*;
use crate::filecontainer::*;
use crate::smb::nbss_records::*;
@ -2189,3 +2190,143 @@ pub extern "C" fn rs_smb_state_get_event_info(event_name: *const std::os::raw::c
}
0
}
pub extern "C" fn smb3_probe_tcp(f: *const Flow, dir: u8, input: *const u8, len: u32, rdir: *mut u8) -> u16 {
let retval = rs_smb_probe_tcp(f, dir, input, len, rdir);
let f = cast_pointer!(f, Flow);
if unsafe { retval != ALPROTO_SMB } {
return retval;
}
let (sp, dp) = f.get_ports();
let flags = f.get_flags();
let fsp = if (flags & FLOW_DIR_REVERSED) != 0 { dp } else { sp };
let fdp = if (flags & FLOW_DIR_REVERSED) != 0 { sp } else { dp };
if fsp == 445 && fdp != 445 {
unsafe {
if dir & STREAM_TOSERVER != 0 {
*rdir = STREAM_TOCLIENT;
} else {
*rdir = STREAM_TOSERVER;
}
}
}
return unsafe { ALPROTO_SMB };
}
fn register_pattern_probe() -> i8 {
let mut r = 0;
unsafe {
// SMB1
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
b"|ff|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
STREAM_TOSERVER, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
b"|ff|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
STREAM_TOCLIENT, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
// SMB2/3
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
b"|fe|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
STREAM_TOSERVER, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
b"|fe|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
STREAM_TOCLIENT, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
// SMB3 encrypted records
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
b"|fd|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
STREAM_TOSERVER, smb3_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
b"|fd|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
STREAM_TOCLIENT, smb3_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
}
if r == 0 {
return 0;
} else {
return -1;
}
}
// Parser name as a C style string.
const PARSER_NAME: &'static [u8] = b"smb\0";
#[no_mangle]
pub unsafe extern "C" fn rs_smb_register_parser() {
let default_port = CString::new("445").unwrap();
let mut stream_depth = SMB_CONFIG_DEFAULT_STREAM_DEPTH;
let parser = RustParser {
name: PARSER_NAME.as_ptr() as *const std::os::raw::c_char,
default_port: default_port.as_ptr(),
ipproto: IPPROTO_TCP,
probe_ts: None,
probe_tc: None,
min_depth: 0,
max_depth: 16,
state_new: rs_smb_state_new,
state_free: rs_smb_state_free,
tx_free: rs_smb_state_tx_free,
parse_ts: rs_smb_parse_request_tcp,
parse_tc: rs_smb_parse_response_tcp,
get_tx_count: rs_smb_state_get_tx_count,
get_tx: rs_smb_state_get_tx,
tx_comp_st_ts: 1,
tx_comp_st_tc: 1,
tx_get_progress: rs_smb_tx_get_alstate_progress,
get_de_state: rs_smb_state_get_tx_detect_state,
set_de_state: rs_smb_state_set_tx_detect_state,
get_events: Some(rs_smb_state_get_events),
get_eventinfo: Some(rs_smb_state_get_event_info),
get_eventinfo_byid : Some(rs_smb_state_get_event_info_by_id),
localstorage_new: None,
localstorage_free: None,
get_files: Some(rs_smb_getfiles),
get_tx_iterator: Some(rs_smb_state_get_tx_iterator),
get_tx_data: rs_smb_get_tx_data,
apply_tx_config: None,
flags: APP_LAYER_PARSER_OPT_ACCEPT_GAPS,
truncate: Some(rs_smb_state_truncate),
};
let ip_proto_str = CString::new("tcp").unwrap();
if AppLayerProtoDetectConfProtoDetectionEnabled(
ip_proto_str.as_ptr(),
parser.name,
) != 0
{
let alproto = AppLayerRegisterProtocolDetection(&parser, 1);
ALPROTO_SMB = alproto;
if register_pattern_probe() < 0 {
return;
}
let have_cfg = AppLayerProtoDetectPPParseConfPorts(ip_proto_str.as_ptr(),
IPPROTO_TCP as u8, parser.name, ALPROTO_SMB, 0,
MIN_REC_SIZE, rs_smb_probe_tcp, rs_smb_probe_tcp);
if have_cfg == 0 {
AppLayerProtoDetectPPRegister(IPPROTO_TCP as u8, parser.default_port, ALPROTO_SMB,
0, MIN_REC_SIZE, STREAM_TOSERVER, rs_smb_probe_tcp, rs_smb_probe_tcp);
}
if AppLayerParserConfParserEnabled(
ip_proto_str.as_ptr(),
parser.name,
) != 0
{
let _ = AppLayerRegisterParser(&parser, alproto);
}
SCLogDebug!("Rust SMB parser registered.");
let retval = conf_get("app-layer.protocols.smb.stream-depth");
if let Some(val) = retval {
let val = val.parse::<i32>().unwrap();
if val < 0 {
SCLogError!("invalid value for stream-depth");
} else {
stream_depth = val as u32;
}
AppLayerParserSetStreamDepth(IPPROTO_TCP as u8, ALPROTO_SMB, stream_depth);
}
} else {
SCLogDebug!("Protocol detector and parser disabled for SMB.");
}
}

@ -28,199 +28,6 @@
#include "app-layer-smb.h"
#include "util-misc.h"
#define MIN_REC_SIZE 32+4 // SMB hdr + nbss hdr
static AppLayerResult SMBTCPParseRequest(Flow *f, void *state,
AppLayerParserState *pstate, const uint8_t *input, uint32_t input_len,
void *local_data, const uint8_t flags)
{
SCLogDebug("SMBTCPParseRequest");
uint16_t file_flags = FileFlowToFlags(f, STREAM_TOSERVER);
rs_smb_setfileflags(0, state, file_flags|FILE_USE_DETECT);
if (input == NULL && input_len > 0) {
AppLayerResult res = rs_smb_parse_request_tcp_gap(state, input_len);
SCLogDebug("SMB request GAP of %u bytes, retval %d", input_len, res.status);
SCReturnStruct(res);
} else {
AppLayerResult res = rs_smb_parse_request_tcp(f, state, pstate,
input, input_len, local_data, flags);
SCLogDebug("SMB request%s of %u bytes, retval %d",
(input == NULL && input_len > 0) ? " is GAP" : "", input_len, res.status);
SCReturnStruct(res);
}
}
static AppLayerResult SMBTCPParseResponse(Flow *f, void *state,
AppLayerParserState *pstate, const uint8_t *input, uint32_t input_len,
void *local_data, const uint8_t flags)
{
SCLogDebug("SMBTCPParseResponse");
uint16_t file_flags = FileFlowToFlags(f, STREAM_TOCLIENT);
rs_smb_setfileflags(1, state, file_flags|FILE_USE_DETECT);
SCLogDebug("SMBTCPParseResponse %p/%u", input, input_len);
if (input == NULL && input_len > 0) {
AppLayerResult res = rs_smb_parse_response_tcp_gap(state, input_len);
SCLogDebug("SMB response GAP of %u bytes, retval %d", input_len, res.status);
SCReturnStruct(res);
} else {
AppLayerResult res = rs_smb_parse_response_tcp(f, state, pstate,
input, input_len, local_data, flags);
SCReturnStruct(res);
}
}
static uint16_t SMBTCPProbe(Flow *f, uint8_t direction,
const uint8_t *input, uint32_t len, uint8_t *rdir)
{
SCLogDebug("SMBTCPProbe");
if (len < MIN_REC_SIZE) {
return ALPROTO_UNKNOWN;
}
const int r = rs_smb_probe_tcp(f, direction, input, len, rdir);
switch (r) {
case 1:
return ALPROTO_SMB;
case 0:
return ALPROTO_UNKNOWN;
case -1:
default:
return ALPROTO_FAILED;
}
}
/** \internal
* \brief as SMB3 records have no direction indicator, fall
* back to the port numbers for a hint
*/
static uint16_t SMB3TCPProbe(Flow *f, uint8_t direction,
const uint8_t *input, uint32_t len, uint8_t *rdir)
{
SCEnter();
AppProto p = SMBTCPProbe(f, direction, input, len, rdir);
if (p != ALPROTO_SMB) {
SCReturnUInt(p);
}
uint16_t fsp = (f->flags & FLOW_DIR_REVERSED) ? f->dp : f->sp;
uint16_t fdp = (f->flags & FLOW_DIR_REVERSED) ? f->sp : f->dp;
SCLogDebug("direction %s flow sp %u dp %u fsp %u fdp %u",
(direction & STREAM_TOSERVER) ? "toserver" : "toclient",
f->sp, f->dp, fsp, fdp);
if (fsp == 445 && fdp != 445) {
if (direction & STREAM_TOSERVER) {
*rdir = STREAM_TOCLIENT;
} else {
*rdir = STREAM_TOSERVER;
}
}
SCLogDebug("returning ALPROTO_SMB for dir %s with rdir %s",
(direction & STREAM_TOSERVER) ? "toserver" : "toclient",
(*rdir == STREAM_TOSERVER) ? "toserver" : "toclient");
SCReturnUInt(ALPROTO_SMB);
}
static int SMBGetAlstateProgress(void *tx, uint8_t direction)
{
return rs_smb_tx_get_alstate_progress(tx, direction);
}
static uint64_t SMBGetTxCnt(void *alstate)
{
return rs_smb_state_get_tx_count(alstate);
}
static void *SMBGetTx(void *alstate, uint64_t tx_id)
{
return rs_smb_state_get_tx(alstate, tx_id);
}
static AppLayerGetTxIterTuple SMBGetTxIterator(
const uint8_t ipproto, const AppProto alproto,
void *alstate, uint64_t min_tx_id, uint64_t max_tx_id,
AppLayerGetTxIterState *istate)
{
return rs_smb_state_get_tx_iterator(
ipproto, alproto, alstate, min_tx_id, max_tx_id, (uint64_t *)istate);
}
static void SMBStateTransactionFree(void *state, uint64_t tx_id)
{
rs_smb_state_tx_free(state, tx_id);
}
static DetectEngineState *SMBGetTxDetectState(void *tx)
{
return rs_smb_state_get_tx_detect_state(tx);
}
static int SMBSetTxDetectState(void *tx, DetectEngineState *s)
{
rs_smb_state_set_tx_detect_state(tx, s);
return 0;
}
static FileContainer *SMBGetFiles(void *state, uint8_t direction)
{
return rs_smb_getfiles(state, direction);
}
static AppLayerDecoderEvents *SMBGetEvents(void *tx)
{
return rs_smb_state_get_events(tx);
}
static int SMBGetEventInfoById(int event_id, const char **event_name,
AppLayerEventType *event_type)
{
return rs_smb_state_get_event_info_by_id(event_id, event_name, event_type);
}
static int SMBGetEventInfo(const char *event_name, int *event_id,
AppLayerEventType *event_type)
{
return rs_smb_state_get_event_info(event_name, event_id, event_type);
}
static void SMBStateTruncate(void *state, uint8_t direction)
{
return rs_smb_state_truncate(state, direction);
}
static int SMBRegisterPatternsForProtocolDetection(void)
{
int r = 0;
/* SMB1 */
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
"|ff|SMB", 8, 4, STREAM_TOSERVER, SMBTCPProbe,
MIN_REC_SIZE, MIN_REC_SIZE);
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
"|ff|SMB", 8, 4, STREAM_TOCLIENT, SMBTCPProbe,
MIN_REC_SIZE, MIN_REC_SIZE);
/* SMB2/3 */
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
"|fe|SMB", 8, 4, STREAM_TOSERVER, SMBTCPProbe,
MIN_REC_SIZE, MIN_REC_SIZE);
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
"|fe|SMB", 8, 4, STREAM_TOCLIENT, SMBTCPProbe,
MIN_REC_SIZE, MIN_REC_SIZE);
/* SMB3 encrypted records */
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
"|fd|SMB", 8, 4, STREAM_TOSERVER, SMB3TCPProbe,
MIN_REC_SIZE, MIN_REC_SIZE);
r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
"|fd|SMB", 8, 4, STREAM_TOCLIENT, SMB3TCPProbe,
MIN_REC_SIZE, MIN_REC_SIZE);
return r == 0 ? 0 : -1;
}
static StreamingBufferConfig sbcfg = STREAMING_BUFFER_CONFIG_INITIALIZER;
static SuricataFileContext sfc = { &sbcfg };
@ -231,95 +38,11 @@ static SuricataFileContext sfc = { &sbcfg };
static void SMBParserRegisterTests(void);
#endif
static uint32_t stream_depth = SMB_CONFIG_DEFAULT_STREAM_DEPTH;
void RegisterSMBParsers(void)
{
const char *proto_name = "smb";
/** SMB */
if (AppLayerProtoDetectConfProtoDetectionEnabled("tcp", proto_name)) {
AppLayerProtoDetectRegisterProtocol(ALPROTO_SMB, proto_name);
if (SMBRegisterPatternsForProtocolDetection() < 0)
return;
rs_smb_init(&sfc);
if (RunmodeIsUnittests()) {
AppLayerProtoDetectPPRegister(IPPROTO_TCP, "445", ALPROTO_SMB, 0,
MIN_REC_SIZE, STREAM_TOSERVER, SMBTCPProbe,
SMBTCPProbe);
} else {
int have_cfg = AppLayerProtoDetectPPParseConfPorts("tcp",
IPPROTO_TCP, proto_name, ALPROTO_SMB, 0,
MIN_REC_SIZE, SMBTCPProbe, SMBTCPProbe);
/* if we have no config, we enable the default port 445 */
if (!have_cfg) {
SCLogConfig("no SMB TCP config found, enabling SMB detection "
"on port 445.");
AppLayerProtoDetectPPRegister(IPPROTO_TCP, "445", ALPROTO_SMB, 0,
MIN_REC_SIZE, STREAM_TOSERVER, SMBTCPProbe,
SMBTCPProbe);
}
}
} else {
SCLogConfig("Protocol detection and parser disabled for %s protocol.",
proto_name);
return;
}
if (AppLayerParserConfParserEnabled("tcp", proto_name)) {
AppLayerParserRegisterParser(IPPROTO_TCP, ALPROTO_SMB, STREAM_TOSERVER,
SMBTCPParseRequest);
AppLayerParserRegisterParser(IPPROTO_TCP , ALPROTO_SMB, STREAM_TOCLIENT,
SMBTCPParseResponse);
AppLayerParserRegisterStateFuncs(IPPROTO_TCP, ALPROTO_SMB,
rs_smb_state_new, rs_smb_state_free);
AppLayerParserRegisterTxFreeFunc(IPPROTO_TCP, ALPROTO_SMB,
SMBStateTransactionFree);
AppLayerParserRegisterGetEventsFunc(IPPROTO_TCP, ALPROTO_SMB,
SMBGetEvents);
AppLayerParserRegisterGetEventInfo(IPPROTO_TCP, ALPROTO_SMB,
SMBGetEventInfo);
AppLayerParserRegisterGetEventInfoById(IPPROTO_TCP, ALPROTO_SMB,
SMBGetEventInfoById);
AppLayerParserRegisterDetectStateFuncs(IPPROTO_TCP, ALPROTO_SMB,
SMBGetTxDetectState, SMBSetTxDetectState);
AppLayerParserRegisterGetTx(IPPROTO_TCP, ALPROTO_SMB, SMBGetTx);
AppLayerParserRegisterGetTxIterator(IPPROTO_TCP, ALPROTO_SMB, SMBGetTxIterator);
AppLayerParserRegisterGetTxCnt(IPPROTO_TCP, ALPROTO_SMB,
SMBGetTxCnt);
AppLayerParserRegisterGetStateProgressFunc(IPPROTO_TCP, ALPROTO_SMB,
SMBGetAlstateProgress);
AppLayerParserRegisterStateProgressCompletionStatus(ALPROTO_SMB, 1, 1);
AppLayerParserRegisterTruncateFunc(IPPROTO_TCP, ALPROTO_SMB,
SMBStateTruncate);
AppLayerParserRegisterGetFilesFunc(IPPROTO_TCP, ALPROTO_SMB, SMBGetFiles);
AppLayerParserRegisterTxDataFunc(IPPROTO_TCP, ALPROTO_SMB, rs_smb_get_tx_data);
/* This parser accepts gaps. */
AppLayerParserRegisterOptionFlags(IPPROTO_TCP, ALPROTO_SMB,
APP_LAYER_PARSER_OPT_ACCEPT_GAPS);
ConfNode *p = ConfGetNode("app-layer.protocols.smb.stream-depth");
if (p != NULL) {
uint32_t value;
if (ParseSizeStringU32(p->val, &value) < 0) {
SCLogError(SC_ERR_SMB_CONFIG, "invalid value for stream-depth %s", p->val);
} else {
stream_depth = value;
}
}
SCLogConfig("SMB stream depth: %u", stream_depth);
rs_smb_init(&sfc);
rs_smb_register_parser();
AppLayerParserSetStreamDepth(IPPROTO_TCP, ALPROTO_SMB, stream_depth);
} else {
SCLogConfig("Parsed disabled for %s protocol. Protocol detection"
"still on.", proto_name);
}
#ifdef UNITTESTS
AppLayerParserRegisterProtocolUnittests(IPPROTO_TCP, ALPROTO_SMB, SMBParserRegisterTests);
#endif

Loading…
Cancel
Save