ssl: handshake parsing code cleanup

pull/4922/head
Victor Julien 6 years ago
parent d1ada2e13c
commit ab44b5edac

@ -1406,11 +1406,82 @@ error:
return -1;
}
static inline bool
HaveEntireRecord(const SSLStateConnp *curr_connp, const uint32_t input_len)
{
return (curr_connp->bytes_processed + input_len) >=
(curr_connp->record_length + SSLV3_RECORD_HDR_LEN);
}
static inline bool
RecordAlreadyProcessed(const SSLStateConnp *curr_connp)
{
return ((curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
curr_connp->bytes_processed);
}
static inline int SSLv3ParseHandshakeTypeCertificate(SSLState *ssl_state,
const uint8_t * const initial_input,
const uint32_t input_len)
{
if (EnsureRecordSpace(ssl_state->curr_connp, input_len) < 0) {
/* error, skip packet */
ssl_state->curr_connp->bytes_processed += input_len;
return -1;
}
uint32_t write_len = 0;
if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
if (RecordAlreadyProcessed(ssl_state->curr_connp)) {
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
write_len = (ssl_state->curr_connp->record_length +
SSLV3_RECORD_HDR_LEN) - ssl_state->curr_connp->bytes_processed;
} else {
write_len = input_len;
}
if (SafeMemcpy(ssl_state->curr_connp->trec,
ssl_state->curr_connp->trec_pos,
ssl_state->curr_connp->trec_len,
initial_input, 0, input_len, write_len) != 0) {
return -1;
}
ssl_state->curr_connp->trec_pos += write_len;
int rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec,
ssl_state->curr_connp->trec_pos);
if (rc > 0) {
/* do not return normally if the packet was fragmented:
we would return the size of the _entire_ message,
while we expect only the number of bytes parsed bytes
from the _current_ fragment */
if (write_len < (ssl_state->curr_connp->trec_pos - rc)) {
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
uint32_t diff = write_len -
(ssl_state->curr_connp->trec_pos - rc);
ssl_state->curr_connp->bytes_processed += diff;
ssl_state->curr_connp->trec_pos = 0;
ssl_state->curr_connp->handshake_type = 0;
ssl_state->curr_connp->hs_bytes_processed = 0;
ssl_state->curr_connp->message_length = 0;
return diff;
} else {
ssl_state->curr_connp->bytes_processed += write_len;
return write_len;
}
}
static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
uint32_t input_len, uint8_t direction)
{
const uint8_t *initial_input = input;
uint32_t parsed = 0;
int rc;
if (input_len == 0) {
@ -1425,7 +1496,6 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
if (input_len >= ssl_state->curr_connp->message_length &&
input_len >= 40) {
rc = TLSDecodeHandshakeHello(ssl_state, input, input_len);
if (rc < 0)
return rc;
}
@ -1440,7 +1510,6 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
input_len >= 40) {
rc = TLSDecodeHandshakeHello(ssl_state, input,
ssl_state->curr_connp->message_length);
if (rc < 0)
return rc;
}
@ -1462,72 +1531,10 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
"direction!");
break;
}
if (EnsureRecordSpace(ssl_state->curr_connp, input_len) < 0) {
/* error, skip packet */
parsed += input_len;
(void)parsed; /* for scan-build */
ssl_state->curr_connp->bytes_processed += input_len;
return -1;
}
uint32_t write_len = 0;
if ((ssl_state->curr_connp->bytes_processed + input_len) >
ssl_state->curr_connp->record_length +
(SSLV3_RECORD_HDR_LEN)) {
if ((ssl_state->curr_connp->record_length +
SSLV3_RECORD_HDR_LEN) <
ssl_state->curr_connp->bytes_processed) {
SSLSetEvent(ssl_state,
TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
write_len = (ssl_state->curr_connp->record_length +
SSLV3_RECORD_HDR_LEN) -
ssl_state->curr_connp->bytes_processed;
} else {
write_len = input_len;
}
if (SafeMemcpy(ssl_state->curr_connp->trec,
ssl_state->curr_connp->trec_pos,
ssl_state->curr_connp->trec_len,
initial_input, 0, input_len, write_len) != 0) {
return -1;
}
ssl_state->curr_connp->trec_pos += write_len;
rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec,
ssl_state->curr_connp->trec_pos);
if (rc > 0) {
/* do not return normally if the packet was fragmented:
we would return the size of the _entire_ message,
while we expect only the number of bytes parsed bytes
from the _current_ fragment */
if (write_len < (ssl_state->curr_connp->trec_pos - rc)) {
SSLSetEvent(ssl_state,
TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
uint32_t diff = write_len -
(ssl_state->curr_connp->trec_pos - rc);
ssl_state->curr_connp->bytes_processed += diff;
ssl_state->curr_connp->trec_pos = 0;
ssl_state->curr_connp->handshake_type = 0;
ssl_state->curr_connp->hs_bytes_processed = 0;
ssl_state->curr_connp->message_length = 0;
return diff;
} else {
ssl_state->curr_connp->bytes_processed += write_len;
parsed += write_len;
return parsed;
}
return SSLv3ParseHandshakeTypeCertificate(ssl_state,
initial_input, input_len);
break;
case SSLV3_HS_HELLO_REQUEST:
case SSLV3_HS_CERTIFICATE_REQUEST:
case SSLV3_HS_CERTIFICATE_VERIFY:
@ -1546,10 +1553,8 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
ssl_state->flags |= ssl_state->current_flags;
uint32_t write_len = 0;
if ((ssl_state->curr_connp->bytes_processed + input_len) >=
ssl_state->curr_connp->record_length + (SSLV3_RECORD_HDR_LEN)) {
if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
ssl_state->curr_connp->bytes_processed) {
if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
if (RecordAlreadyProcessed(ssl_state->curr_connp)) {
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
@ -1566,24 +1571,18 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
parsed += ssl_state->curr_connp->message_length -
const uint32_t parsed = ssl_state->curr_connp->message_length -
ssl_state->curr_connp->trec_pos;
ssl_state->curr_connp->bytes_processed +=
ssl_state->curr_connp->message_length -
ssl_state->curr_connp->trec_pos;
ssl_state->curr_connp->bytes_processed += parsed;
ssl_state->curr_connp->handshake_type = 0;
ssl_state->curr_connp->hs_bytes_processed = 0;
ssl_state->curr_connp->message_length = 0;
ssl_state->curr_connp->trec_pos = 0;
return parsed;
} else {
ssl_state->curr_connp->trec_pos += write_len;
ssl_state->curr_connp->bytes_processed += write_len;
parsed += write_len;
return parsed;
return write_len;
}
}
@ -2221,18 +2220,16 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
AppLayerParserState *pstate, const uint8_t *input,
uint32_t input_len)
{
int retval = 0;
uint32_t parsed = 0;
if (ssl_state->curr_connp->bytes_processed < SSLV3_RECORD_HDR_LEN) {
retval = SSLv3ParseRecord(direction, ssl_state, input, input_len);
int retval = SSLv3ParseRecord(direction, ssl_state, input, input_len);
if (retval < 0) {
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_TLS_HEADER);
return -1;
} else {
parsed += retval;
input_len -= retval;
}
parsed += retval;
input_len -= retval;
}
if (input_len == 0) {
@ -2296,7 +2293,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
break;
case SSLV3_HANDSHAKE_PROTOCOL:
case SSLV3_HANDSHAKE_PROTOCOL: {
if (ssl_state->flags & SSL_AL_FLAG_CHANGE_CIPHER_SPEC) {
/* In TLSv1.3, ChangeCipherSpec is only used for middlebox
compability (rfc8446, appendix D.4). */
@ -2314,8 +2311,8 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
return -1;
}
retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed,
input_len, direction);
int retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed,
input_len, direction);
if (retval < 0) {
SSLSetEvent(ssl_state,
TLS_DECODER_EVENT_INVALID_HANDSHAKE_MESSAGE);
@ -2349,15 +2346,15 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
}
break;
case SSLV3_HEARTBEAT_PROTOCOL:
retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed,
}
case SSLV3_HEARTBEAT_PROTOCOL: {
int retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed,
input_len, direction);
if (retval < 0)
return -1;
break;
}
default:
/* \todo fix the event from invalid rule to unknown rule */
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_RECORD_TYPE);
@ -2365,8 +2362,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
return -1;
}
if (input_len + ssl_state->curr_connp->bytes_processed >=
ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) {
if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
ssl_state->curr_connp->bytes_processed) {
/* defensive checks. Something is wrong. */
@ -2392,7 +2388,6 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
ssl_state->curr_connp->bytes_processed += input_len;
return parsed;
}
}
/**

Loading…
Cancel
Save