ssl: handshake parsing code cleanup

pull/4922/head
Victor Julien 6 years ago
parent d1ada2e13c
commit ab44b5edac

@ -1406,11 +1406,82 @@ error:
return -1; return -1;
} }
static inline bool
HaveEntireRecord(const SSLStateConnp *curr_connp, const uint32_t input_len)
{
return (curr_connp->bytes_processed + input_len) >=
(curr_connp->record_length + SSLV3_RECORD_HDR_LEN);
}
static inline bool
RecordAlreadyProcessed(const SSLStateConnp *curr_connp)
{
return ((curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
curr_connp->bytes_processed);
}
static inline int SSLv3ParseHandshakeTypeCertificate(SSLState *ssl_state,
const uint8_t * const initial_input,
const uint32_t input_len)
{
if (EnsureRecordSpace(ssl_state->curr_connp, input_len) < 0) {
/* error, skip packet */
ssl_state->curr_connp->bytes_processed += input_len;
return -1;
}
uint32_t write_len = 0;
if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
if (RecordAlreadyProcessed(ssl_state->curr_connp)) {
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
write_len = (ssl_state->curr_connp->record_length +
SSLV3_RECORD_HDR_LEN) - ssl_state->curr_connp->bytes_processed;
} else {
write_len = input_len;
}
if (SafeMemcpy(ssl_state->curr_connp->trec,
ssl_state->curr_connp->trec_pos,
ssl_state->curr_connp->trec_len,
initial_input, 0, input_len, write_len) != 0) {
return -1;
}
ssl_state->curr_connp->trec_pos += write_len;
int rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec,
ssl_state->curr_connp->trec_pos);
if (rc > 0) {
/* do not return normally if the packet was fragmented:
we would return the size of the _entire_ message,
while we expect only the number of bytes parsed bytes
from the _current_ fragment */
if (write_len < (ssl_state->curr_connp->trec_pos - rc)) {
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
uint32_t diff = write_len -
(ssl_state->curr_connp->trec_pos - rc);
ssl_state->curr_connp->bytes_processed += diff;
ssl_state->curr_connp->trec_pos = 0;
ssl_state->curr_connp->handshake_type = 0;
ssl_state->curr_connp->hs_bytes_processed = 0;
ssl_state->curr_connp->message_length = 0;
return diff;
} else {
ssl_state->curr_connp->bytes_processed += write_len;
return write_len;
}
}
static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input, static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
uint32_t input_len, uint8_t direction) uint32_t input_len, uint8_t direction)
{ {
const uint8_t *initial_input = input; const uint8_t *initial_input = input;
uint32_t parsed = 0;
int rc; int rc;
if (input_len == 0) { if (input_len == 0) {
@ -1425,7 +1496,6 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
if (input_len >= ssl_state->curr_connp->message_length && if (input_len >= ssl_state->curr_connp->message_length &&
input_len >= 40) { input_len >= 40) {
rc = TLSDecodeHandshakeHello(ssl_state, input, input_len); rc = TLSDecodeHandshakeHello(ssl_state, input, input_len);
if (rc < 0) if (rc < 0)
return rc; return rc;
} }
@ -1440,7 +1510,6 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
input_len >= 40) { input_len >= 40) {
rc = TLSDecodeHandshakeHello(ssl_state, input, rc = TLSDecodeHandshakeHello(ssl_state, input,
ssl_state->curr_connp->message_length); ssl_state->curr_connp->message_length);
if (rc < 0) if (rc < 0)
return rc; return rc;
} }
@ -1462,72 +1531,10 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
"direction!"); "direction!");
break; break;
} }
return SSLv3ParseHandshakeTypeCertificate(ssl_state,
if (EnsureRecordSpace(ssl_state->curr_connp, input_len) < 0) { initial_input, input_len);
/* error, skip packet */
parsed += input_len;
(void)parsed; /* for scan-build */
ssl_state->curr_connp->bytes_processed += input_len;
return -1;
}
uint32_t write_len = 0;
if ((ssl_state->curr_connp->bytes_processed + input_len) >
ssl_state->curr_connp->record_length +
(SSLV3_RECORD_HDR_LEN)) {
if ((ssl_state->curr_connp->record_length +
SSLV3_RECORD_HDR_LEN) <
ssl_state->curr_connp->bytes_processed) {
SSLSetEvent(ssl_state,
TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
write_len = (ssl_state->curr_connp->record_length +
SSLV3_RECORD_HDR_LEN) -
ssl_state->curr_connp->bytes_processed;
} else {
write_len = input_len;
}
if (SafeMemcpy(ssl_state->curr_connp->trec,
ssl_state->curr_connp->trec_pos,
ssl_state->curr_connp->trec_len,
initial_input, 0, input_len, write_len) != 0) {
return -1;
}
ssl_state->curr_connp->trec_pos += write_len;
rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec,
ssl_state->curr_connp->trec_pos);
if (rc > 0) {
/* do not return normally if the packet was fragmented:
we would return the size of the _entire_ message,
while we expect only the number of bytes parsed bytes
from the _current_ fragment */
if (write_len < (ssl_state->curr_connp->trec_pos - rc)) {
SSLSetEvent(ssl_state,
TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1;
}
uint32_t diff = write_len -
(ssl_state->curr_connp->trec_pos - rc);
ssl_state->curr_connp->bytes_processed += diff;
ssl_state->curr_connp->trec_pos = 0;
ssl_state->curr_connp->handshake_type = 0;
ssl_state->curr_connp->hs_bytes_processed = 0;
ssl_state->curr_connp->message_length = 0;
return diff;
} else {
ssl_state->curr_connp->bytes_processed += write_len;
parsed += write_len;
return parsed;
}
break; break;
case SSLV3_HS_HELLO_REQUEST: case SSLV3_HS_HELLO_REQUEST:
case SSLV3_HS_CERTIFICATE_REQUEST: case SSLV3_HS_CERTIFICATE_REQUEST:
case SSLV3_HS_CERTIFICATE_VERIFY: case SSLV3_HS_CERTIFICATE_VERIFY:
@ -1546,10 +1553,8 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
ssl_state->flags |= ssl_state->current_flags; ssl_state->flags |= ssl_state->current_flags;
uint32_t write_len = 0; uint32_t write_len = 0;
if ((ssl_state->curr_connp->bytes_processed + input_len) >= if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
ssl_state->curr_connp->record_length + (SSLV3_RECORD_HDR_LEN)) { if (RecordAlreadyProcessed(ssl_state->curr_connp)) {
if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
ssl_state->curr_connp->bytes_processed) {
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD); SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1; return -1;
} }
@ -1566,24 +1571,18 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD); SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
return -1; return -1;
} }
parsed += ssl_state->curr_connp->message_length - const uint32_t parsed = ssl_state->curr_connp->message_length -
ssl_state->curr_connp->trec_pos; ssl_state->curr_connp->trec_pos;
ssl_state->curr_connp->bytes_processed += parsed;
ssl_state->curr_connp->bytes_processed +=
ssl_state->curr_connp->message_length -
ssl_state->curr_connp->trec_pos;
ssl_state->curr_connp->handshake_type = 0; ssl_state->curr_connp->handshake_type = 0;
ssl_state->curr_connp->hs_bytes_processed = 0; ssl_state->curr_connp->hs_bytes_processed = 0;
ssl_state->curr_connp->message_length = 0; ssl_state->curr_connp->message_length = 0;
ssl_state->curr_connp->trec_pos = 0; ssl_state->curr_connp->trec_pos = 0;
return parsed; return parsed;
} else { } else {
ssl_state->curr_connp->trec_pos += write_len; ssl_state->curr_connp->trec_pos += write_len;
ssl_state->curr_connp->bytes_processed += write_len; ssl_state->curr_connp->bytes_processed += write_len;
parsed += write_len; return write_len;
return parsed;
} }
} }
@ -2221,18 +2220,16 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
AppLayerParserState *pstate, const uint8_t *input, AppLayerParserState *pstate, const uint8_t *input,
uint32_t input_len) uint32_t input_len)
{ {
int retval = 0;
uint32_t parsed = 0; uint32_t parsed = 0;
if (ssl_state->curr_connp->bytes_processed < SSLV3_RECORD_HDR_LEN) { if (ssl_state->curr_connp->bytes_processed < SSLV3_RECORD_HDR_LEN) {
retval = SSLv3ParseRecord(direction, ssl_state, input, input_len); int retval = SSLv3ParseRecord(direction, ssl_state, input, input_len);
if (retval < 0) { if (retval < 0) {
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_TLS_HEADER); SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_TLS_HEADER);
return -1; return -1;
} else {
parsed += retval;
input_len -= retval;
} }
parsed += retval;
input_len -= retval;
} }
if (input_len == 0) { if (input_len == 0) {
@ -2296,7 +2293,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
break; break;
case SSLV3_HANDSHAKE_PROTOCOL: case SSLV3_HANDSHAKE_PROTOCOL: {
if (ssl_state->flags & SSL_AL_FLAG_CHANGE_CIPHER_SPEC) { if (ssl_state->flags & SSL_AL_FLAG_CHANGE_CIPHER_SPEC) {
/* In TLSv1.3, ChangeCipherSpec is only used for middlebox /* In TLSv1.3, ChangeCipherSpec is only used for middlebox
compability (rfc8446, appendix D.4). */ compability (rfc8446, appendix D.4). */
@ -2314,8 +2311,8 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
return -1; return -1;
} }
retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed, int retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed,
input_len, direction); input_len, direction);
if (retval < 0) { if (retval < 0) {
SSLSetEvent(ssl_state, SSLSetEvent(ssl_state,
TLS_DECODER_EVENT_INVALID_HANDSHAKE_MESSAGE); TLS_DECODER_EVENT_INVALID_HANDSHAKE_MESSAGE);
@ -2349,15 +2346,15 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
} }
break; break;
}
case SSLV3_HEARTBEAT_PROTOCOL: case SSLV3_HEARTBEAT_PROTOCOL: {
retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed, int retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed,
input_len, direction); input_len, direction);
if (retval < 0) if (retval < 0)
return -1; return -1;
break; break;
}
default: default:
/* \todo fix the event from invalid rule to unknown rule */ /* \todo fix the event from invalid rule to unknown rule */
SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_RECORD_TYPE); SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_RECORD_TYPE);
@ -2365,8 +2362,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
return -1; return -1;
} }
if (input_len + ssl_state->curr_connp->bytes_processed >= if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) {
if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) < if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
ssl_state->curr_connp->bytes_processed) { ssl_state->curr_connp->bytes_processed) {
/* defensive checks. Something is wrong. */ /* defensive checks. Something is wrong. */
@ -2392,7 +2388,6 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
ssl_state->curr_connp->bytes_processed += input_len; ssl_state->curr_connp->bytes_processed += input_len;
return parsed; return parsed;
} }
} }
/** /**

Loading…
Cancel
Save