diff --git a/rust/src/pgsql/logger.rs b/rust/src/pgsql/logger.rs index bcfcb5a8e9..bec73b3829 100644 --- a/rust/src/pgsql/logger.rs +++ b/rust/src/pgsql/logger.rs @@ -97,10 +97,7 @@ fn log_request(req: &PgsqlFEMessage, flags: u32) -> Result { js.set_string_from_bytes(req.to_str(), payload)?; } - PgsqlFEMessage::CancelRequest(CancelRequestMessage { - pid, - backend_key, - }) => { + PgsqlFEMessage::CancelRequest(CancelRequestMessage { pid, backend_key }) => { js.set_string("message", "cancel_request")?; js.set_uint("process_id", (*pid).into())?; js.set_uint("secret_key", (*backend_key).into())?; diff --git a/rust/src/pgsql/parser.rs b/rust/src/pgsql/parser.rs index 886ee4c5dc..aa34acb4e5 100644 --- a/rust/src/pgsql/parser.rs +++ b/rust/src/pgsql/parser.rs @@ -279,7 +279,7 @@ impl PgsqlBEMessage { } PgsqlBEMessage::ConsolidatedDataRow(_) => "data_row", PgsqlBEMessage::NotificationResponse(_) => "notification_response", - PgsqlBEMessage::UnknownMessageType(_) => "unknown_message_type" + PgsqlBEMessage::UnknownMessageType(_) => "unknown_message_type", } } @@ -534,16 +534,25 @@ fn pgsql_parse_generic_parameter(i: &[u8]) -> IResult<&[u8], PgsqlParameter> { let (i, _) = tag("\x00")(i)?; let (i, param_value) = take_until("\x00")(i)?; let (i, _) = tag("\x00")(i)?; - Ok((i, PgsqlParameter { - name: PgsqlParameters::from(param_name), - value: param_value.to_vec(), - })) + Ok(( + i, + PgsqlParameter { + name: PgsqlParameters::from(param_name), + value: param_value.to_vec(), + }, + )) } pub fn pgsql_parse_startup_parameters(i: &[u8]) -> IResult<&[u8], PgsqlStartupParameters> { - let (i, mut optional) = opt(terminated(many1(pgsql_parse_generic_parameter), tag("\x00")))(i)?; + let (i, mut optional) = opt(terminated( + many1(pgsql_parse_generic_parameter), + tag("\x00"), + ))(i)?; if let Some(ref mut params) = optional { - let mut user = PgsqlParameter{name: PgsqlParameters::User, value: Vec::new() }; + let mut user = PgsqlParameter { + name: PgsqlParameters::User, + value: Vec::new(), + }; let mut index: usize = 0; for (j, p) in params.iter().enumerate() { if p.name == PgsqlParameters::User { @@ -555,17 +564,20 @@ pub fn pgsql_parse_startup_parameters(i: &[u8]) -> IResult<&[u8], PgsqlStartupPa if user.value.is_empty() { return Err(Err::Error(make_error(i, ErrorKind::Tag))); } - return Ok((i, PgsqlStartupParameters{ - user, - optional_params: if !params.is_empty() { - optional - } else { None }, - })); + return Ok(( + i, + PgsqlStartupParameters { + user, + optional_params: if !params.is_empty() { optional } else { None }, + }, + )); } return Err(Err::Error(make_error(i, ErrorKind::Tag))); } -fn parse_sasl_initial_response_payload(i: &[u8]) -> IResult<&[u8], (SASLAuthenticationMechanism, u32, Vec)> { +fn parse_sasl_initial_response_payload( + i: &[u8], +) -> IResult<&[u8], (SASLAuthenticationMechanism, u32, Vec)> { let (i, sasl_mechanism) = parse_sasl_mechanism(i)?; let (i, param_length) = be_u32(i)?; // From RFC 5802 - the client-first-message will always start w/ @@ -577,27 +589,31 @@ fn parse_sasl_initial_response_payload(i: &[u8]) -> IResult<&[u8], (SASLAuthenti pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; let (i, length) = parse_length(i)?; - let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload)(i)?; - Ok((i, PgsqlFEMessage::SASLInitialResponse( - SASLInitialResponsePacket { - identifier, - length, - auth_mechanism: payload.0, - param_length: payload.1, - sasl_param: payload.2, - }))) + let (i, payload) = map_parser( + take(length - PGSQL_LENGTH_FIELD), + parse_sasl_initial_response_payload, + )(i)?; + Ok(( + i, + PgsqlFEMessage::SASLInitialResponse(SASLInitialResponsePacket { + identifier, + length, + auth_mechanism: payload.0, + param_length: payload.1, + sasl_param: payload.2, + }), + )) } pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; let (i, length) = parse_length(i)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; - let resp = PgsqlFEMessage::SASLResponse( - RegularPacket { - identifier, - length, - payload: payload.to_vec(), - }); + let resp = PgsqlFEMessage::SASLResponse(RegularPacket { + identifier, + length, + payload: payload.to_vec(), + }); Ok((i, resp)) } @@ -605,37 +621,41 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, len) = verify(be_u32, |&x| x >= 8)(i)?; let (i, proto_major) = peek(be_u16)(i)?; let (i, b) = take(len - PGSQL_LENGTH_FIELD)(i)?; - let (_, message) = - match proto_major { - 1..=3 => { - let (b, proto_major) = be_u16(b)?; - let (b, proto_minor) = be_u16(b)?; - let (b, params) = pgsql_parse_startup_parameters(b)?; - (b, PgsqlFEMessage::StartupMessage(StartupPacket{ + let (_, message) = match proto_major { + 1..=3 => { + let (b, proto_major) = be_u16(b)?; + let (b, proto_minor) = be_u16(b)?; + let (b, params) = pgsql_parse_startup_parameters(b)?; + ( + b, + PgsqlFEMessage::StartupMessage(StartupPacket { length: len, proto_major, proto_minor, - params})) - }, - PGSQL_DUMMY_PROTO_MAJOR => { - let (b, proto_major) = be_u16(b)?; - let (b, proto_minor) = be_u16(b)?; - let (b, message) = match proto_minor { - PGSQL_DUMMY_PROTO_CANCEL_REQUEST => { - parse_cancel_request(b)? - }, - PGSQL_DUMMY_PROTO_MINOR_SSL => (b, PgsqlFEMessage::SSLRequest(DummyStartupPacket{ + params, + }), + ) + } + PGSQL_DUMMY_PROTO_MAJOR => { + let (b, proto_major) = be_u16(b)?; + let (b, proto_minor) = be_u16(b)?; + let (b, message) = match proto_minor { + PGSQL_DUMMY_PROTO_CANCEL_REQUEST => parse_cancel_request(b)?, + PGSQL_DUMMY_PROTO_MINOR_SSL => ( + b, + PgsqlFEMessage::SSLRequest(DummyStartupPacket { length: len, proto_major, - proto_minor - })), - _ => return Err(Err::Error(make_error(b, ErrorKind::Switch))), - }; + proto_minor, + }), + ), + _ => return Err(Err::Error(make_error(b, ErrorKind::Switch))), + }; - (b, message) - } - _ => return Err(Err::Error(make_error(b, ErrorKind::Switch))), - }; + (b, message) + } + _ => return Err(Err::Error(make_error(b, ErrorKind::Switch))), + }; Ok((i, message)) } @@ -655,42 +675,47 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; let (i, length) = parse_length(i)?; - let (i, password) = map_parser( - take(length - PGSQL_LENGTH_FIELD), - take_until1("\x00") - )(i)?; - Ok((i, PgsqlFEMessage::PasswordMessage( - RegularPacket{ - identifier, - length, - payload: password.to_vec(), - }))) + let (i, password) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?; + Ok(( + i, + PgsqlFEMessage::PasswordMessage(RegularPacket { + identifier, + length, + payload: password.to_vec(), + }), + )) } fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?; let (i, length) = parse_length(i)?; let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?; - Ok((i, PgsqlFEMessage::SimpleQuery(RegularPacket { - identifier, - length, - payload: query.to_vec(), - }))) + Ok(( + i, + PgsqlFEMessage::SimpleQuery(RegularPacket { + identifier, + length, + payload: query.to_vec(), + }), + )) } fn parse_cancel_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, pid) = be_u32(i)?; let (i, backend_key) = be_u32(i)?; - Ok((i, PgsqlFEMessage::CancelRequest(CancelRequestMessage { - pid, - backend_key, - }))) + Ok(( + i, + PgsqlFEMessage::CancelRequest(CancelRequestMessage { pid, backend_key }), + )) } fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?; let (i, length) = parse_length(i)?; - Ok((i, PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }))) + Ok(( + i, + PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }), + )) } // Messages that begin with 'p' but are not password ones are not parsed here @@ -704,7 +729,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = be_u8(i)?; let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; - let unknown = PgsqlFEMessage::UnknownMessageType (RegularPacket{ + let unknown = PgsqlFEMessage::UnknownMessageType(RegularPacket { identifier, length, payload: payload.to_vec(), @@ -719,95 +744,108 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq let (i, identifier) = verify(be_u8, |&x| x == b'R')(i)?; let (i, length) = verify(be_u32, |&x| x >= 8)(i)?; let (i, auth_type) = be_u32(i)?; - let (i, message) = map_parser( - take(length - 8), - |b: &'a [u8]| { - match auth_type { - 0 => Ok((b, PgsqlBEMessage::AuthenticationOk( - AuthenticationMessage { - identifier, - length, - auth_type, - payload: b.to_vec(), - }))), - 3 => Ok((b, PgsqlBEMessage::AuthenticationCleartextPassword( - AuthenticationMessage { - identifier, - length, - auth_type, - payload: b.to_vec(), - }))), - 5 => { - let (b, salt) = all_consuming(take(4_usize))(b)?; - Ok((b, PgsqlBEMessage::AuthenticationMD5Password( - AuthenticationMessage { - identifier, - length, - auth_type, - payload: salt.to_vec(), - }))) - } - 9 => Ok((b, PgsqlBEMessage::AuthenticationSSPI( - AuthenticationMessage { - identifier, - length, - auth_type, - payload: b.to_vec(), - }))), - // TODO - For SASL, should we parse specific details of the challenge itself? (as seen in: https://github.com/launchbadge/sqlx/blob/master/sqlx-core/src/postgres/message/authentication.rs ) - 10 => { - let (b, auth_mechanisms) = parse_sasl_mechanisms(b)?; - Ok((b, PgsqlBEMessage::AuthenticationSASL( - AuthenticationSASLMechanismMessage { - identifier, - length, - auth_type, - auth_mechanisms, - }))) - } - 11 => { - Ok((b, PgsqlBEMessage::AuthenticationSASLContinue( - AuthenticationMessage { - identifier, - length, - auth_type, - payload: b.to_vec(), - }))) - }, - 12 => { - Ok((b, PgsqlBEMessage::AuthenticationSASLFinal( - AuthenticationMessage { - identifier, - length, - auth_type, - payload: b.to_vec(), - } - ))) - } - // TODO add other authentication messages - _ => return Err(Err::Error(make_error(i, ErrorKind::Switch))), - } + let (i, message) = map_parser(take(length - 8), |b: &'a [u8]| { + match auth_type { + 0 => Ok(( + b, + PgsqlBEMessage::AuthenticationOk(AuthenticationMessage { + identifier, + length, + auth_type, + payload: b.to_vec(), + }), + )), + 3 => Ok(( + b, + PgsqlBEMessage::AuthenticationCleartextPassword(AuthenticationMessage { + identifier, + length, + auth_type, + payload: b.to_vec(), + }), + )), + 5 => { + let (b, salt) = all_consuming(take(4_usize))(b)?; + Ok(( + b, + PgsqlBEMessage::AuthenticationMD5Password(AuthenticationMessage { + identifier, + length, + auth_type, + payload: salt.to_vec(), + }), + )) + } + 9 => Ok(( + b, + PgsqlBEMessage::AuthenticationSSPI(AuthenticationMessage { + identifier, + length, + auth_type, + payload: b.to_vec(), + }), + )), + // TODO - For SASL, should we parse specific details of the challenge itself? (as seen in: https://github.com/launchbadge/sqlx/blob/master/sqlx-core/src/postgres/message/authentication.rs ) + 10 => { + let (b, auth_mechanisms) = parse_sasl_mechanisms(b)?; + Ok(( + b, + PgsqlBEMessage::AuthenticationSASL(AuthenticationSASLMechanismMessage { + identifier, + length, + auth_type, + auth_mechanisms, + }), + )) + } + 11 => Ok(( + b, + PgsqlBEMessage::AuthenticationSASLContinue(AuthenticationMessage { + identifier, + length, + auth_type, + payload: b.to_vec(), + }), + )), + 12 => Ok(( + b, + PgsqlBEMessage::AuthenticationSASLFinal(AuthenticationMessage { + identifier, + length, + auth_type, + payload: b.to_vec(), + }), + )), + // TODO add other authentication messages + _ => return Err(Err::Error(make_error(i, ErrorKind::Switch))), } - )(i)?; + })(i)?; Ok((i, message)) } fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?; let (i, length) = parse_length(i)?; - let (i, param) = map_parser(take(length - PGSQL_LENGTH_FIELD), pgsql_parse_generic_parameter)(i)?; - Ok((i, PgsqlBEMessage::ParameterStatus(ParameterStatusMessage { - identifier, - length, - param, - }))) + let (i, param) = map_parser( + take(length - PGSQL_LENGTH_FIELD), + pgsql_parse_generic_parameter, + )(i)?; + Ok(( + i, + PgsqlBEMessage::ParameterStatus(ParameterStatusMessage { + identifier, + length, + param, + }), + )) } pub fn parse_ssl_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, tag) = alt((char('N'), char('S')))(i)?; - Ok((i, PgsqlBEMessage::SSLResponse( - SSLResponseMessage::from(tag)) - )) + Ok(( + i, + PgsqlBEMessage::SSLResponse(SSLResponseMessage::from(tag)), + )) } fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { @@ -815,34 +853,43 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, length) = verify(be_u32, |&x| x == 12)(i)?; let (i, pid) = be_u32(i)?; let (i, secret_key) = be_u32(i)?; - Ok((i, PgsqlBEMessage::BackendKeyData(BackendKeyDataMessage { - identifier, - length, - backend_pid: pid, - secret_key, - }))) + Ok(( + i, + PgsqlBEMessage::BackendKeyData(BackendKeyDataMessage { + identifier, + length, + backend_pid: pid, + secret_key, + }), + )) } fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?; let (i, length) = parse_length(i)?; let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?; - Ok((i, PgsqlBEMessage::CommandComplete(RegularPacket { - identifier, - length, - payload: payload.to_vec(), - }))) + Ok(( + i, + PgsqlBEMessage::CommandComplete(RegularPacket { + identifier, + length, + payload: payload.to_vec(), + }), + )) } fn parse_ready_for_query(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'Z')(i)?; let (i, length) = verify(be_u32, |&x| x == 5)(i)?; let (i, status) = verify(be_u8, |&x| x == b'I' || x == b'T' || x == b'E')(i)?; - Ok((i, PgsqlBEMessage::ReadyForQuery(ReadyForQueryMessage { - identifier, - length, - transaction_status: status, - }))) + Ok(( + i, + PgsqlBEMessage::ReadyForQuery(ReadyForQueryMessage { + identifier, + length, + transaction_status: status, + }), + )) } fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField> { @@ -854,15 +901,18 @@ fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField> { let (i, data_type_size) = be_i16(i)?; let (i, type_modifier) = be_i32(i)?; let (i, format_code) = be_u16(i)?; - Ok((i, RowField { - field_name: field_name.to_vec(), - table_oid, - column_index, - data_type_oid, - data_type_size, - type_modifier, - format_code, - })) + Ok(( + i, + RowField { + field_name: field_name.to_vec(), + table_oid, + column_index, + data_type_oid, + data_type_size, + type_modifier, + format_code, + }, + )) } pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { @@ -871,29 +921,34 @@ pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, field_count) = be_u16(i)?; let (i, fields) = map_parser( take(length - 6), - many_m_n(0, field_count.into(), parse_row_field) + many_m_n(0, field_count.into(), parse_row_field), )(i)?; - Ok((i, PgsqlBEMessage::RowDescription( - RowDescriptionMessage { - identifier, - length, - field_count, - fields, - }))) + Ok(( + i, + PgsqlBEMessage::RowDescription(RowDescriptionMessage { + identifier, + length, + field_count, + fields, + }), + )) } fn parse_data_row_value(i: &[u8]) -> IResult<&[u8], ColumnFieldValue> { let (i, value_length) = be_i32(i)?; let (i, value) = cond(value_length >= 0, take(value_length as usize))(i)?; - Ok((i, ColumnFieldValue { - value_length, - value: { - match value { - Some(data) => data.to_vec(), - None => [].to_vec(), - } + Ok(( + i, + ColumnFieldValue { + value_length, + value: { + match value { + Some(data) => data.to_vec(), + None => [].to_vec(), + } + }, }, - })) + )) } /// For each column, add up the data size. Return the total @@ -916,14 +971,18 @@ pub fn parse_consolidated_data_row(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, length) = verify(be_u32, |&x| x >= 6)(i)?; let (i, field_count) = be_u16(i)?; // 6 here is for skipping length + field_count - let (i, rows) = map_parser(take(length - 6), many_m_n(0, field_count.into(), parse_data_row_value))(i)?; - Ok((i, PgsqlBEMessage::ConsolidatedDataRow( - ConsolidatedDataRowPacket { - identifier, - row_cnt: 1, - data_size: add_up_data_size(rows), - } - ))) + let (i, rows) = map_parser( + take(length - 6), + many_m_n(0, field_count.into(), parse_data_row_value), + )(i)?; + Ok(( + i, + PgsqlBEMessage::ConsolidatedDataRow(ConsolidatedDataRowPacket { + identifier, + row_cnt: 1, + data_size: add_up_data_size(rows), + }), + )) } fn parse_sasl_mechanism(i: &[u8]) -> IResult<&[u8], SASLAuthenticationMechanism> { @@ -945,10 +1004,13 @@ fn parse_sasl_mechanisms(i: &[u8]) -> IResult<&[u8], Vec IResult<&[u8], PgsqlErrorNoticeMessageField> { let (i, _field_type) = char('C')(i)?; let (i, field_value) = map_parser(take(6_usize), alphanumeric1)(i)?; - Ok((i, PgsqlErrorNoticeMessageField{ - field_type: PgsqlErrorNoticeFieldType::CodeSqlStateCode, - field_value: field_value.to_vec(), - })) + Ok(( + i, + PgsqlErrorNoticeMessageField { + field_type: PgsqlErrorNoticeFieldType::CodeSqlStateCode, + field_value: field_value.to_vec(), + }, + )) } // Parse an error response with non-localizeable severity message. @@ -957,10 +1019,13 @@ pub fn parse_error_response_severity(i: &[u8]) -> IResult<&[u8], PgsqlErrorNotic let (i, field_type) = char('V')(i)?; let (i, field_value) = alt((tag("ERROR"), tag("FATAL"), tag("PANIC")))(i)?; let (i, _) = tag("\x00")(i)?; - Ok((i, PgsqlErrorNoticeMessageField{ - field_type: PgsqlErrorNoticeFieldType::from(field_type), - field_value: field_value.to_vec(), - })) + Ok(( + i, + PgsqlErrorNoticeMessageField { + field_type: PgsqlErrorNoticeFieldType::from(field_type), + field_value: field_value.to_vec(), + }, + )) } // The non-localizable version of Severity field has different values, @@ -968,16 +1033,20 @@ pub fn parse_error_response_severity(i: &[u8]) -> IResult<&[u8], PgsqlErrorNotic pub fn parse_notice_response_severity(i: &[u8]) -> IResult<&[u8], PgsqlErrorNoticeMessageField> { let (i, field_type) = char('V')(i)?; let (i, field_value) = alt(( - tag("WARNING"), - tag("NOTICE"), - tag("DEBUG"), - tag("INFO"), - tag("LOG")))(i)?; + tag("WARNING"), + tag("NOTICE"), + tag("DEBUG"), + tag("INFO"), + tag("LOG"), + ))(i)?; let (i, _) = tag("\x00")(i)?; - Ok((i, PgsqlErrorNoticeMessageField{ - field_type: PgsqlErrorNoticeFieldType::from(field_type), - field_value: field_value.to_vec(), - })) + Ok(( + i, + PgsqlErrorNoticeMessageField { + field_type: PgsqlErrorNoticeFieldType::from(field_type), + field_value: field_value.to_vec(), + }, + )) } pub fn parse_error_response_field( @@ -1007,7 +1076,9 @@ pub fn parse_error_response_field( Ok((i, data)) } -pub fn parse_error_notice_fields(i: &[u8], is_err_msg: bool) -> IResult<&[u8], Vec> { +pub fn parse_error_notice_fields( + i: &[u8], is_err_msg: bool, +) -> IResult<&[u8], Vec> { let (i, data) = many_till(|b| parse_error_response_field(b, is_err_msg), tag("\x00"))(i)?; Ok((i, data.0)) } @@ -1015,45 +1086,47 @@ pub fn parse_error_notice_fields(i: &[u8], is_err_msg: bool) -> IResult<&[u8], V fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'E')(i)?; let (i, length) = verify(be_u32, |&x| x > 10)(i)?; - let (i, message_body) = map_parser( - take(length - PGSQL_LENGTH_FIELD), - |b| parse_error_notice_fields(b, true) - )(i)?; + let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { + parse_error_notice_fields(b, true) + })(i)?; - Ok((i, PgsqlBEMessage::ErrorResponse(ErrorNoticeMessage { - identifier, - length, - message_body, - }))) + Ok(( + i, + PgsqlBEMessage::ErrorResponse(ErrorNoticeMessage { + identifier, + length, + message_body, + }), + )) } fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'N')(i)?; let (i, length) = verify(be_u32, |&x| x > 10)(i)?; - let (i, message_body) = map_parser( - take(length - PGSQL_LENGTH_FIELD), - |b| parse_error_notice_fields(b, false) - )(i)?; - Ok((i, PgsqlBEMessage::NoticeResponse(ErrorNoticeMessage { - identifier, - length, - message_body, - }))) + let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { + parse_error_notice_fields(b, false) + })(i)?; + Ok(( + i, + PgsqlBEMessage::NoticeResponse(ErrorNoticeMessage { + identifier, + length, + message_body, + }), + )) } fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'A')(i)?; // length (u32) + pid (u32) + at least one byte, for we have two str fields let (i, length) = verify(be_u32, |&x| x > 9)(i)?; - let (i, data) = map_parser( - take(length - PGSQL_LENGTH_FIELD), - |b| { - let (b, pid) = be_u32(b)?; - let (b, channel_name) = take_until_and_consume(b"\x00")(b)?; - let (b, payload) = take_until_and_consume(b"\x00")(b)?; - Ok((b, (pid, channel_name, payload))) - })(i)?; - let msg = PgsqlBEMessage::NotificationResponse(NotificationResponse{ + let (i, data) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { + let (b, pid) = be_u32(b)?; + let (b, channel_name) = take_until_and_consume(b"\x00")(b)?; + let (b, payload) = take_until_and_consume(b"\x00")(b)?; + Ok((b, (pid, channel_name, payload))) + })(i)?; + let msg = PgsqlBEMessage::NotificationResponse(NotificationResponse { identifier, length, pid: data.0, @@ -1065,31 +1138,29 @@ fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, pseudo_header) = peek(tuple((be_u8, be_u32)))(i)?; - let (i, message) = - match pseudo_header.0 { - b'E' => pgsql_parse_error_response(i)?, - b'K' => parse_backend_key_data_message(i)?, - b'N' => pgsql_parse_notice_response(i)?, - b'R' => pgsql_parse_authentication_message(i)?, - b'S' => parse_parameter_status_message(i)?, - b'C' => parse_command_complete(i)?, - b'Z' => parse_ready_for_query(i)?, - b'T' => parse_row_description(i)?, - b'A' => parse_notification_response(i)?, - b'D' => parse_consolidated_data_row(i)?, - _ => { - let (i, identifier) = be_u8(i)?; - let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; - let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; - let unknown = PgsqlBEMessage::UnknownMessageType (RegularPacket{ - identifier, - length, - payload: payload.to_vec(), - }); - (i, unknown) - } - - }; + let (i, message) = match pseudo_header.0 { + b'E' => pgsql_parse_error_response(i)?, + b'K' => parse_backend_key_data_message(i)?, + b'N' => pgsql_parse_notice_response(i)?, + b'R' => pgsql_parse_authentication_message(i)?, + b'S' => parse_parameter_status_message(i)?, + b'C' => parse_command_complete(i)?, + b'Z' => parse_ready_for_query(i)?, + b'T' => parse_row_description(i)?, + b'A' => parse_notification_response(i)?, + b'D' => parse_consolidated_data_row(i)?, + _ => { + let (i, identifier) = be_u8(i)?; + let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; + let unknown = PgsqlBEMessage::UnknownMessageType(RegularPacket { + identifier, + length, + payload: payload.to_vec(), + }); + (i, unknown) + } + }; Ok((i, message)) } @@ -1279,7 +1350,6 @@ mod tests { let result = parse_request(&buf[0..3]); assert!(result.is_err()); - } #[test] @@ -1289,7 +1359,8 @@ mod tests { 0x00, 0x00, 0x00, 0x10, // length: 16 (fixed) 0x04, 0xd2, 0x16, 0x2e, // 1234.5678 - identifies a cancel request 0x00, 0x00, 0x76, 0x31, // PID: 30257 - 0x23, 0x84, 0xf7, 0x2d]; // Backend key: 595916589 + 0x23, 0x84, 0xf7, 0x2d, + ]; // Backend key: 595916589 let result = parse_cancel_request(buf); assert!(result.is_ok()); @@ -1309,8 +1380,6 @@ mod tests { assert!(fail_result.is_err()); } - - #[test] fn test_parse_error_response_code() { let buf: &[u8] = &[0x43, 0x32, 0x38, 0x30, 0x30, 0x30, 0x00]; @@ -1916,7 +1985,8 @@ mod tests { 0x2b, 0x4a, 0x36, 0x79, 0x78, 0x72, 0x66, 0x77, 0x2f, 0x7a, 0x7a, 0x70, 0x38, 0x59, 0x54, 0x39, 0x65, 0x78, 0x56, 0x37, 0x73, 0x38, 0x3d, ]; - let (remainder, result) = pgsql_parse_response(bad_buf).expect("parsing sasl final response failed"); + let (remainder, result) = + pgsql_parse_response(bad_buf).expect("parsing sasl final response failed"); let res = PgsqlBEMessage::UnknownMessageType(RegularPacket { identifier: b'`', length: 54, @@ -2128,55 +2198,34 @@ mod tests { // S #standard_conforming_strings on S ·TimeZone Europe/Paris // K ···O··Z ·I let buf = &[ - 0x52, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x00, 0x53, 0x00, 0x00, 0x00, 0x16, 0x61, 0x70, - 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x00, 0x00, - 0x53, 0x00, 0x00, 0x00, 0x19, 0x63, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x5f, 0x65, 0x6e, 0x63, 0x6f, - 0x64, 0x69, 0x6e, 0x67, 0x00, 0x55, 0x54, 0x46, - 0x38, 0x00, 0x53, 0x00, 0x00, 0x00, 0x17, 0x44, - 0x61, 0x74, 0x65, 0x53, 0x74, 0x79, 0x6c, 0x65, - 0x00, 0x49, 0x53, 0x4f, 0x2c, 0x20, 0x4d, 0x44, - 0x59, 0x00, 0x53, 0x00, 0x00, 0x00, 0x26, 0x64, - 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x74, - 0x72, 0x61, 0x6e, 0x73, 0x61, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x5f, 0x72, 0x65, 0x61, 0x64, 0x5f, - 0x6f, 0x6e, 0x6c, 0x79, 0x00, 0x6f, 0x66, 0x66, - 0x00, 0x53, 0x00, 0x00, 0x00, 0x17, 0x69, 0x6e, - 0x5f, 0x68, 0x6f, 0x74, 0x5f, 0x73, 0x74, 0x61, - 0x6e, 0x64, 0x62, 0x79, 0x00, 0x6f, 0x66, 0x66, - 0x00, 0x53, 0x00, 0x00, 0x00, 0x19, 0x69, 0x6e, - 0x74, 0x65, 0x67, 0x65, 0x72, 0x5f, 0x64, 0x61, - 0x74, 0x65, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x00, - 0x6f, 0x6e, 0x00, 0x53, 0x00, 0x00, 0x00, 0x1b, - 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, - 0x53, 0x74, 0x79, 0x6c, 0x65, 0x00, 0x70, 0x6f, - 0x73, 0x74, 0x67, 0x72, 0x65, 0x73, 0x00, 0x53, - 0x00, 0x00, 0x00, 0x15, 0x69, 0x73, 0x5f, 0x73, - 0x75, 0x70, 0x65, 0x72, 0x75, 0x73, 0x65, 0x72, - 0x00, 0x6f, 0x66, 0x66, 0x00, 0x53, 0x00, 0x00, - 0x00, 0x19, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x5f, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, - 0x67, 0x00, 0x55, 0x54, 0x46, 0x38, 0x00, 0x53, - 0x00, 0x00, 0x00, 0x18, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, - 0x6f, 0x6e, 0x00, 0x31, 0x34, 0x2e, 0x35, 0x00, - 0x53, 0x00, 0x00, 0x00, 0x22, 0x73, 0x65, 0x73, - 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x00, 0x63, 0x74, 0x66, 0x70, 0x6f, - 0x73, 0x74, 0x00, 0x53, 0x00, 0x00, 0x00, 0x23, - 0x73, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, - 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x6f, 0x72, 0x6d, - 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x73, 0x00, 0x6f, 0x6e, 0x00, 0x53, - 0x00, 0x00, 0x00, 0x1a, 0x54, 0x69, 0x6d, 0x65, - 0x5a, 0x6f, 0x6e, 0x65, 0x00, 0x45, 0x75, 0x72, - 0x6f, 0x70, 0x65, 0x2f, 0x50, 0x61, 0x72, 0x69, - 0x73, 0x00, 0x4b, 0x00, 0x00, 0x00, 0x0c, 0x00, - 0x00, 0x0b, 0x8d, 0xcf, 0x4f, 0xb6, 0xcf, 0x5a, - 0x00, 0x00, 0x00, 0x05, 0x49 + 0x52, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, 0x16, + 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, 0x19, 0x63, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x5f, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x00, 0x55, 0x54, 0x46, + 0x38, 0x00, 0x53, 0x00, 0x00, 0x00, 0x17, 0x44, 0x61, 0x74, 0x65, 0x53, 0x74, 0x79, + 0x6c, 0x65, 0x00, 0x49, 0x53, 0x4f, 0x2c, 0x20, 0x4d, 0x44, 0x59, 0x00, 0x53, 0x00, + 0x00, 0x00, 0x26, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x72, 0x65, 0x61, 0x64, 0x5f, + 0x6f, 0x6e, 0x6c, 0x79, 0x00, 0x6f, 0x66, 0x66, 0x00, 0x53, 0x00, 0x00, 0x00, 0x17, + 0x69, 0x6e, 0x5f, 0x68, 0x6f, 0x74, 0x5f, 0x73, 0x74, 0x61, 0x6e, 0x64, 0x62, 0x79, + 0x00, 0x6f, 0x66, 0x66, 0x00, 0x53, 0x00, 0x00, 0x00, 0x19, 0x69, 0x6e, 0x74, 0x65, + 0x67, 0x65, 0x72, 0x5f, 0x64, 0x61, 0x74, 0x65, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x00, + 0x6f, 0x6e, 0x00, 0x53, 0x00, 0x00, 0x00, 0x1b, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, + 0x61, 0x6c, 0x53, 0x74, 0x79, 0x6c, 0x65, 0x00, 0x70, 0x6f, 0x73, 0x74, 0x67, 0x72, + 0x65, 0x73, 0x00, 0x53, 0x00, 0x00, 0x00, 0x15, 0x69, 0x73, 0x5f, 0x73, 0x75, 0x70, + 0x65, 0x72, 0x75, 0x73, 0x65, 0x72, 0x00, 0x6f, 0x66, 0x66, 0x00, 0x53, 0x00, 0x00, + 0x00, 0x19, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x65, 0x6e, 0x63, 0x6f, 0x64, + 0x69, 0x6e, 0x67, 0x00, 0x55, 0x54, 0x46, 0x38, 0x00, 0x53, 0x00, 0x00, 0x00, 0x18, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x00, 0x31, 0x34, 0x2e, 0x35, 0x00, 0x53, 0x00, 0x00, 0x00, 0x22, 0x73, 0x65, 0x73, + 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x00, 0x63, 0x74, 0x66, 0x70, 0x6f, 0x73, 0x74, 0x00, 0x53, + 0x00, 0x00, 0x00, 0x23, 0x73, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, 0x5f, 0x63, + 0x6f, 0x6e, 0x66, 0x6f, 0x72, 0x6d, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x73, 0x00, 0x6f, 0x6e, 0x00, 0x53, 0x00, 0x00, 0x00, 0x1a, 0x54, 0x69, + 0x6d, 0x65, 0x5a, 0x6f, 0x6e, 0x65, 0x00, 0x45, 0x75, 0x72, 0x6f, 0x70, 0x65, 0x2f, + 0x50, 0x61, 0x72, 0x69, 0x73, 0x00, 0x4b, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x0b, + 0x8d, 0xcf, 0x4f, 0xb6, 0xcf, 0x5a, 0x00, 0x00, 0x00, 0x05, 0x49, ]; let result = pgsql_parse_response(buf); diff --git a/rust/src/pgsql/pgsql.rs b/rust/src/pgsql/pgsql.rs index 105bdcdb12..0a733d3c22 100644 --- a/rust/src/pgsql/pgsql.rs +++ b/rust/src/pgsql/pgsql.rs @@ -22,11 +22,11 @@ use super::parser::{self, ConsolidatedDataRowPacket, PgsqlBEMessage, PgsqlFEMessage}; use crate::applayer::*; use crate::conf::*; +use crate::core::{AppProto, Direction, Flow, ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP, *}; use nom7::{Err, IResult}; use std; use std::collections::VecDeque; use std::ffi::CString; -use crate::core::{Flow, AppProto, Direction, ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP, *}; pub const PGSQL_CONFIG_DEFAULT_STREAM_DEPTH: u32 = 0; @@ -341,7 +341,10 @@ impl PgsqlState { ); match PgsqlState::state_based_req_parsing(self.state_progress, start) { Ok((rem, request)) => { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32); + sc_app_layer_parser_trigger_raw_stream_reassembly( + flow, + Direction::ToServer as i32, + ); start = rem; if let Some(state) = PgsqlState::request_next_state(&request) { self.state_progress = state; @@ -471,7 +474,10 @@ impl PgsqlState { while !start.is_empty() { match PgsqlState::state_based_resp_parsing(self.state_progress, start) { Ok((rem, response)) => { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32); + sc_app_layer_parser_trigger_raw_stream_reassembly( + flow, + Direction::ToClient as i32, + ); start = rem; SCLogDebug!("Response is {:?}", &response); if let Some(state) = self.response_process_next_state(&response, flow) { @@ -557,7 +563,6 @@ pub unsafe extern "C" fn rs_pgsql_probing_parser_ts( _flow: *const Flow, _direction: u8, input: *const u8, input_len: u32, _rdir: *mut u8, ) -> AppProto { if input_len >= 1 && !input.is_null() { - let slice: &[u8] = build_slice!(input, input_len as usize); match parser::parse_request(slice) { @@ -584,7 +589,6 @@ pub unsafe extern "C" fn rs_pgsql_probing_parser_tc( _flow: *const Flow, _direction: u8, input: *const u8, input_len: u32, _rdir: *mut u8, ) -> AppProto { if input_len >= 1 && !input.is_null() { - let slice: &[u8] = build_slice!(input, input_len as usize); if parser::parse_ssl_response(slice).is_ok() { @@ -647,7 +651,6 @@ pub unsafe extern "C" fn rs_pgsql_parse_request( } } - let state_safe: &mut PgsqlState = cast_pointer!(state, PgsqlState); if stream_slice.is_gap() {