pgsql: apply rust fmt changes

pull/11804/head
Juliana Fajardini 1 month ago committed by Victor Julien
parent ef63aa50e2
commit 7aeb718dd7

@ -97,10 +97,7 @@ fn log_request(req: &PgsqlFEMessage, flags: u32) -> Result<JsonBuilder, JsonErro
}) => {
js.set_string_from_bytes(req.to_str(), payload)?;
}
PgsqlFEMessage::CancelRequest(CancelRequestMessage {
pid,
backend_key,
}) => {
PgsqlFEMessage::CancelRequest(CancelRequestMessage { pid, backend_key }) => {
js.set_string("message", "cancel_request")?;
js.set_uint("process_id", (*pid).into())?;
js.set_uint("secret_key", (*backend_key).into())?;

@ -279,7 +279,7 @@ impl PgsqlBEMessage {
}
PgsqlBEMessage::ConsolidatedDataRow(_) => "data_row",
PgsqlBEMessage::NotificationResponse(_) => "notification_response",
PgsqlBEMessage::UnknownMessageType(_) => "unknown_message_type"
PgsqlBEMessage::UnknownMessageType(_) => "unknown_message_type",
}
}
@ -534,16 +534,25 @@ fn pgsql_parse_generic_parameter(i: &[u8]) -> IResult<&[u8], PgsqlParameter> {
let (i, _) = tag("\x00")(i)?;
let (i, param_value) = take_until("\x00")(i)?;
let (i, _) = tag("\x00")(i)?;
Ok((i, PgsqlParameter {
name: PgsqlParameters::from(param_name),
value: param_value.to_vec(),
}))
Ok((
i,
PgsqlParameter {
name: PgsqlParameters::from(param_name),
value: param_value.to_vec(),
},
))
}
pub fn pgsql_parse_startup_parameters(i: &[u8]) -> IResult<&[u8], PgsqlStartupParameters> {
let (i, mut optional) = opt(terminated(many1(pgsql_parse_generic_parameter), tag("\x00")))(i)?;
let (i, mut optional) = opt(terminated(
many1(pgsql_parse_generic_parameter),
tag("\x00"),
))(i)?;
if let Some(ref mut params) = optional {
let mut user = PgsqlParameter{name: PgsqlParameters::User, value: Vec::new() };
let mut user = PgsqlParameter {
name: PgsqlParameters::User,
value: Vec::new(),
};
let mut index: usize = 0;
for (j, p) in params.iter().enumerate() {
if p.name == PgsqlParameters::User {
@ -555,17 +564,20 @@ pub fn pgsql_parse_startup_parameters(i: &[u8]) -> IResult<&[u8], PgsqlStartupPa
if user.value.is_empty() {
return Err(Err::Error(make_error(i, ErrorKind::Tag)));
}
return Ok((i, PgsqlStartupParameters{
user,
optional_params: if !params.is_empty() {
optional
} else { None },
}));
return Ok((
i,
PgsqlStartupParameters {
user,
optional_params: if !params.is_empty() { optional } else { None },
},
));
}
return Err(Err::Error(make_error(i, ErrorKind::Tag)));
}
fn parse_sasl_initial_response_payload(i: &[u8]) -> IResult<&[u8], (SASLAuthenticationMechanism, u32, Vec<u8>)> {
fn parse_sasl_initial_response_payload(
i: &[u8],
) -> IResult<&[u8], (SASLAuthenticationMechanism, u32, Vec<u8>)> {
let (i, sasl_mechanism) = parse_sasl_mechanism(i)?;
let (i, param_length) = be_u32(i)?;
// From RFC 5802 - the client-first-message will always start w/
@ -577,27 +589,31 @@ fn parse_sasl_initial_response_payload(i: &[u8]) -> IResult<&[u8], (SASLAuthenti
pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = parse_length(i)?;
let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload)(i)?;
Ok((i, PgsqlFEMessage::SASLInitialResponse(
SASLInitialResponsePacket {
identifier,
length,
auth_mechanism: payload.0,
param_length: payload.1,
sasl_param: payload.2,
})))
let (i, payload) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
parse_sasl_initial_response_payload,
)(i)?;
Ok((
i,
PgsqlFEMessage::SASLInitialResponse(SASLInitialResponsePacket {
identifier,
length,
auth_mechanism: payload.0,
param_length: payload.1,
sasl_param: payload.2,
}),
))
}
pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = parse_length(i)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let resp = PgsqlFEMessage::SASLResponse(
RegularPacket {
identifier,
length,
payload: payload.to_vec(),
});
let resp = PgsqlFEMessage::SASLResponse(RegularPacket {
identifier,
length,
payload: payload.to_vec(),
});
Ok((i, resp))
}
@ -605,37 +621,41 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, len) = verify(be_u32, |&x| x >= 8)(i)?;
let (i, proto_major) = peek(be_u16)(i)?;
let (i, b) = take(len - PGSQL_LENGTH_FIELD)(i)?;
let (_, message) =
match proto_major {
1..=3 => {
let (b, proto_major) = be_u16(b)?;
let (b, proto_minor) = be_u16(b)?;
let (b, params) = pgsql_parse_startup_parameters(b)?;
(b, PgsqlFEMessage::StartupMessage(StartupPacket{
let (_, message) = match proto_major {
1..=3 => {
let (b, proto_major) = be_u16(b)?;
let (b, proto_minor) = be_u16(b)?;
let (b, params) = pgsql_parse_startup_parameters(b)?;
(
b,
PgsqlFEMessage::StartupMessage(StartupPacket {
length: len,
proto_major,
proto_minor,
params}))
},
PGSQL_DUMMY_PROTO_MAJOR => {
let (b, proto_major) = be_u16(b)?;
let (b, proto_minor) = be_u16(b)?;
let (b, message) = match proto_minor {
PGSQL_DUMMY_PROTO_CANCEL_REQUEST => {
parse_cancel_request(b)?
},
PGSQL_DUMMY_PROTO_MINOR_SSL => (b, PgsqlFEMessage::SSLRequest(DummyStartupPacket{
params,
}),
)
}
PGSQL_DUMMY_PROTO_MAJOR => {
let (b, proto_major) = be_u16(b)?;
let (b, proto_minor) = be_u16(b)?;
let (b, message) = match proto_minor {
PGSQL_DUMMY_PROTO_CANCEL_REQUEST => parse_cancel_request(b)?,
PGSQL_DUMMY_PROTO_MINOR_SSL => (
b,
PgsqlFEMessage::SSLRequest(DummyStartupPacket {
length: len,
proto_major,
proto_minor
})),
_ => return Err(Err::Error(make_error(b, ErrorKind::Switch))),
};
proto_minor,
}),
),
_ => return Err(Err::Error(make_error(b, ErrorKind::Switch))),
};
(b, message)
}
_ => return Err(Err::Error(make_error(b, ErrorKind::Switch))),
};
(b, message)
}
_ => return Err(Err::Error(make_error(b, ErrorKind::Switch))),
};
Ok((i, message))
}
@ -655,42 +675,47 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = parse_length(i)?;
let (i, password) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
take_until1("\x00")
)(i)?;
Ok((i, PgsqlFEMessage::PasswordMessage(
RegularPacket{
identifier,
length,
payload: password.to_vec(),
})))
let (i, password) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
Ok((
i,
PgsqlFEMessage::PasswordMessage(RegularPacket {
identifier,
length,
payload: password.to_vec(),
}),
))
}
fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?;
let (i, length) = parse_length(i)?;
let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
Ok((i, PgsqlFEMessage::SimpleQuery(RegularPacket {
identifier,
length,
payload: query.to_vec(),
})))
Ok((
i,
PgsqlFEMessage::SimpleQuery(RegularPacket {
identifier,
length,
payload: query.to_vec(),
}),
))
}
fn parse_cancel_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, pid) = be_u32(i)?;
let (i, backend_key) = be_u32(i)?;
Ok((i, PgsqlFEMessage::CancelRequest(CancelRequestMessage {
pid,
backend_key,
})))
Ok((
i,
PgsqlFEMessage::CancelRequest(CancelRequestMessage { pid, backend_key }),
))
}
fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?;
let (i, length) = parse_length(i)?;
Ok((i, PgsqlFEMessage::Terminate(TerminationMessage { identifier, length })))
Ok((
i,
PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }),
))
}
// Messages that begin with 'p' but are not password ones are not parsed here
@ -704,7 +729,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = be_u8(i)?;
let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let unknown = PgsqlFEMessage::UnknownMessageType (RegularPacket{
let unknown = PgsqlFEMessage::UnknownMessageType(RegularPacket {
identifier,
length,
payload: payload.to_vec(),
@ -719,95 +744,108 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq
let (i, identifier) = verify(be_u8, |&x| x == b'R')(i)?;
let (i, length) = verify(be_u32, |&x| x >= 8)(i)?;
let (i, auth_type) = be_u32(i)?;
let (i, message) = map_parser(
take(length - 8),
|b: &'a [u8]| {
match auth_type {
0 => Ok((b, PgsqlBEMessage::AuthenticationOk(
AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}))),
3 => Ok((b, PgsqlBEMessage::AuthenticationCleartextPassword(
AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}))),
5 => {
let (b, salt) = all_consuming(take(4_usize))(b)?;
Ok((b, PgsqlBEMessage::AuthenticationMD5Password(
AuthenticationMessage {
identifier,
length,
auth_type,
payload: salt.to_vec(),
})))
}
9 => Ok((b, PgsqlBEMessage::AuthenticationSSPI(
AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}))),
// TODO - For SASL, should we parse specific details of the challenge itself? (as seen in: https://github.com/launchbadge/sqlx/blob/master/sqlx-core/src/postgres/message/authentication.rs )
10 => {
let (b, auth_mechanisms) = parse_sasl_mechanisms(b)?;
Ok((b, PgsqlBEMessage::AuthenticationSASL(
AuthenticationSASLMechanismMessage {
identifier,
length,
auth_type,
auth_mechanisms,
})))
}
11 => {
Ok((b, PgsqlBEMessage::AuthenticationSASLContinue(
AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
})))
},
12 => {
Ok((b, PgsqlBEMessage::AuthenticationSASLFinal(
AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}
)))
}
// TODO add other authentication messages
_ => return Err(Err::Error(make_error(i, ErrorKind::Switch))),
}
let (i, message) = map_parser(take(length - 8), |b: &'a [u8]| {
match auth_type {
0 => Ok((
b,
PgsqlBEMessage::AuthenticationOk(AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}),
)),
3 => Ok((
b,
PgsqlBEMessage::AuthenticationCleartextPassword(AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}),
)),
5 => {
let (b, salt) = all_consuming(take(4_usize))(b)?;
Ok((
b,
PgsqlBEMessage::AuthenticationMD5Password(AuthenticationMessage {
identifier,
length,
auth_type,
payload: salt.to_vec(),
}),
))
}
9 => Ok((
b,
PgsqlBEMessage::AuthenticationSSPI(AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}),
)),
// TODO - For SASL, should we parse specific details of the challenge itself? (as seen in: https://github.com/launchbadge/sqlx/blob/master/sqlx-core/src/postgres/message/authentication.rs )
10 => {
let (b, auth_mechanisms) = parse_sasl_mechanisms(b)?;
Ok((
b,
PgsqlBEMessage::AuthenticationSASL(AuthenticationSASLMechanismMessage {
identifier,
length,
auth_type,
auth_mechanisms,
}),
))
}
11 => Ok((
b,
PgsqlBEMessage::AuthenticationSASLContinue(AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}),
)),
12 => Ok((
b,
PgsqlBEMessage::AuthenticationSASLFinal(AuthenticationMessage {
identifier,
length,
auth_type,
payload: b.to_vec(),
}),
)),
// TODO add other authentication messages
_ => return Err(Err::Error(make_error(i, ErrorKind::Switch))),
}
)(i)?;
})(i)?;
Ok((i, message))
}
fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?;
let (i, length) = parse_length(i)?;
let (i, param) = map_parser(take(length - PGSQL_LENGTH_FIELD), pgsql_parse_generic_parameter)(i)?;
Ok((i, PgsqlBEMessage::ParameterStatus(ParameterStatusMessage {
identifier,
length,
param,
})))
let (i, param) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
pgsql_parse_generic_parameter,
)(i)?;
Ok((
i,
PgsqlBEMessage::ParameterStatus(ParameterStatusMessage {
identifier,
length,
param,
}),
))
}
pub fn parse_ssl_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, tag) = alt((char('N'), char('S')))(i)?;
Ok((i, PgsqlBEMessage::SSLResponse(
SSLResponseMessage::from(tag))
))
Ok((
i,
PgsqlBEMessage::SSLResponse(SSLResponseMessage::from(tag)),
))
}
fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
@ -815,34 +853,43 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, length) = verify(be_u32, |&x| x == 12)(i)?;
let (i, pid) = be_u32(i)?;
let (i, secret_key) = be_u32(i)?;
Ok((i, PgsqlBEMessage::BackendKeyData(BackendKeyDataMessage {
identifier,
length,
backend_pid: pid,
secret_key,
})))
Ok((
i,
PgsqlBEMessage::BackendKeyData(BackendKeyDataMessage {
identifier,
length,
backend_pid: pid,
secret_key,
}),
))
}
fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?;
let (i, length) = parse_length(i)?;
let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?;
Ok((i, PgsqlBEMessage::CommandComplete(RegularPacket {
identifier,
length,
payload: payload.to_vec(),
})))
Ok((
i,
PgsqlBEMessage::CommandComplete(RegularPacket {
identifier,
length,
payload: payload.to_vec(),
}),
))
}
fn parse_ready_for_query(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'Z')(i)?;
let (i, length) = verify(be_u32, |&x| x == 5)(i)?;
let (i, status) = verify(be_u8, |&x| x == b'I' || x == b'T' || x == b'E')(i)?;
Ok((i, PgsqlBEMessage::ReadyForQuery(ReadyForQueryMessage {
identifier,
length,
transaction_status: status,
})))
Ok((
i,
PgsqlBEMessage::ReadyForQuery(ReadyForQueryMessage {
identifier,
length,
transaction_status: status,
}),
))
}
fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField> {
@ -854,15 +901,18 @@ fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField> {
let (i, data_type_size) = be_i16(i)?;
let (i, type_modifier) = be_i32(i)?;
let (i, format_code) = be_u16(i)?;
Ok((i, RowField {
field_name: field_name.to_vec(),
table_oid,
column_index,
data_type_oid,
data_type_size,
type_modifier,
format_code,
}))
Ok((
i,
RowField {
field_name: field_name.to_vec(),
table_oid,
column_index,
data_type_oid,
data_type_size,
type_modifier,
format_code,
},
))
}
pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
@ -871,29 +921,34 @@ pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, field_count) = be_u16(i)?;
let (i, fields) = map_parser(
take(length - 6),
many_m_n(0, field_count.into(), parse_row_field)
many_m_n(0, field_count.into(), parse_row_field),
)(i)?;
Ok((i, PgsqlBEMessage::RowDescription(
RowDescriptionMessage {
identifier,
length,
field_count,
fields,
})))
Ok((
i,
PgsqlBEMessage::RowDescription(RowDescriptionMessage {
identifier,
length,
field_count,
fields,
}),
))
}
fn parse_data_row_value(i: &[u8]) -> IResult<&[u8], ColumnFieldValue> {
let (i, value_length) = be_i32(i)?;
let (i, value) = cond(value_length >= 0, take(value_length as usize))(i)?;
Ok((i, ColumnFieldValue {
value_length,
value: {
match value {
Some(data) => data.to_vec(),
None => [].to_vec(),
}
Ok((
i,
ColumnFieldValue {
value_length,
value: {
match value {
Some(data) => data.to_vec(),
None => [].to_vec(),
}
},
},
}))
))
}
/// For each column, add up the data size. Return the total
@ -916,14 +971,18 @@ pub fn parse_consolidated_data_row(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, length) = verify(be_u32, |&x| x >= 6)(i)?;
let (i, field_count) = be_u16(i)?;
// 6 here is for skipping length + field_count
let (i, rows) = map_parser(take(length - 6), many_m_n(0, field_count.into(), parse_data_row_value))(i)?;
Ok((i, PgsqlBEMessage::ConsolidatedDataRow(
ConsolidatedDataRowPacket {
identifier,
row_cnt: 1,
data_size: add_up_data_size(rows),
}
)))
let (i, rows) = map_parser(
take(length - 6),
many_m_n(0, field_count.into(), parse_data_row_value),
)(i)?;
Ok((
i,
PgsqlBEMessage::ConsolidatedDataRow(ConsolidatedDataRowPacket {
identifier,
row_cnt: 1,
data_size: add_up_data_size(rows),
}),
))
}
fn parse_sasl_mechanism(i: &[u8]) -> IResult<&[u8], SASLAuthenticationMechanism> {
@ -945,10 +1004,13 @@ fn parse_sasl_mechanisms(i: &[u8]) -> IResult<&[u8], Vec<SASLAuthenticationMecha
pub fn parse_error_response_code(i: &[u8]) -> IResult<&[u8], PgsqlErrorNoticeMessageField> {
let (i, _field_type) = char('C')(i)?;
let (i, field_value) = map_parser(take(6_usize), alphanumeric1)(i)?;
Ok((i, PgsqlErrorNoticeMessageField{
field_type: PgsqlErrorNoticeFieldType::CodeSqlStateCode,
field_value: field_value.to_vec(),
}))
Ok((
i,
PgsqlErrorNoticeMessageField {
field_type: PgsqlErrorNoticeFieldType::CodeSqlStateCode,
field_value: field_value.to_vec(),
},
))
}
// Parse an error response with non-localizeable severity message.
@ -957,10 +1019,13 @@ pub fn parse_error_response_severity(i: &[u8]) -> IResult<&[u8], PgsqlErrorNotic
let (i, field_type) = char('V')(i)?;
let (i, field_value) = alt((tag("ERROR"), tag("FATAL"), tag("PANIC")))(i)?;
let (i, _) = tag("\x00")(i)?;
Ok((i, PgsqlErrorNoticeMessageField{
field_type: PgsqlErrorNoticeFieldType::from(field_type),
field_value: field_value.to_vec(),
}))
Ok((
i,
PgsqlErrorNoticeMessageField {
field_type: PgsqlErrorNoticeFieldType::from(field_type),
field_value: field_value.to_vec(),
},
))
}
// The non-localizable version of Severity field has different values,
@ -968,16 +1033,20 @@ pub fn parse_error_response_severity(i: &[u8]) -> IResult<&[u8], PgsqlErrorNotic
pub fn parse_notice_response_severity(i: &[u8]) -> IResult<&[u8], PgsqlErrorNoticeMessageField> {
let (i, field_type) = char('V')(i)?;
let (i, field_value) = alt((
tag("WARNING"),
tag("NOTICE"),
tag("DEBUG"),
tag("INFO"),
tag("LOG")))(i)?;
tag("WARNING"),
tag("NOTICE"),
tag("DEBUG"),
tag("INFO"),
tag("LOG"),
))(i)?;
let (i, _) = tag("\x00")(i)?;
Ok((i, PgsqlErrorNoticeMessageField{
field_type: PgsqlErrorNoticeFieldType::from(field_type),
field_value: field_value.to_vec(),
}))
Ok((
i,
PgsqlErrorNoticeMessageField {
field_type: PgsqlErrorNoticeFieldType::from(field_type),
field_value: field_value.to_vec(),
},
))
}
pub fn parse_error_response_field(
@ -1007,7 +1076,9 @@ pub fn parse_error_response_field(
Ok((i, data))
}
pub fn parse_error_notice_fields(i: &[u8], is_err_msg: bool) -> IResult<&[u8], Vec<PgsqlErrorNoticeMessageField>> {
pub fn parse_error_notice_fields(
i: &[u8], is_err_msg: bool,
) -> IResult<&[u8], Vec<PgsqlErrorNoticeMessageField>> {
let (i, data) = many_till(|b| parse_error_response_field(b, is_err_msg), tag("\x00"))(i)?;
Ok((i, data.0))
}
@ -1015,45 +1086,47 @@ pub fn parse_error_notice_fields(i: &[u8], is_err_msg: bool) -> IResult<&[u8], V
fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'E')(i)?;
let (i, length) = verify(be_u32, |&x| x > 10)(i)?;
let (i, message_body) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
|b| parse_error_notice_fields(b, true)
)(i)?;
let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
parse_error_notice_fields(b, true)
})(i)?;
Ok((i, PgsqlBEMessage::ErrorResponse(ErrorNoticeMessage {
identifier,
length,
message_body,
})))
Ok((
i,
PgsqlBEMessage::ErrorResponse(ErrorNoticeMessage {
identifier,
length,
message_body,
}),
))
}
fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'N')(i)?;
let (i, length) = verify(be_u32, |&x| x > 10)(i)?;
let (i, message_body) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
|b| parse_error_notice_fields(b, false)
)(i)?;
Ok((i, PgsqlBEMessage::NoticeResponse(ErrorNoticeMessage {
identifier,
length,
message_body,
})))
let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
parse_error_notice_fields(b, false)
})(i)?;
Ok((
i,
PgsqlBEMessage::NoticeResponse(ErrorNoticeMessage {
identifier,
length,
message_body,
}),
))
}
fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'A')(i)?;
// length (u32) + pid (u32) + at least one byte, for we have two str fields
let (i, length) = verify(be_u32, |&x| x > 9)(i)?;
let (i, data) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
|b| {
let (b, pid) = be_u32(b)?;
let (b, channel_name) = take_until_and_consume(b"\x00")(b)?;
let (b, payload) = take_until_and_consume(b"\x00")(b)?;
Ok((b, (pid, channel_name, payload)))
})(i)?;
let msg = PgsqlBEMessage::NotificationResponse(NotificationResponse{
let (i, data) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
let (b, pid) = be_u32(b)?;
let (b, channel_name) = take_until_and_consume(b"\x00")(b)?;
let (b, payload) = take_until_and_consume(b"\x00")(b)?;
Ok((b, (pid, channel_name, payload)))
})(i)?;
let msg = PgsqlBEMessage::NotificationResponse(NotificationResponse {
identifier,
length,
pid: data.0,
@ -1065,31 +1138,29 @@ fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, pseudo_header) = peek(tuple((be_u8, be_u32)))(i)?;
let (i, message) =
match pseudo_header.0 {
b'E' => pgsql_parse_error_response(i)?,
b'K' => parse_backend_key_data_message(i)?,
b'N' => pgsql_parse_notice_response(i)?,
b'R' => pgsql_parse_authentication_message(i)?,
b'S' => parse_parameter_status_message(i)?,
b'C' => parse_command_complete(i)?,
b'Z' => parse_ready_for_query(i)?,
b'T' => parse_row_description(i)?,
b'A' => parse_notification_response(i)?,
b'D' => parse_consolidated_data_row(i)?,
_ => {
let (i, identifier) = be_u8(i)?;
let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let unknown = PgsqlBEMessage::UnknownMessageType (RegularPacket{
identifier,
length,
payload: payload.to_vec(),
});
(i, unknown)
}
};
let (i, message) = match pseudo_header.0 {
b'E' => pgsql_parse_error_response(i)?,
b'K' => parse_backend_key_data_message(i)?,
b'N' => pgsql_parse_notice_response(i)?,
b'R' => pgsql_parse_authentication_message(i)?,
b'S' => parse_parameter_status_message(i)?,
b'C' => parse_command_complete(i)?,
b'Z' => parse_ready_for_query(i)?,
b'T' => parse_row_description(i)?,
b'A' => parse_notification_response(i)?,
b'D' => parse_consolidated_data_row(i)?,
_ => {
let (i, identifier) = be_u8(i)?;
let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let unknown = PgsqlBEMessage::UnknownMessageType(RegularPacket {
identifier,
length,
payload: payload.to_vec(),
});
(i, unknown)
}
};
Ok((i, message))
}
@ -1279,7 +1350,6 @@ mod tests {
let result = parse_request(&buf[0..3]);
assert!(result.is_err());
}
#[test]
@ -1289,7 +1359,8 @@ mod tests {
0x00, 0x00, 0x00, 0x10, // length: 16 (fixed)
0x04, 0xd2, 0x16, 0x2e, // 1234.5678 - identifies a cancel request
0x00, 0x00, 0x76, 0x31, // PID: 30257
0x23, 0x84, 0xf7, 0x2d]; // Backend key: 595916589
0x23, 0x84, 0xf7, 0x2d,
]; // Backend key: 595916589
let result = parse_cancel_request(buf);
assert!(result.is_ok());
@ -1309,8 +1380,6 @@ mod tests {
assert!(fail_result.is_err());
}
#[test]
fn test_parse_error_response_code() {
let buf: &[u8] = &[0x43, 0x32, 0x38, 0x30, 0x30, 0x30, 0x00];
@ -1916,7 +1985,8 @@ mod tests {
0x2b, 0x4a, 0x36, 0x79, 0x78, 0x72, 0x66, 0x77, 0x2f, 0x7a, 0x7a, 0x70, 0x38, 0x59,
0x54, 0x39, 0x65, 0x78, 0x56, 0x37, 0x73, 0x38, 0x3d,
];
let (remainder, result) = pgsql_parse_response(bad_buf).expect("parsing sasl final response failed");
let (remainder, result) =
pgsql_parse_response(bad_buf).expect("parsing sasl final response failed");
let res = PgsqlBEMessage::UnknownMessageType(RegularPacket {
identifier: b'`',
length: 54,
@ -2128,55 +2198,34 @@ mod tests {
// S #standard_conforming_strings on S ·TimeZone Europe/Paris
// K ···O··Z ·I
let buf = &[
0x52, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
0x00, 0x53, 0x00, 0x00, 0x00, 0x16, 0x61, 0x70,
0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f,
0x6e, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x00, 0x00,
0x53, 0x00, 0x00, 0x00, 0x19, 0x63, 0x6c, 0x69,
0x65, 0x6e, 0x74, 0x5f, 0x65, 0x6e, 0x63, 0x6f,
0x64, 0x69, 0x6e, 0x67, 0x00, 0x55, 0x54, 0x46,
0x38, 0x00, 0x53, 0x00, 0x00, 0x00, 0x17, 0x44,
0x61, 0x74, 0x65, 0x53, 0x74, 0x79, 0x6c, 0x65,
0x00, 0x49, 0x53, 0x4f, 0x2c, 0x20, 0x4d, 0x44,
0x59, 0x00, 0x53, 0x00, 0x00, 0x00, 0x26, 0x64,
0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x74,
0x72, 0x61, 0x6e, 0x73, 0x61, 0x63, 0x74, 0x69,
0x6f, 0x6e, 0x5f, 0x72, 0x65, 0x61, 0x64, 0x5f,
0x6f, 0x6e, 0x6c, 0x79, 0x00, 0x6f, 0x66, 0x66,
0x00, 0x53, 0x00, 0x00, 0x00, 0x17, 0x69, 0x6e,
0x5f, 0x68, 0x6f, 0x74, 0x5f, 0x73, 0x74, 0x61,
0x6e, 0x64, 0x62, 0x79, 0x00, 0x6f, 0x66, 0x66,
0x00, 0x53, 0x00, 0x00, 0x00, 0x19, 0x69, 0x6e,
0x74, 0x65, 0x67, 0x65, 0x72, 0x5f, 0x64, 0x61,
0x74, 0x65, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x00,
0x6f, 0x6e, 0x00, 0x53, 0x00, 0x00, 0x00, 0x1b,
0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c,
0x53, 0x74, 0x79, 0x6c, 0x65, 0x00, 0x70, 0x6f,
0x73, 0x74, 0x67, 0x72, 0x65, 0x73, 0x00, 0x53,
0x00, 0x00, 0x00, 0x15, 0x69, 0x73, 0x5f, 0x73,
0x75, 0x70, 0x65, 0x72, 0x75, 0x73, 0x65, 0x72,
0x00, 0x6f, 0x66, 0x66, 0x00, 0x53, 0x00, 0x00,
0x00, 0x19, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x5f, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e,
0x67, 0x00, 0x55, 0x54, 0x46, 0x38, 0x00, 0x53,
0x00, 0x00, 0x00, 0x18, 0x73, 0x65, 0x72, 0x76,
0x65, 0x72, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69,
0x6f, 0x6e, 0x00, 0x31, 0x34, 0x2e, 0x35, 0x00,
0x53, 0x00, 0x00, 0x00, 0x22, 0x73, 0x65, 0x73,
0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x75, 0x74,
0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x00, 0x63, 0x74, 0x66, 0x70, 0x6f,
0x73, 0x74, 0x00, 0x53, 0x00, 0x00, 0x00, 0x23,
0x73, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64,
0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x6f, 0x72, 0x6d,
0x69, 0x6e, 0x67, 0x5f, 0x73, 0x74, 0x72, 0x69,
0x6e, 0x67, 0x73, 0x00, 0x6f, 0x6e, 0x00, 0x53,
0x00, 0x00, 0x00, 0x1a, 0x54, 0x69, 0x6d, 0x65,
0x5a, 0x6f, 0x6e, 0x65, 0x00, 0x45, 0x75, 0x72,
0x6f, 0x70, 0x65, 0x2f, 0x50, 0x61, 0x72, 0x69,
0x73, 0x00, 0x4b, 0x00, 0x00, 0x00, 0x0c, 0x00,
0x00, 0x0b, 0x8d, 0xcf, 0x4f, 0xb6, 0xcf, 0x5a,
0x00, 0x00, 0x00, 0x05, 0x49
0x52, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, 0x16,
0x61, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6e, 0x61,
0x6d, 0x65, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, 0x19, 0x63, 0x6c, 0x69, 0x65, 0x6e,
0x74, 0x5f, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x00, 0x55, 0x54, 0x46,
0x38, 0x00, 0x53, 0x00, 0x00, 0x00, 0x17, 0x44, 0x61, 0x74, 0x65, 0x53, 0x74, 0x79,
0x6c, 0x65, 0x00, 0x49, 0x53, 0x4f, 0x2c, 0x20, 0x4d, 0x44, 0x59, 0x00, 0x53, 0x00,
0x00, 0x00, 0x26, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x74, 0x72, 0x61,
0x6e, 0x73, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x72, 0x65, 0x61, 0x64, 0x5f,
0x6f, 0x6e, 0x6c, 0x79, 0x00, 0x6f, 0x66, 0x66, 0x00, 0x53, 0x00, 0x00, 0x00, 0x17,
0x69, 0x6e, 0x5f, 0x68, 0x6f, 0x74, 0x5f, 0x73, 0x74, 0x61, 0x6e, 0x64, 0x62, 0x79,
0x00, 0x6f, 0x66, 0x66, 0x00, 0x53, 0x00, 0x00, 0x00, 0x19, 0x69, 0x6e, 0x74, 0x65,
0x67, 0x65, 0x72, 0x5f, 0x64, 0x61, 0x74, 0x65, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x00,
0x6f, 0x6e, 0x00, 0x53, 0x00, 0x00, 0x00, 0x1b, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76,
0x61, 0x6c, 0x53, 0x74, 0x79, 0x6c, 0x65, 0x00, 0x70, 0x6f, 0x73, 0x74, 0x67, 0x72,
0x65, 0x73, 0x00, 0x53, 0x00, 0x00, 0x00, 0x15, 0x69, 0x73, 0x5f, 0x73, 0x75, 0x70,
0x65, 0x72, 0x75, 0x73, 0x65, 0x72, 0x00, 0x6f, 0x66, 0x66, 0x00, 0x53, 0x00, 0x00,
0x00, 0x19, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x65, 0x6e, 0x63, 0x6f, 0x64,
0x69, 0x6e, 0x67, 0x00, 0x55, 0x54, 0x46, 0x38, 0x00, 0x53, 0x00, 0x00, 0x00, 0x18,
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
0x00, 0x31, 0x34, 0x2e, 0x35, 0x00, 0x53, 0x00, 0x00, 0x00, 0x22, 0x73, 0x65, 0x73,
0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61,
0x74, 0x69, 0x6f, 0x6e, 0x00, 0x63, 0x74, 0x66, 0x70, 0x6f, 0x73, 0x74, 0x00, 0x53,
0x00, 0x00, 0x00, 0x23, 0x73, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, 0x5f, 0x63,
0x6f, 0x6e, 0x66, 0x6f, 0x72, 0x6d, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x74, 0x72, 0x69,
0x6e, 0x67, 0x73, 0x00, 0x6f, 0x6e, 0x00, 0x53, 0x00, 0x00, 0x00, 0x1a, 0x54, 0x69,
0x6d, 0x65, 0x5a, 0x6f, 0x6e, 0x65, 0x00, 0x45, 0x75, 0x72, 0x6f, 0x70, 0x65, 0x2f,
0x50, 0x61, 0x72, 0x69, 0x73, 0x00, 0x4b, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x0b,
0x8d, 0xcf, 0x4f, 0xb6, 0xcf, 0x5a, 0x00, 0x00, 0x00, 0x05, 0x49,
];
let result = pgsql_parse_response(buf);

@ -22,11 +22,11 @@
use super::parser::{self, ConsolidatedDataRowPacket, PgsqlBEMessage, PgsqlFEMessage};
use crate::applayer::*;
use crate::conf::*;
use crate::core::{AppProto, Direction, Flow, ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP, *};
use nom7::{Err, IResult};
use std;
use std::collections::VecDeque;
use std::ffi::CString;
use crate::core::{Flow, AppProto, Direction, ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP, *};
pub const PGSQL_CONFIG_DEFAULT_STREAM_DEPTH: u32 = 0;
@ -341,7 +341,10 @@ impl PgsqlState {
);
match PgsqlState::state_based_req_parsing(self.state_progress, start) {
Ok((rem, request)) => {
sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32);
sc_app_layer_parser_trigger_raw_stream_reassembly(
flow,
Direction::ToServer as i32,
);
start = rem;
if let Some(state) = PgsqlState::request_next_state(&request) {
self.state_progress = state;
@ -471,7 +474,10 @@ impl PgsqlState {
while !start.is_empty() {
match PgsqlState::state_based_resp_parsing(self.state_progress, start) {
Ok((rem, response)) => {
sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32);
sc_app_layer_parser_trigger_raw_stream_reassembly(
flow,
Direction::ToClient as i32,
);
start = rem;
SCLogDebug!("Response is {:?}", &response);
if let Some(state) = self.response_process_next_state(&response, flow) {
@ -557,7 +563,6 @@ pub unsafe extern "C" fn rs_pgsql_probing_parser_ts(
_flow: *const Flow, _direction: u8, input: *const u8, input_len: u32, _rdir: *mut u8,
) -> AppProto {
if input_len >= 1 && !input.is_null() {
let slice: &[u8] = build_slice!(input, input_len as usize);
match parser::parse_request(slice) {
@ -584,7 +589,6 @@ pub unsafe extern "C" fn rs_pgsql_probing_parser_tc(
_flow: *const Flow, _direction: u8, input: *const u8, input_len: u32, _rdir: *mut u8,
) -> AppProto {
if input_len >= 1 && !input.is_null() {
let slice: &[u8] = build_slice!(input, input_len as usize);
if parser::parse_ssl_response(slice).is_ok() {
@ -647,7 +651,6 @@ pub unsafe extern "C" fn rs_pgsql_parse_request(
}
}
let state_safe: &mut PgsqlState = cast_pointer!(state, PgsqlState);
if stream_slice.is_gap() {

Loading…
Cancel
Save