pgsql/parser: always use fn for parsing PDU length

Some inner parsers were using it, some weren't. Better to standardize
this. Also take the time to avoid magic numbers for representing the
expected lengths for pgsql PDUs.
Also throwing PgsqlParseError and allowing for incomplete results.

Related to
Task #5566
Bug #5524
pull/12625/head
Juliana Fajardini 2 weeks ago committed by Victor Julien
parent 29d3aa7a6a
commit cc841e66db

@ -55,9 +55,22 @@ impl<I> ParseError<I> for PgsqlParseError<I> {
}
}
fn parse_length(i: &[u8]) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
fn parse_gte_length(i: &[u8], expected_length: u32) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| {
x >= PGSQL_LENGTH_FIELD
x >= expected_length
})(i);
match res {
Ok(result) => Ok((result.0, result.1)),
Err(nom7::Err::Incomplete(needed)) => Err(Err::Incomplete(needed)),
Err(_) => Err(Err::Error(PgsqlParseError::InvalidLength)),
}
}
fn parse_exact_length(
i: &[u8], expected_length: u32,
) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| {
x == expected_length
})(i);
match res {
Ok(result) => Ok((result.0, result.1)),
@ -612,7 +625,7 @@ fn parse_sasl_initial_response_payload(
pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = parse_length(i)?;
let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
parse_sasl_initial_response_payload,
@ -631,7 +644,7 @@ pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, P
pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = parse_length(i)?;
let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let resp = PgsqlFEMessage::SASLResponse(RegularPacket {
identifier,
@ -698,7 +711,7 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, Pg
// Password can be encrypted or in cleartext
pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = parse_length(i)?;
let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, password) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
Ok((
i,
@ -712,7 +725,7 @@ pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlP
fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?;
let (i, length) = parse_length(i)?;
let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
Ok((
i,
@ -735,7 +748,7 @@ fn parse_cancel_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseEr
fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?;
let (i, length) = parse_length(i)?;
let (i, length) = parse_exact_length(i, PGSQL_LENGTH_FIELD)?;
Ok((
i,
PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }),
@ -751,7 +764,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError
b'X' => parse_terminate_message(i)?,
_ => {
let (i, identifier) = be_u8(i)?;
let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let unknown = PgsqlFEMessage::UnknownMessageType(RegularPacket {
identifier,
@ -766,7 +779,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError
fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], PgsqlBEMessage, PgsqlParseError<&'a [u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'R')(i)?;
let (i, length) = verify(be_u32, |&x| x >= 8)(i)?;
let (i, length) = parse_gte_length(i, 8)?;
let (i, auth_type) = be_u32(i)?;
let (i, message) = map_parser(take(length - 8), |b: &'a [u8]| {
match auth_type {
@ -849,7 +862,7 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq
fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?;
let (i, length) = parse_length(i)?;
let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, param) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
pgsql_parse_generic_parameter,
@ -874,7 +887,7 @@ pub fn parse_ssl_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParse
fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'K')(i)?;
let (i, length) = verify(be_u32, |&x| x == 12)(i)?;
let (i, length) = parse_exact_length(i, 12)?;
let (i, pid) = be_u32(i)?;
let (i, secret_key) = be_u32(i)?;
Ok((
@ -890,7 +903,7 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pg
fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?;
let (i, length) = parse_length(i)?;
let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?;
Ok((
i,
@ -904,7 +917,7 @@ fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParse
fn parse_ready_for_query(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'Z')(i)?;
let (i, length) = verify(be_u32, |&x| x == 5)(i)?;
let (i, length) = parse_exact_length(i, 5)?;
let (i, status) = verify(be_u8, |&x| x == b'I' || x == b'T' || x == b'E')(i)?;
Ok((
i,
@ -941,7 +954,7 @@ fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField, PgsqlParseError<&[u8]>>
pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'T')(i)?;
let (i, length) = verify(be_u32, |&x| x > 6)(i)?;
let (i, length) = parse_gte_length(i, 7)?;
let (i, field_count) = be_u16(i)?;
let (i, fields) = map_parser(
take(length - 6),
@ -992,7 +1005,7 @@ fn add_up_data_size(columns: Vec<ColumnFieldValue>) -> u64 {
// Later on, we calculate the number of lines the command actually returned by counting ConsolidatedDataRow messages
pub fn parse_consolidated_data_row(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'D')(i)?;
let (i, length) = verify(be_u32, |&x| x >= 6)(i)?;
let (i, length) = parse_gte_length(i, 7)?;
let (i, field_count) = be_u16(i)?;
// 6 here is for skipping length + field_count
let (i, rows) = map_parser(
@ -1109,7 +1122,7 @@ pub fn parse_error_notice_fields(
fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'E')(i)?;
let (i, length) = verify(be_u32, |&x| x > 10)(i)?;
let (i, length) = parse_gte_length(i, 11)?;
let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
parse_error_notice_fields(b, true)
})(i)?;
@ -1126,7 +1139,7 @@ fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlP
fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'N')(i)?;
let (i, length) = verify(be_u32, |&x| x > 10)(i)?;
let (i, length) = parse_gte_length(i, 11)?;
let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
parse_error_notice_fields(b, false)
})(i)?;
@ -1143,7 +1156,7 @@ fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pgsql
fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'A')(i)?;
// length (u32) + pid (u32) + at least one byte, for we have two str fields
let (i, length) = verify(be_u32, |&x| x > 9)(i)?;
let (i, length) = parse_gte_length(i, 10)?;
let (i, data) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
let (b, pid) = be_u32(b)?;
let (b, channel_name) = take_until_and_consume(b"\x00")(b)?;
@ -1175,7 +1188,7 @@ pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlPar
b'D' => parse_consolidated_data_row(i)?,
_ => {
let (i, identifier) = be_u8(i)?;
let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let unknown = PgsqlBEMessage::UnknownMessageType(RegularPacket {
identifier,

Loading…
Cancel
Save