From dcccbb11963b350a5f81b47f53f64f4b4a082ce3 Mon Sep 17 00:00:00 2001 From: Juliana Fajardini Date: Thu, 29 Aug 2024 18:02:15 -0300 Subject: [PATCH] pgsql: track transaction progress per direction PGSQL's current implementation tracks the transaction progress without taking into consideration flow direction, and also has indirections that make it harder to understand how the progress is tracked, as well as when a request or response is actually complete. This patch introduces tracking such progress per direction and adds completion status per direction, too. This will help when triggering raw stream reassembly or for unidirectional transactions, and may be useful when we implement sub-protocols that can have multiple requests per transaction, as well. CancelRequests and TerminationRequests are examples of unidirectional transactions. There won't be any responses to those requests, so we can also mark the response side as done, and set their transactions as completed. Bug #7113 --- rust/src/pgsql/pgsql.rs | 181 +++++++++++++++++++++++++++------------- 1 file changed, 121 insertions(+), 60 deletions(-) diff --git a/rust/src/pgsql/pgsql.rs b/rust/src/pgsql/pgsql.rs index 3b2d8b93dd..8a2b0da157 100644 --- a/rust/src/pgsql/pgsql.rs +++ b/rust/src/pgsql/pgsql.rs @@ -36,17 +36,18 @@ static mut PGSQL_MAX_TX: usize = 1024; #[repr(u8)] #[derive(Copy, Clone, PartialOrd, PartialEq, Eq, Debug)] -pub enum PgsqlTransactionState { - Init = 0, - RequestReceived, - ResponseDone, - FlushedOut, +pub enum PgsqlTxProgress { + TxInit = 0, + TxReceived, + TxDone, + TxFlushedOut, } #[derive(Debug)] pub struct PgsqlTransaction { pub tx_id: u64, - pub tx_state: PgsqlTransactionState, + pub tx_req_state: PgsqlTxProgress, + pub tx_res_state: PgsqlTxProgress, pub request: Option, pub responses: Vec, @@ -72,7 +73,8 @@ impl PgsqlTransaction { pub fn new() -> Self { Self { tx_id: 0, - tx_state: PgsqlTransactionState::Init, + tx_req_state: PgsqlTxProgress::TxInit, + tx_res_state: PgsqlTxProgress::TxInit, request: None, responses: Vec::::new(), data_row_cnt: 0, @@ -205,8 +207,11 @@ impl PgsqlState { let mut index = self.tx_index_completed; for tx_old in &mut self.transactions.range_mut(self.tx_index_completed..) { index += 1; - if tx_old.tx_state < PgsqlTransactionState::ResponseDone { - tx_old.tx_state = PgsqlTransactionState::FlushedOut; + if tx_old.tx_res_state < PgsqlTxProgress::TxDone { + // we don't check for TxReqDone for the majority of requests are basically completed + // when they're parsed, as of now + tx_old.tx_req_state = PgsqlTxProgress::TxFlushedOut; + tx_old.tx_res_state = PgsqlTxProgress::TxFlushedOut; //TODO set event break; } @@ -242,26 +247,6 @@ impl PgsqlState { return self.transactions.back_mut(); } - /// Process State progress to decide if PgsqlTransaction is finished - /// - /// As Pgsql transactions are bidirectional and may be comprised of several - /// responses, we must track State progress to decide on tx completion - fn is_tx_completed(&self) -> bool { - if let PgsqlStateProgress::ReadyForQueryReceived - | PgsqlStateProgress::SSLRejectedReceived - | PgsqlStateProgress::SimpleAuthenticationReceived - | PgsqlStateProgress::SASLAuthenticationReceived - | PgsqlStateProgress::SASLAuthenticationContinueReceived - | PgsqlStateProgress::SASLAuthenticationFinalReceived - | PgsqlStateProgress::ConnectionTerminated - | PgsqlStateProgress::Finished = self.state_progress - { - true - } else { - false - } - } - /// Define PgsqlState progression, based on the request received /// /// As PostgreSQL transactions can have multiple messages, State progression @@ -315,6 +300,22 @@ impl PgsqlState { } } + /// Process State progress to decide if request is finished + /// + fn request_is_complete(state: PgsqlStateProgress) -> bool { + match state { + PgsqlStateProgress::SSLRequestReceived + | PgsqlStateProgress::StartupMessageReceived + | PgsqlStateProgress::SimpleQueryReceived + | PgsqlStateProgress::PasswordMessageReceived + | PgsqlStateProgress::SASLInitialResponseReceived + | PgsqlStateProgress::SASLResponseReceived + | PgsqlStateProgress::CancelRequestReceived + | PgsqlStateProgress::ConnectionTerminated => true, + _ => false, + } + } + fn parse_request(&mut self, flow: *const Flow, input: &[u8]) -> AppLayerResult { // We're not interested in empty requests. if input.is_empty() { @@ -348,14 +349,33 @@ impl PgsqlState { Direction::ToServer as i32, ); start = rem; - if let Some(state) = PgsqlState::request_next_state(&request) { + let new_state = PgsqlState::request_next_state(&request); + + if let Some(state) = new_state { self.state_progress = state; }; - let tx_completed = self.is_tx_completed(); + // PostreSQL progress states can be represented as a finite state machine + // After the connection phase, the backend/ server will be mostly waiting in a state of `ReadyForQuery`, unless + // it's processing some request. + // When the frontend wants to cancel a request, it will send a CancelRequest message over a new connection - to + // which there won't be any responses. + // If the frontend wants to terminate the connection, the backend won't send any confirmation after receiving a + // Terminate request. + // A simplified finite state machine for PostgreSQL v3 can be found at: + // https://samadhiweb.com/blog/2013.04.28.graphviz.postgresv3.html if let Some(tx) = self.find_or_create_tx() { tx.request = Some(request); - if tx_completed { - tx.tx_state = PgsqlTransactionState::ResponseDone; + if let Some(state) = new_state { + if Self::request_is_complete(state) { + // The request is always complete at this point + tx.tx_req_state = PgsqlTxProgress::TxDone; + if state == PgsqlStateProgress::ConnectionTerminated + || state == PgsqlStateProgress::CancelRequestReceived + { + /* The server won't send any responses to such requests, so transaction should be over */ + tx.tx_res_state = PgsqlTxProgress::TxDone; + } + } } } else { // If there isn't a new transaction, we'll consider Suri should move on @@ -455,6 +475,21 @@ impl PgsqlState { } } + /// Process State progress to decide if response is finished + /// + fn response_is_complete(state: PgsqlStateProgress) -> bool { + match state { + PgsqlStateProgress::ReadyForQueryReceived + | PgsqlStateProgress::SSLRejectedReceived + | PgsqlStateProgress::SimpleAuthenticationReceived + | PgsqlStateProgress::SASLAuthenticationReceived + | PgsqlStateProgress::SASLAuthenticationContinueReceived + | PgsqlStateProgress::SASLAuthenticationFinalReceived + | PgsqlStateProgress::Finished => true, + _ => false, + } + } + fn parse_response(&mut self, flow: *const Flow, input: &[u8]) -> AppLayerResult { // We're not interested in empty responses. if input.is_empty() { @@ -482,30 +517,36 @@ impl PgsqlState { ); start = rem; SCLogDebug!("Response is {:?}", &response); - if let Some(state) = self.response_process_next_state(&response, flow) { + let new_state = self.response_process_next_state(&response, flow); + if let Some(state) = new_state { self.state_progress = state; - }; - let tx_completed = self.is_tx_completed(); - let curr_state = self.state_progress; + } if let Some(tx) = self.find_or_create_tx() { - if curr_state == PgsqlStateProgress::DataRowReceived { - tx.incr_row_cnt(); - } else if curr_state == PgsqlStateProgress::CommandCompletedReceived - && tx.get_row_cnt() > 0 - { - // let's summarize the info from the data_rows in one response - let dummy_resp = - PgsqlBEMessage::ConsolidatedDataRow(ConsolidatedDataRowPacket { - identifier: b'D', - row_cnt: tx.get_row_cnt(), - data_size: tx.data_size, // total byte count of all data_row messages combined - }); - tx.responses.push(dummy_resp); - tx.responses.push(response); - } else { - tx.responses.push(response); - if tx_completed { - tx.tx_state = PgsqlTransactionState::ResponseDone; + if tx.tx_res_state == PgsqlTxProgress::TxInit { + tx.tx_res_state = PgsqlTxProgress::TxReceived; + } + if let Some(state) = new_state { + if state == PgsqlStateProgress::DataRowReceived { + tx.incr_row_cnt(); + } else if state == PgsqlStateProgress::CommandCompletedReceived + && tx.get_row_cnt() > 0 + { + // let's summarize the info from the data_rows in one response + let dummy_resp = PgsqlBEMessage::ConsolidatedDataRow( + ConsolidatedDataRowPacket { + identifier: b'D', + row_cnt: tx.get_row_cnt(), + data_size: tx.data_size, // total byte count of all data_row messages combined + }, + ); + tx.responses.push(dummy_resp); + tx.responses.push(response); + } else { + tx.responses.push(response); + if Self::response_is_complete(state) { + tx.tx_req_state = PgsqlTxProgress::TxDone; + tx.tx_res_state = PgsqlTxProgress::TxDone; + } } } } else { @@ -557,6 +598,22 @@ fn probe_tc(input: &[u8]) -> bool { false } +fn pgsql_tx_get_req_state(tx: *mut std::os::raw::c_void) -> PgsqlTxProgress { + let tx_safe: &mut PgsqlTransaction; + unsafe { + tx_safe = cast_pointer!(tx, PgsqlTransaction); + } + tx_safe.tx_req_state +} + +fn pgsql_tx_get_res_state(tx: *mut std::os::raw::c_void) -> PgsqlTxProgress { + let tx_safe: &mut PgsqlTransaction; + unsafe { + tx_safe = cast_pointer!(tx, PgsqlTransaction); + } + tx_safe.tx_res_state +} + // C exports. /// C entry point for a probing parser. @@ -712,10 +769,14 @@ pub extern "C" fn SCPgsqlStateGetTxCount(state: *mut std::os::raw::c_void) -> u6 #[no_mangle] pub unsafe extern "C" fn SCPgsqlTxGetALStateProgress( - tx: *mut std::os::raw::c_void, _direction: u8, + tx: *mut std::os::raw::c_void, direction: u8, ) -> std::os::raw::c_int { - let tx = cast_pointer!(tx, PgsqlTransaction); - tx.tx_state as i32 + if direction == Direction::ToServer as u8 { + return pgsql_tx_get_req_state(tx) as i32; + } + + // Direction has only two possible values, so we don't need to check for the other one + pgsql_tx_get_res_state(tx) as i32 } export_tx_data_get!(rs_pgsql_get_tx_data, PgsqlTransaction); @@ -743,8 +804,8 @@ pub unsafe extern "C" fn SCRegisterPgsqlParser() { parse_tc: SCPgsqlParseResponse, get_tx_count: SCPgsqlStateGetTxCount, get_tx: SCPgsqlStateGetTx, - tx_comp_st_ts: PgsqlTransactionState::RequestReceived as i32, - tx_comp_st_tc: PgsqlTransactionState::ResponseDone as i32, + tx_comp_st_ts: PgsqlTxProgress::TxDone as i32, + tx_comp_st_tc: PgsqlTxProgress::TxDone as i32, tx_get_progress: SCPgsqlTxGetALStateProgress, get_eventinfo: None, get_eventinfo_byid: None,