diff --git a/rust/src/pgsql/pgsql.rs b/rust/src/pgsql/pgsql.rs index 3b2d8b93dd..8a2b0da157 100644 --- a/rust/src/pgsql/pgsql.rs +++ b/rust/src/pgsql/pgsql.rs @@ -36,17 +36,18 @@ static mut PGSQL_MAX_TX: usize = 1024; #[repr(u8)] #[derive(Copy, Clone, PartialOrd, PartialEq, Eq, Debug)] -pub enum PgsqlTransactionState { - Init = 0, - RequestReceived, - ResponseDone, - FlushedOut, +pub enum PgsqlTxProgress { + TxInit = 0, + TxReceived, + TxDone, + TxFlushedOut, } #[derive(Debug)] pub struct PgsqlTransaction { pub tx_id: u64, - pub tx_state: PgsqlTransactionState, + pub tx_req_state: PgsqlTxProgress, + pub tx_res_state: PgsqlTxProgress, pub request: Option, pub responses: Vec, @@ -72,7 +73,8 @@ impl PgsqlTransaction { pub fn new() -> Self { Self { tx_id: 0, - tx_state: PgsqlTransactionState::Init, + tx_req_state: PgsqlTxProgress::TxInit, + tx_res_state: PgsqlTxProgress::TxInit, request: None, responses: Vec::::new(), data_row_cnt: 0, @@ -205,8 +207,11 @@ impl PgsqlState { let mut index = self.tx_index_completed; for tx_old in &mut self.transactions.range_mut(self.tx_index_completed..) { index += 1; - if tx_old.tx_state < PgsqlTransactionState::ResponseDone { - tx_old.tx_state = PgsqlTransactionState::FlushedOut; + if tx_old.tx_res_state < PgsqlTxProgress::TxDone { + // we don't check for TxReqDone for the majority of requests are basically completed + // when they're parsed, as of now + tx_old.tx_req_state = PgsqlTxProgress::TxFlushedOut; + tx_old.tx_res_state = PgsqlTxProgress::TxFlushedOut; //TODO set event break; } @@ -242,26 +247,6 @@ impl PgsqlState { return self.transactions.back_mut(); } - /// Process State progress to decide if PgsqlTransaction is finished - /// - /// As Pgsql transactions are bidirectional and may be comprised of several - /// responses, we must track State progress to decide on tx completion - fn is_tx_completed(&self) -> bool { - if let PgsqlStateProgress::ReadyForQueryReceived - | PgsqlStateProgress::SSLRejectedReceived - | PgsqlStateProgress::SimpleAuthenticationReceived - | PgsqlStateProgress::SASLAuthenticationReceived - | PgsqlStateProgress::SASLAuthenticationContinueReceived - | PgsqlStateProgress::SASLAuthenticationFinalReceived - | PgsqlStateProgress::ConnectionTerminated - | PgsqlStateProgress::Finished = self.state_progress - { - true - } else { - false - } - } - /// Define PgsqlState progression, based on the request received /// /// As PostgreSQL transactions can have multiple messages, State progression @@ -315,6 +300,22 @@ impl PgsqlState { } } + /// Process State progress to decide if request is finished + /// + fn request_is_complete(state: PgsqlStateProgress) -> bool { + match state { + PgsqlStateProgress::SSLRequestReceived + | PgsqlStateProgress::StartupMessageReceived + | PgsqlStateProgress::SimpleQueryReceived + | PgsqlStateProgress::PasswordMessageReceived + | PgsqlStateProgress::SASLInitialResponseReceived + | PgsqlStateProgress::SASLResponseReceived + | PgsqlStateProgress::CancelRequestReceived + | PgsqlStateProgress::ConnectionTerminated => true, + _ => false, + } + } + fn parse_request(&mut self, flow: *const Flow, input: &[u8]) -> AppLayerResult { // We're not interested in empty requests. if input.is_empty() { @@ -348,14 +349,33 @@ impl PgsqlState { Direction::ToServer as i32, ); start = rem; - if let Some(state) = PgsqlState::request_next_state(&request) { + let new_state = PgsqlState::request_next_state(&request); + + if let Some(state) = new_state { self.state_progress = state; }; - let tx_completed = self.is_tx_completed(); + // PostreSQL progress states can be represented as a finite state machine + // After the connection phase, the backend/ server will be mostly waiting in a state of `ReadyForQuery`, unless + // it's processing some request. + // When the frontend wants to cancel a request, it will send a CancelRequest message over a new connection - to + // which there won't be any responses. + // If the frontend wants to terminate the connection, the backend won't send any confirmation after receiving a + // Terminate request. + // A simplified finite state machine for PostgreSQL v3 can be found at: + // https://samadhiweb.com/blog/2013.04.28.graphviz.postgresv3.html if let Some(tx) = self.find_or_create_tx() { tx.request = Some(request); - if tx_completed { - tx.tx_state = PgsqlTransactionState::ResponseDone; + if let Some(state) = new_state { + if Self::request_is_complete(state) { + // The request is always complete at this point + tx.tx_req_state = PgsqlTxProgress::TxDone; + if state == PgsqlStateProgress::ConnectionTerminated + || state == PgsqlStateProgress::CancelRequestReceived + { + /* The server won't send any responses to such requests, so transaction should be over */ + tx.tx_res_state = PgsqlTxProgress::TxDone; + } + } } } else { // If there isn't a new transaction, we'll consider Suri should move on @@ -455,6 +475,21 @@ impl PgsqlState { } } + /// Process State progress to decide if response is finished + /// + fn response_is_complete(state: PgsqlStateProgress) -> bool { + match state { + PgsqlStateProgress::ReadyForQueryReceived + | PgsqlStateProgress::SSLRejectedReceived + | PgsqlStateProgress::SimpleAuthenticationReceived + | PgsqlStateProgress::SASLAuthenticationReceived + | PgsqlStateProgress::SASLAuthenticationContinueReceived + | PgsqlStateProgress::SASLAuthenticationFinalReceived + | PgsqlStateProgress::Finished => true, + _ => false, + } + } + fn parse_response(&mut self, flow: *const Flow, input: &[u8]) -> AppLayerResult { // We're not interested in empty responses. if input.is_empty() { @@ -482,30 +517,36 @@ impl PgsqlState { ); start = rem; SCLogDebug!("Response is {:?}", &response); - if let Some(state) = self.response_process_next_state(&response, flow) { + let new_state = self.response_process_next_state(&response, flow); + if let Some(state) = new_state { self.state_progress = state; - }; - let tx_completed = self.is_tx_completed(); - let curr_state = self.state_progress; + } if let Some(tx) = self.find_or_create_tx() { - if curr_state == PgsqlStateProgress::DataRowReceived { - tx.incr_row_cnt(); - } else if curr_state == PgsqlStateProgress::CommandCompletedReceived - && tx.get_row_cnt() > 0 - { - // let's summarize the info from the data_rows in one response - let dummy_resp = - PgsqlBEMessage::ConsolidatedDataRow(ConsolidatedDataRowPacket { - identifier: b'D', - row_cnt: tx.get_row_cnt(), - data_size: tx.data_size, // total byte count of all data_row messages combined - }); - tx.responses.push(dummy_resp); - tx.responses.push(response); - } else { - tx.responses.push(response); - if tx_completed { - tx.tx_state = PgsqlTransactionState::ResponseDone; + if tx.tx_res_state == PgsqlTxProgress::TxInit { + tx.tx_res_state = PgsqlTxProgress::TxReceived; + } + if let Some(state) = new_state { + if state == PgsqlStateProgress::DataRowReceived { + tx.incr_row_cnt(); + } else if state == PgsqlStateProgress::CommandCompletedReceived + && tx.get_row_cnt() > 0 + { + // let's summarize the info from the data_rows in one response + let dummy_resp = PgsqlBEMessage::ConsolidatedDataRow( + ConsolidatedDataRowPacket { + identifier: b'D', + row_cnt: tx.get_row_cnt(), + data_size: tx.data_size, // total byte count of all data_row messages combined + }, + ); + tx.responses.push(dummy_resp); + tx.responses.push(response); + } else { + tx.responses.push(response); + if Self::response_is_complete(state) { + tx.tx_req_state = PgsqlTxProgress::TxDone; + tx.tx_res_state = PgsqlTxProgress::TxDone; + } } } } else { @@ -557,6 +598,22 @@ fn probe_tc(input: &[u8]) -> bool { false } +fn pgsql_tx_get_req_state(tx: *mut std::os::raw::c_void) -> PgsqlTxProgress { + let tx_safe: &mut PgsqlTransaction; + unsafe { + tx_safe = cast_pointer!(tx, PgsqlTransaction); + } + tx_safe.tx_req_state +} + +fn pgsql_tx_get_res_state(tx: *mut std::os::raw::c_void) -> PgsqlTxProgress { + let tx_safe: &mut PgsqlTransaction; + unsafe { + tx_safe = cast_pointer!(tx, PgsqlTransaction); + } + tx_safe.tx_res_state +} + // C exports. /// C entry point for a probing parser. @@ -712,10 +769,14 @@ pub extern "C" fn SCPgsqlStateGetTxCount(state: *mut std::os::raw::c_void) -> u6 #[no_mangle] pub unsafe extern "C" fn SCPgsqlTxGetALStateProgress( - tx: *mut std::os::raw::c_void, _direction: u8, + tx: *mut std::os::raw::c_void, direction: u8, ) -> std::os::raw::c_int { - let tx = cast_pointer!(tx, PgsqlTransaction); - tx.tx_state as i32 + if direction == Direction::ToServer as u8 { + return pgsql_tx_get_req_state(tx) as i32; + } + + // Direction has only two possible values, so we don't need to check for the other one + pgsql_tx_get_res_state(tx) as i32 } export_tx_data_get!(rs_pgsql_get_tx_data, PgsqlTransaction); @@ -743,8 +804,8 @@ pub unsafe extern "C" fn SCRegisterPgsqlParser() { parse_tc: SCPgsqlParseResponse, get_tx_count: SCPgsqlStateGetTxCount, get_tx: SCPgsqlStateGetTx, - tx_comp_st_ts: PgsqlTransactionState::RequestReceived as i32, - tx_comp_st_tc: PgsqlTransactionState::ResponseDone as i32, + tx_comp_st_ts: PgsqlTxProgress::TxDone as i32, + tx_comp_st_tc: PgsqlTxProgress::TxDone as i32, tx_get_progress: SCPgsqlTxGetALStateProgress, get_eventinfo: None, get_eventinfo_byid: None,