diff --git a/rust/src/smb/smb.rs b/rust/src/smb/smb.rs index f4009c1082..6cdc5163b9 100644 --- a/rust/src/smb/smb.rs +++ b/rust/src/smb/smb.rs @@ -32,6 +32,7 @@ use std::ffi::{self, CString}; use std::collections::HashMap; use nom7::{Err, Needed}; +use nom7::error::{make_error, ErrorKind}; use crate::core::*; use crate::applayer; @@ -2072,9 +2073,19 @@ pub extern "C" fn rs_smb_parse_response_tcp_gap( state.parse_tcp_data_tc_gap(input_len as u32) } -fn smb_probe_tcp_midstream(direction: Direction, slice: &[u8], rdir: *mut u8) -> i8 +fn smb_probe_tcp_midstream(direction: Direction, slice: &[u8], rdir: *mut u8, begins: bool) -> i8 { - match search_smb_record(slice) { + let r = if begins { + // if pattern was found in the beginning, just check first byte + if slice[0] == NBSS_MSGTYPE_SESSION_MESSAGE { + Ok((&slice[..4], &slice[4..])) + } else { + Err(Err::Error(make_error(slice, ErrorKind::Eof))) + } + } else { + search_smb_record(slice) + }; + match r { Ok((_, data)) => { SCLogDebug!("smb found"); match parse_smb_version(data) { @@ -2135,27 +2146,18 @@ fn smb_probe_tcp_midstream(direction: Direction, slice: &[u8], rdir: *mut u8) -> return 0; } -// probing parser -// return 1 if found, 0 is not found -#[no_mangle] -pub unsafe extern "C" fn rs_smb_probe_tcp(_f: *const Flow, - flags: u8, input: *const u8, len: u32, rdir: *mut u8) - -> AppProto +fn smb_probe_tcp(flags: u8, slice: &[u8], rdir: *mut u8, begins: bool) -> AppProto { - if len < MIN_REC_SIZE as u32 { - return ALPROTO_UNKNOWN; - } - let slice = build_slice!(input, len as usize); if flags & STREAM_MIDSTREAM == STREAM_MIDSTREAM { - if smb_probe_tcp_midstream(flags.into(), slice, rdir) == 1 { - return ALPROTO_SMB; + if smb_probe_tcp_midstream(flags.into(), slice, rdir, begins) == 1 { + unsafe { return ALPROTO_SMB; } } } match parse_nbss_record_partial(slice) { Ok((_, ref hdr)) => { if hdr.is_smb() { SCLogDebug!("smb found"); - return ALPROTO_SMB; + unsafe { return ALPROTO_SMB; } } else if hdr.needs_more(){ return 0; } else if hdr.is_valid() && @@ -2168,7 +2170,7 @@ pub unsafe extern "C" fn rs_smb_probe_tcp(_f: *const Flow, Ok((_, ref hdr2)) => { if hdr2.is_smb() { SCLogDebug!("smb found"); - return ALPROTO_SMB; + unsafe { return ALPROTO_SMB; } } } _ => {} @@ -2183,7 +2185,35 @@ pub unsafe extern "C" fn rs_smb_probe_tcp(_f: *const Flow, _ => { }, } SCLogDebug!("no smb"); - return ALPROTO_FAILED; + unsafe { return ALPROTO_FAILED; } +} + +// probing confirmation parser +// return 1 if found, 0 is not found +#[no_mangle] +pub unsafe extern "C" fn rs_smb_probe_begins_tcp(_f: *const Flow, + flags: u8, input: *const u8, len: u32, rdir: *mut u8) + -> AppProto +{ + if len < MIN_REC_SIZE as u32 { + return ALPROTO_UNKNOWN; + } + let slice = build_slice!(input, len as usize); + return smb_probe_tcp(flags, slice, rdir, true); +} + +// probing parser +// return 1 if found, 0 is not found +#[no_mangle] +pub unsafe extern "C" fn rs_smb_probe_tcp(_f: *const Flow, + flags: u8, input: *const u8, len: u32, rdir: *mut u8) + -> AppProto +{ + if len < MIN_REC_SIZE as u32 { + return ALPROTO_UNKNOWN; + } + let slice = build_slice!(input, len as usize); + return smb_probe_tcp(flags, slice, rdir, false); } #[no_mangle] @@ -2313,17 +2343,17 @@ fn register_pattern_probe() -> i8 { // SMB1 r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, b"|ff|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, - Direction::ToServer as u8, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + Direction::ToServer as u8, rs_smb_probe_begins_tcp, MIN_REC_SIZE, MIN_REC_SIZE); r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, b"|ff|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, - Direction::ToClient as u8, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + Direction::ToClient as u8, rs_smb_probe_begins_tcp, MIN_REC_SIZE, MIN_REC_SIZE); // SMB2/3 r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, b"|fe|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, - Direction::ToServer as u8, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + Direction::ToServer as u8, rs_smb_probe_begins_tcp, MIN_REC_SIZE, MIN_REC_SIZE); r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, b"|fe|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, - Direction::ToClient as u8, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + Direction::ToClient as u8, rs_smb_probe_begins_tcp, MIN_REC_SIZE, MIN_REC_SIZE); // SMB3 encrypted records r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, b"|fd|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,