diff --git a/rust/src/dns/dns.rs b/rust/src/dns/dns.rs index 777cb10793..c9b688f1ff 100644 --- a/rust/src/dns/dns.rs +++ b/rust/src/dns/dns.rs @@ -431,7 +431,8 @@ impl DNSState { /// Returns the number of messages parsed. pub fn parse_request_tcp(&mut self, input: &[u8]) -> i8 { if self.gap { - if probe_tcp(input) { + let (is_dns, _) = probe_tcp(input); + if is_dns { self.gap = false; } else { return 0 @@ -471,7 +472,8 @@ impl DNSState { /// Returns the number of messages parsed. pub fn parse_response_tcp(&mut self, input: &[u8]) -> i8 { if self.gap { - if probe_tcp(input) { + let (is_dns, _) = probe_tcp(input); + if is_dns { self.gap = false; } else { return 0 @@ -519,19 +521,25 @@ impl DNSState { } /// Probe input to see if it looks like DNS. -fn probe(input: &[u8]) -> bool { - parser::dns_parse_request(input).is_ok() +fn probe(input: &[u8]) -> (bool, bool) { + match parser::dns_parse_request(input) { + Ok((_, request)) => { + let is_request = request.header.flags & 0x8000 == 0; + return (true, is_request); + }, + Err(_) => (false, false), + } } /// Probe TCP input to see if it looks like DNS. -pub fn probe_tcp(input: &[u8]) -> bool { +pub fn probe_tcp(input: &[u8]) -> (bool, bool) { match nom::be_u16(input) { Ok((rem, _)) => { return probe(rem); }, _ => {} } - return false; + return (false, false); } /// Returns *mut DNSState @@ -813,27 +821,46 @@ pub extern "C" fn rs_dns_tx_get_query_rrtype(tx: &mut DNSTransaction, } #[no_mangle] -pub extern "C" fn rs_dns_probe(input: *const u8, len: u32) +pub extern "C" fn rs_dns_probe(input: *const u8, len: u32, rdir: *mut u8) -> u8 { let slice: &[u8] = unsafe { std::slice::from_raw_parts(input as *mut u8, len as usize) }; - if probe(slice) { + let (is_dns, is_request) = probe(slice); + if is_dns { + let dir = if is_request { + core::STREAM_TOSERVER + } else { + core::STREAM_TOCLIENT + }; + unsafe { *rdir = dir }; + return 1; } return 0; } #[no_mangle] -pub extern "C" fn rs_dns_probe_tcp(input: *const u8, - len: u32) +pub extern "C" fn rs_dns_probe_tcp(direction: u8, + input: *const u8, + len: u32, + rdir: *mut u8) -> u8 { let slice: &[u8] = unsafe { std::slice::from_raw_parts(input as *mut u8, len as usize) }; - if probe_tcp(slice) { + let (is_dns, is_request) = probe_tcp(slice); + if is_dns { + let dir = if is_request { + core::STREAM_TOSERVER + } else { + core::STREAM_TOCLIENT + }; + if direction & (core::STREAM_TOSERVER|core::STREAM_TOCLIENT) != dir { + unsafe { *rdir = dir }; + } return 1; } return 0; diff --git a/src/app-layer-dns-tcp-rust.c b/src/app-layer-dns-tcp-rust.c index c58096220c..a6d3efb3a3 100644 --- a/src/app-layer-dns-tcp-rust.c +++ b/src/app-layer-dns-tcp-rust.c @@ -61,7 +61,7 @@ static uint16_t RustDNSTCPProbe(Flow *f, uint8_t direction, } // Validate and return ALPROTO_FAILED if needed. - if (!rs_dns_probe_tcp(input, len)) { + if (!rs_dns_probe_tcp(direction, input, len, rdir)) { return ALPROTO_FAILED; } @@ -126,7 +126,7 @@ void RegisterRustDNSTCPParsers(void) if (RunmodeIsUnittests()) { AppLayerProtoDetectPPRegister(IPPROTO_TCP, "53", ALPROTO_DNS, 0, sizeof(DNSHeader) + 2, STREAM_TOSERVER, RustDNSTCPProbe, - NULL); + RustDNSTCPProbe); } else { int have_cfg = AppLayerProtoDetectPPParseConfPorts("tcp", IPPROTO_TCP, proto_name, ALPROTO_DNS, 0, diff --git a/src/app-layer-dns-udp-rust.c b/src/app-layer-dns-udp-rust.c index cd9e84dfb5..734d410a05 100644 --- a/src/app-layer-dns-udp-rust.c +++ b/src/app-layer-dns-udp-rust.c @@ -58,7 +58,7 @@ static uint16_t DNSUDPProbe(Flow *f, uint8_t direction, } // Validate and return ALPROTO_FAILED if needed. - if (!rs_dns_probe(input, len)) { + if (!rs_dns_probe(input, len, rdir)) { return ALPROTO_FAILED; } @@ -132,11 +132,11 @@ void RegisterRustDNSUDPParsers(void) if (RunmodeIsUnittests()) { AppLayerProtoDetectPPRegister(IPPROTO_UDP, "53", ALPROTO_DNS, 0, sizeof(DNSHeader), STREAM_TOSERVER, DNSUDPProbe, - NULL); + DNSUDPProbe); } else { int have_cfg = AppLayerProtoDetectPPParseConfPorts("udp", IPPROTO_UDP, proto_name, ALPROTO_DNS, 0, sizeof(DNSHeader), - DNSUDPProbe, NULL); + DNSUDPProbe, DNSUDPProbe); /* If no config, enable on port 53. */ if (!have_cfg) { @@ -146,7 +146,7 @@ void RegisterRustDNSUDPParsers(void) #endif AppLayerProtoDetectPPRegister(IPPROTO_UDP, "53", ALPROTO_DNS, 0, sizeof(DNSHeader), STREAM_TOSERVER, - DNSUDPProbe, NULL); + DNSUDPProbe, DNSUDPProbe); } } } else {