diff options
Diffstat (limited to 'ext/olm/src')
| -rw-r--r-- | ext/olm/src/account.cpp | 580 | ||||
| -rw-r--r-- | ext/olm/src/base64.cpp | 187 | ||||
| -rw-r--r-- | ext/olm/src/cipher.cpp | 152 | ||||
| -rw-r--r-- | ext/olm/src/crypto.cpp | 299 | ||||
| -rw-r--r-- | ext/olm/src/ed25519.c | 22 | ||||
| -rw-r--r-- | ext/olm/src/error.c | 46 | ||||
| -rw-r--r-- | ext/olm/src/inbound_group_session.c | 540 | ||||
| -rw-r--r-- | ext/olm/src/megolm.c | 154 | ||||
| -rw-r--r-- | ext/olm/src/memory.cpp | 45 | ||||
| -rw-r--r-- | ext/olm/src/message.cpp | 406 | ||||
| -rw-r--r-- | ext/olm/src/olm.cpp | 846 | ||||
| -rw-r--r-- | ext/olm/src/outbound_group_session.c | 390 | ||||
| -rw-r--r-- | ext/olm/src/pickle.cpp | 274 | ||||
| -rw-r--r-- | ext/olm/src/pickle_encoding.c | 92 | ||||
| -rw-r--r-- | ext/olm/src/pk.cpp | 542 | ||||
| -rw-r--r-- | ext/olm/src/ratchet.cpp | 625 | ||||
| -rw-r--r-- | ext/olm/src/sas.c | 229 | ||||
| -rw-r--r-- | ext/olm/src/session.cpp | 531 | ||||
| -rw-r--r-- | ext/olm/src/utility.cpp | 57 |
19 files changed, 6017 insertions, 0 deletions
diff --git a/ext/olm/src/account.cpp b/ext/olm/src/account.cpp new file mode 100644 index 0000000..41b7188 --- /dev/null +++ b/ext/olm/src/account.cpp @@ -0,0 +1,580 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/account.hh" +#include "olm/base64.hh" +#include "olm/pickle.h" +#include "olm/pickle.hh" +#include "olm/memory.hh" + +olm::Account::Account( +) : num_fallback_keys(0), + next_one_time_key_id(0), + last_error(OlmErrorCode::OLM_SUCCESS) { +} + + +olm::OneTimeKey const * olm::Account::lookup_key( + _olm_curve25519_public_key const & public_key +) { + for (olm::OneTimeKey const & key : one_time_keys) { + if (olm::array_equal(key.key.public_key.public_key, public_key.public_key)) { + return &key; + } + } + if (num_fallback_keys >= 1 + && olm::array_equal( + current_fallback_key.key.public_key.public_key, public_key.public_key + ) + ) { + return ¤t_fallback_key; + } + if (num_fallback_keys >= 2 + && olm::array_equal( + prev_fallback_key.key.public_key.public_key, public_key.public_key + ) + ) { + return &prev_fallback_key; + } + return 0; +} + +std::size_t olm::Account::remove_key( + _olm_curve25519_public_key const & public_key +) { + OneTimeKey * i; + for (i = one_time_keys.begin(); i != one_time_keys.end(); ++i) { + if (olm::array_equal(i->key.public_key.public_key, public_key.public_key)) { + std::uint32_t id = i->id; + one_time_keys.erase(i); + return id; + } + } + // check if the key is a fallback key, to avoid returning an error, but + // don't actually remove it + if (num_fallback_keys >= 1 + && olm::array_equal( + current_fallback_key.key.public_key.public_key, public_key.public_key + ) + ) { + return current_fallback_key.id; + } + if (num_fallback_keys >= 2 + && olm::array_equal( + prev_fallback_key.key.public_key.public_key, public_key.public_key + ) + ) { + return prev_fallback_key.id; + } + return std::size_t(-1); +} + +std::size_t olm::Account::new_account_random_length() const { + return ED25519_RANDOM_LENGTH + CURVE25519_RANDOM_LENGTH; +} + +std::size_t olm::Account::new_account( + uint8_t const * random, std::size_t random_length +) { + if (random_length < new_account_random_length()) { + last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM; + return std::size_t(-1); + } + + _olm_crypto_ed25519_generate_key(random, &identity_keys.ed25519_key); + random += ED25519_RANDOM_LENGTH; + _olm_crypto_curve25519_generate_key(random, &identity_keys.curve25519_key); + + return 0; +} + +namespace { + +uint8_t KEY_JSON_ED25519[] = "\"ed25519\":"; +uint8_t KEY_JSON_CURVE25519[] = "\"curve25519\":"; + +template<typename T> +static std::uint8_t * write_string( + std::uint8_t * pos, + T const & value +) { + std::memcpy(pos, value, sizeof(T) - 1); + return pos + (sizeof(T) - 1); +} + +} + + +std::size_t olm::Account::get_identity_json_length() const { + std::size_t length = 0; + length += 1; /* { */ + length += sizeof(KEY_JSON_CURVE25519) - 1; + length += 1; /* " */ + length += olm::encode_base64_length( + sizeof(identity_keys.curve25519_key.public_key) + ); + length += 2; /* ", */ + length += sizeof(KEY_JSON_ED25519) - 1; + length += 1; /* " */ + length += olm::encode_base64_length( + sizeof(identity_keys.ed25519_key.public_key) + ); + length += 2; /* "} */ + return length; +} + + +std::size_t olm::Account::get_identity_json( + std::uint8_t * identity_json, std::size_t identity_json_length +) { + std::uint8_t * pos = identity_json; + size_t expected_length = get_identity_json_length(); + + if (identity_json_length < expected_length) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + + *(pos++) = '{'; + pos = write_string(pos, KEY_JSON_CURVE25519); + *(pos++) = '\"'; + pos = olm::encode_base64( + identity_keys.curve25519_key.public_key.public_key, + sizeof(identity_keys.curve25519_key.public_key.public_key), + pos + ); + *(pos++) = '\"'; *(pos++) = ','; + pos = write_string(pos, KEY_JSON_ED25519); + *(pos++) = '\"'; + pos = olm::encode_base64( + identity_keys.ed25519_key.public_key.public_key, + sizeof(identity_keys.ed25519_key.public_key.public_key), + pos + ); + *(pos++) = '\"'; *(pos++) = '}'; + return pos - identity_json; +} + + +std::size_t olm::Account::signature_length( +) const { + return ED25519_SIGNATURE_LENGTH; +} + + +std::size_t olm::Account::sign( + std::uint8_t const * message, std::size_t message_length, + std::uint8_t * signature, std::size_t signature_length +) { + if (signature_length < this->signature_length()) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + _olm_crypto_ed25519_sign( + &identity_keys.ed25519_key, message, message_length, signature + ); + return this->signature_length(); +} + + +std::size_t olm::Account::get_one_time_keys_json_length( +) const { + std::size_t length = 0; + bool is_empty = true; + for (auto const & key : one_time_keys) { + if (key.published) { + continue; + } + is_empty = false; + length += 2; /* {" */ + length += olm::encode_base64_length(_olm_pickle_uint32_length(key.id)); + length += 3; /* ":" */ + length += olm::encode_base64_length(sizeof(key.key.public_key)); + length += 1; /* " */ + } + if (is_empty) { + length += 1; /* { */ + } + length += 3; /* }{} */ + length += sizeof(KEY_JSON_CURVE25519) - 1; + return length; +} + + +std::size_t olm::Account::get_one_time_keys_json( + std::uint8_t * one_time_json, std::size_t one_time_json_length +) { + std::uint8_t * pos = one_time_json; + if (one_time_json_length < get_one_time_keys_json_length()) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + *(pos++) = '{'; + pos = write_string(pos, KEY_JSON_CURVE25519); + std::uint8_t sep = '{'; + for (auto const & key : one_time_keys) { + if (key.published) { + continue; + } + *(pos++) = sep; + *(pos++) = '\"'; + std::uint8_t key_id[_olm_pickle_uint32_length(key.id)]; + _olm_pickle_uint32(key_id, key.id); + pos = olm::encode_base64(key_id, sizeof(key_id), pos); + *(pos++) = '\"'; *(pos++) = ':'; *(pos++) = '\"'; + pos = olm::encode_base64( + key.key.public_key.public_key, sizeof(key.key.public_key.public_key), pos + ); + *(pos++) = '\"'; + sep = ','; + } + if (sep != ',') { + /* The list was empty */ + *(pos++) = sep; + } + *(pos++) = '}'; + *(pos++) = '}'; + return pos - one_time_json; +} + + +std::size_t olm::Account::mark_keys_as_published( +) { + std::size_t count = 0; + for (auto & key : one_time_keys) { + if (!key.published) { + key.published = true; + count++; + } + } + current_fallback_key.published = true; + return count; +} + + +std::size_t olm::Account::max_number_of_one_time_keys( +) const { + return olm::MAX_ONE_TIME_KEYS; +} + +std::size_t olm::Account::generate_one_time_keys_random_length( + std::size_t number_of_keys +) const { + return CURVE25519_RANDOM_LENGTH * number_of_keys; +} + +std::size_t olm::Account::generate_one_time_keys( + std::size_t number_of_keys, + std::uint8_t const * random, std::size_t random_length +) { + if (random_length < generate_one_time_keys_random_length(number_of_keys)) { + last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM; + return std::size_t(-1); + } + for (unsigned i = 0; i < number_of_keys; ++i) { + OneTimeKey & key = *one_time_keys.insert(one_time_keys.begin()); + key.id = ++next_one_time_key_id; + key.published = false; + _olm_crypto_curve25519_generate_key(random, &key.key); + random += CURVE25519_RANDOM_LENGTH; + } + return number_of_keys; +} + +std::size_t olm::Account::generate_fallback_key_random_length() const { + return CURVE25519_RANDOM_LENGTH; +} + +std::size_t olm::Account::generate_fallback_key( + std::uint8_t const * random, std::size_t random_length +) { + if (random_length < generate_fallback_key_random_length()) { + last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM; + return std::size_t(-1); + } + if (num_fallback_keys < 2) { + num_fallback_keys++; + } + prev_fallback_key = current_fallback_key; + current_fallback_key.id = ++next_one_time_key_id; + current_fallback_key.published = false; + _olm_crypto_curve25519_generate_key(random, ¤t_fallback_key.key); + return 1; +} + + +std::size_t olm::Account::get_fallback_key_json_length( +) const { + std::size_t length = 4 + sizeof(KEY_JSON_CURVE25519) - 1; /* {"curve25519":{}} */ + if (num_fallback_keys >= 1) { + const OneTimeKey & key = current_fallback_key; + length += 1; /* " */ + length += olm::encode_base64_length(_olm_pickle_uint32_length(key.id)); + length += 3; /* ":" */ + length += olm::encode_base64_length(sizeof(key.key.public_key)); + length += 1; /* " */ + } + return length; +} + +std::size_t olm::Account::get_fallback_key_json( + std::uint8_t * fallback_json, std::size_t fallback_json_length +) { + std::uint8_t * pos = fallback_json; + if (fallback_json_length < get_fallback_key_json_length()) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + *(pos++) = '{'; + pos = write_string(pos, KEY_JSON_CURVE25519); + *(pos++) = '{'; + OneTimeKey & key = current_fallback_key; + if (num_fallback_keys >= 1) { + *(pos++) = '\"'; + std::uint8_t key_id[_olm_pickle_uint32_length(key.id)]; + _olm_pickle_uint32(key_id, key.id); + pos = olm::encode_base64(key_id, sizeof(key_id), pos); + *(pos++) = '\"'; *(pos++) = ':'; *(pos++) = '\"'; + pos = olm::encode_base64( + key.key.public_key.public_key, sizeof(key.key.public_key.public_key), pos + ); + *(pos++) = '\"'; + } + *(pos++) = '}'; + *(pos++) = '}'; + return pos - fallback_json; +} + +std::size_t olm::Account::get_unpublished_fallback_key_json_length( +) const { + std::size_t length = 4 + sizeof(KEY_JSON_CURVE25519) - 1; /* {"curve25519":{}} */ + const OneTimeKey & key = current_fallback_key; + if (num_fallback_keys >= 1 && !key.published) { + length += 1; /* " */ + length += olm::encode_base64_length(_olm_pickle_uint32_length(key.id)); + length += 3; /* ":" */ + length += olm::encode_base64_length(sizeof(key.key.public_key)); + length += 1; /* " */ + } + return length; +} + +std::size_t olm::Account::get_unpublished_fallback_key_json( + std::uint8_t * fallback_json, std::size_t fallback_json_length +) { + std::uint8_t * pos = fallback_json; + if (fallback_json_length < get_unpublished_fallback_key_json_length()) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + *(pos++) = '{'; + pos = write_string(pos, KEY_JSON_CURVE25519); + *(pos++) = '{'; + OneTimeKey & key = current_fallback_key; + if (num_fallback_keys >= 1 && !key.published) { + *(pos++) = '\"'; + std::uint8_t key_id[_olm_pickle_uint32_length(key.id)]; + _olm_pickle_uint32(key_id, key.id); + pos = olm::encode_base64(key_id, sizeof(key_id), pos); + *(pos++) = '\"'; *(pos++) = ':'; *(pos++) = '\"'; + pos = olm::encode_base64( + key.key.public_key.public_key, sizeof(key.key.public_key.public_key), pos + ); + *(pos++) = '\"'; + } + *(pos++) = '}'; + *(pos++) = '}'; + return pos - fallback_json; +} + +void olm::Account::forget_old_fallback_key( +) { + if (num_fallback_keys >= 2) { + num_fallback_keys = 1; + olm::unset(&prev_fallback_key, sizeof(prev_fallback_key)); + } +} + +namespace olm { + +static std::size_t pickle_length( + olm::IdentityKeys const & value +) { + size_t length = 0; + length += _olm_pickle_ed25519_key_pair_length(&value.ed25519_key); + length += olm::pickle_length(value.curve25519_key); + return length; +} + + +static std::uint8_t * pickle( + std::uint8_t * pos, + olm::IdentityKeys const & value +) { + pos = _olm_pickle_ed25519_key_pair(pos, &value.ed25519_key); + pos = olm::pickle(pos, value.curve25519_key); + return pos; +} + + +static std::uint8_t const * unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + olm::IdentityKeys & value +) { + pos = _olm_unpickle_ed25519_key_pair(pos, end, &value.ed25519_key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.curve25519_key); UNPICKLE_OK(pos); + return pos; +} + + +static std::size_t pickle_length( + olm::OneTimeKey const & value +) { + std::size_t length = 0; + length += olm::pickle_length(value.id); + length += olm::pickle_length(value.published); + length += olm::pickle_length(value.key); + return length; +} + + +static std::uint8_t * pickle( + std::uint8_t * pos, + olm::OneTimeKey const & value +) { + pos = olm::pickle(pos, value.id); + pos = olm::pickle(pos, value.published); + pos = olm::pickle(pos, value.key); + return pos; +} + + +static std::uint8_t const * unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + olm::OneTimeKey & value +) { + pos = olm::unpickle(pos, end, value.id); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.published); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.key); UNPICKLE_OK(pos); + return pos; +} + +} // namespace olm + +namespace { +// pickle version 1 used only 32 bytes for the ed25519 private key. +// Any keys thus used should be considered compromised. +// pickle version 2 does not have fallback keys. +// pickle version 3 does not store whether the current fallback key is published. +static const std::uint32_t ACCOUNT_PICKLE_VERSION = 4; +} + + +std::size_t olm::pickle_length( + olm::Account const & value +) { + std::size_t length = 0; + length += olm::pickle_length(ACCOUNT_PICKLE_VERSION); + length += olm::pickle_length(value.identity_keys); + length += olm::pickle_length(value.one_time_keys); + length += olm::pickle_length(value.num_fallback_keys); + if (value.num_fallback_keys >= 1) { + length += olm::pickle_length(value.current_fallback_key); + if (value.num_fallback_keys >= 2) { + length += olm::pickle_length(value.prev_fallback_key); + } + } + length += olm::pickle_length(value.next_one_time_key_id); + return length; +} + + +std::uint8_t * olm::pickle( + std::uint8_t * pos, + olm::Account const & value +) { + pos = olm::pickle(pos, ACCOUNT_PICKLE_VERSION); + pos = olm::pickle(pos, value.identity_keys); + pos = olm::pickle(pos, value.one_time_keys); + pos = olm::pickle(pos, value.num_fallback_keys); + if (value.num_fallback_keys >= 1) { + pos = olm::pickle(pos, value.current_fallback_key); + if (value.num_fallback_keys >= 2) { + pos = olm::pickle(pos, value.prev_fallback_key); + } + } + pos = olm::pickle(pos, value.next_one_time_key_id); + return pos; +} + + +std::uint8_t const * olm::unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + olm::Account & value +) { + uint32_t pickle_version; + + pos = olm::unpickle(pos, end, pickle_version); UNPICKLE_OK(pos); + + switch (pickle_version) { + case ACCOUNT_PICKLE_VERSION: + case 3: + case 2: + break; + case 1: + value.last_error = OlmErrorCode::OLM_BAD_LEGACY_ACCOUNT_PICKLE; + return nullptr; + default: + value.last_error = OlmErrorCode::OLM_UNKNOWN_PICKLE_VERSION; + return nullptr; + } + + pos = olm::unpickle(pos, end, value.identity_keys); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.one_time_keys); UNPICKLE_OK(pos); + + if (pickle_version <= 2) { + // version 2 did not have fallback keys + value.num_fallback_keys = 0; + } else if (pickle_version == 3) { + // version 3 used the published flag to indicate how many fallback keys + // were present (we'll have to assume that the keys were published) + pos = olm::unpickle(pos, end, value.current_fallback_key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.prev_fallback_key); UNPICKLE_OK(pos); + if (value.current_fallback_key.published) { + if (value.prev_fallback_key.published) { + value.num_fallback_keys = 2; + } else { + value.num_fallback_keys = 1; + } + } else { + value.num_fallback_keys = 0; + } + } else { + pos = olm::unpickle(pos, end, value.num_fallback_keys); UNPICKLE_OK(pos); + if (value.num_fallback_keys >= 1) { + pos = olm::unpickle(pos, end, value.current_fallback_key); UNPICKLE_OK(pos); + if (value.num_fallback_keys >= 2) { + pos = olm::unpickle(pos, end, value.prev_fallback_key); UNPICKLE_OK(pos); + if (value.num_fallback_keys >= 3) { + value.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE; + return nullptr; + } + } + } + } + + pos = olm::unpickle(pos, end, value.next_one_time_key_id); UNPICKLE_OK(pos); + + return pos; +} diff --git a/ext/olm/src/base64.cpp b/ext/olm/src/base64.cpp new file mode 100644 index 0000000..0e195fb --- /dev/null +++ b/ext/olm/src/base64.cpp @@ -0,0 +1,187 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <cassert> + +#include "olm/base64.h" +#include "olm/base64.hh" + +namespace { + +static const std::uint8_t ENCODE_BASE64[64] = { + 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, + 0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, + 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, + 0x59, 0x5A, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, + 0x67, 0x68, 0x69, 0x6A, 0x6B, 0x6C, 0x6D, 0x6E, + 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, + 0x77, 0x78, 0x79, 0x7A, 0x30, 0x31, 0x32, 0x33, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x2B, 0x2F, +}; + +static const std::uint8_t E = -1; + +static const std::uint8_t DECODE_BASE64[128] = { +/* 0x0 0x1 0x2 0x3 0x4 0x5 0x6 0x7 0x8 0x9 0xA 0xB 0xC 0xD 0xE 0xF */ + E, E, E, E, E, E, E, E, E, E, E, E, E, E, E, E, + E, E, E, E, E, E, E, E, E, E, E, E, E, E, E, E, + E, E, E, E, E, E, E, E, E, E, E, 62, E, E, E, 63, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, E, E, E, E, E, E, + E, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, E, E, E, E, E, + E, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, E, E, E, E, E, +}; + +} // namespace + + +std::size_t olm::encode_base64_length( + std::size_t input_length +) { + return 4 * ((input_length + 2) / 3) + (input_length + 2) % 3 - 2; +} + +std::uint8_t * olm::encode_base64( + std::uint8_t const * input, std::size_t input_length, + std::uint8_t * output +) { + std::uint8_t const * end = input + (input_length / 3) * 3; + std::uint8_t const * pos = input; + while (pos != end) { + unsigned value = pos[0]; + value <<= 8; value |= pos[1]; + value <<= 8; value |= pos[2]; + pos += 3; + output[3] = ENCODE_BASE64[value & 0x3F]; + value >>= 6; output[2] = ENCODE_BASE64[value & 0x3F]; + value >>= 6; output[1] = ENCODE_BASE64[value & 0x3F]; + value >>= 6; output[0] = ENCODE_BASE64[value]; + output += 4; + } + unsigned remainder = input + input_length - pos; + std::uint8_t * result = output; + if (remainder) { + unsigned value = pos[0]; + if (remainder == 2) { + value <<= 8; value |= pos[1]; + value <<= 2; + output[2] = ENCODE_BASE64[value & 0x3F]; + value >>= 6; + result += 3; + } else { + value <<= 4; + result += 2; + } + output[1] = ENCODE_BASE64[value & 0x3F]; + value >>= 6; + output[0] = ENCODE_BASE64[value]; + } + return result; +} + + +std::size_t olm::decode_base64_length( + std::size_t input_length +) { + if (input_length % 4 == 1) { + return std::size_t(-1); + } else { + return 3 * ((input_length + 2) / 4) + (input_length + 2) % 4 - 2; + } +} + + +std::size_t olm::decode_base64( + std::uint8_t const * input, std::size_t input_length, + std::uint8_t * output +) { + size_t raw_length = olm::decode_base64_length(input_length); + + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + + std::uint8_t const * end = input + (input_length / 4) * 4; + std::uint8_t const * pos = input; + + while (pos != end) { + unsigned value = DECODE_BASE64[pos[0] & 0x7F]; + value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F]; + value <<= 6; value |= DECODE_BASE64[pos[2] & 0x7F]; + value <<= 6; value |= DECODE_BASE64[pos[3] & 0x7F]; + pos += 4; + output[2] = value; + value >>= 8; output[1] = value; + value >>= 8; output[0] = value; + output += 3; + } + + unsigned remainder = input + input_length - pos; + if (remainder) { + /* A base64 payload with a single byte remainder cannot occur because + * a single base64 character only encodes 6 bits, which is less than + * a full byte. Therefore, a minimum of two base64 characters are + * required to construct a single output byte and payloads with + * a remainder of 1 are illegal. + * + * Should never be the case due to length check above. + */ + assert(remainder != 1); + + unsigned value = DECODE_BASE64[pos[0] & 0x7F]; + value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F]; + if (remainder == 3) { + value <<= 6; value |= DECODE_BASE64[pos[2] & 0x7F]; + value >>= 2; + output[1] = value; + value >>= 8; + } else { + value >>= 4; + } + output[0] = value; + } + + return raw_length; +} + + +// implementations of base64.h + +size_t _olm_encode_base64_length( + size_t input_length +) { + return olm::encode_base64_length(input_length); +} + +size_t _olm_encode_base64( + uint8_t const * input, size_t input_length, + uint8_t * output +) { + uint8_t * r = olm::encode_base64(input, input_length, output); + return r - output; +} + +size_t _olm_decode_base64_length( + size_t input_length +) { + return olm::decode_base64_length(input_length); +} + +size_t _olm_decode_base64( + uint8_t const * input, size_t input_length, + uint8_t * output +) { + return olm::decode_base64(input, input_length, output); +} diff --git a/ext/olm/src/cipher.cpp b/ext/olm/src/cipher.cpp new file mode 100644 index 0000000..2312b84 --- /dev/null +++ b/ext/olm/src/cipher.cpp @@ -0,0 +1,152 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/cipher.h" +#include "olm/crypto.h" +#include "olm/memory.hh" +#include <cstring> + +const std::size_t HMAC_KEY_LENGTH = 32; + +namespace { + +struct DerivedKeys { + _olm_aes256_key aes_key; + std::uint8_t mac_key[HMAC_KEY_LENGTH]; + _olm_aes256_iv aes_iv; +}; + + +static void derive_keys( + std::uint8_t const * kdf_info, std::size_t kdf_info_length, + std::uint8_t const * key, std::size_t key_length, + DerivedKeys & keys +) { + std::uint8_t derived_secrets[ + AES256_KEY_LENGTH + HMAC_KEY_LENGTH + AES256_IV_LENGTH + ]; + _olm_crypto_hkdf_sha256( + key, key_length, + nullptr, 0, + kdf_info, kdf_info_length, + derived_secrets, sizeof(derived_secrets) + ); + std::uint8_t const * pos = derived_secrets; + pos = olm::load_array(keys.aes_key.key, pos); + pos = olm::load_array(keys.mac_key, pos); + pos = olm::load_array(keys.aes_iv.iv, pos); + olm::unset(derived_secrets); +} + +static const std::size_t MAC_LENGTH = 8; + +size_t aes_sha_256_cipher_mac_length(const struct _olm_cipher *cipher) { + return MAC_LENGTH; +} + +size_t aes_sha_256_cipher_encrypt_ciphertext_length( + const struct _olm_cipher *cipher, size_t plaintext_length +) { + return _olm_crypto_aes_encrypt_cbc_length(plaintext_length); +} + +size_t aes_sha_256_cipher_encrypt( + const struct _olm_cipher *cipher, + uint8_t const * key, size_t key_length, + uint8_t const * plaintext, size_t plaintext_length, + uint8_t * ciphertext, size_t ciphertext_length, + uint8_t * output, size_t output_length +) { + auto *c = reinterpret_cast<const _olm_cipher_aes_sha_256 *>(cipher); + + if (ciphertext_length + < aes_sha_256_cipher_encrypt_ciphertext_length(cipher, plaintext_length) + || output_length < MAC_LENGTH) { + return std::size_t(-1); + } + + struct DerivedKeys keys; + std::uint8_t mac[SHA256_OUTPUT_LENGTH]; + + derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys); + + _olm_crypto_aes_encrypt_cbc( + &keys.aes_key, &keys.aes_iv, plaintext, plaintext_length, ciphertext + ); + + _olm_crypto_hmac_sha256( + keys.mac_key, HMAC_KEY_LENGTH, output, output_length - MAC_LENGTH, mac + ); + + std::memcpy(output + output_length - MAC_LENGTH, mac, MAC_LENGTH); + + olm::unset(keys); + return output_length; +} + + +size_t aes_sha_256_cipher_decrypt_max_plaintext_length( + const struct _olm_cipher *cipher, + size_t ciphertext_length +) { + return ciphertext_length; +} + +size_t aes_sha_256_cipher_decrypt( + const struct _olm_cipher *cipher, + uint8_t const * key, size_t key_length, + uint8_t const * input, size_t input_length, + uint8_t const * ciphertext, size_t ciphertext_length, + uint8_t * plaintext, size_t max_plaintext_length +) { + if (max_plaintext_length + < aes_sha_256_cipher_decrypt_max_plaintext_length(cipher, ciphertext_length) + || input_length < MAC_LENGTH) { + return std::size_t(-1); + } + + auto *c = reinterpret_cast<const _olm_cipher_aes_sha_256 *>(cipher); + + DerivedKeys keys; + std::uint8_t mac[SHA256_OUTPUT_LENGTH]; + + derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys); + + _olm_crypto_hmac_sha256( + keys.mac_key, HMAC_KEY_LENGTH, input, input_length - MAC_LENGTH, mac + ); + + std::uint8_t const * input_mac = input + input_length - MAC_LENGTH; + if (!olm::is_equal(input_mac, mac, MAC_LENGTH)) { + olm::unset(keys); + return std::size_t(-1); + } + + std::size_t plaintext_length = _olm_crypto_aes_decrypt_cbc( + &keys.aes_key, &keys.aes_iv, ciphertext, ciphertext_length, plaintext + ); + + olm::unset(keys); + return plaintext_length; +} + +} // namespace + +const struct _olm_cipher_ops _olm_cipher_aes_sha_256_ops = { + aes_sha_256_cipher_mac_length, + aes_sha_256_cipher_encrypt_ciphertext_length, + aes_sha_256_cipher_encrypt, + aes_sha_256_cipher_decrypt_max_plaintext_length, + aes_sha_256_cipher_decrypt, +}; diff --git a/ext/olm/src/crypto.cpp b/ext/olm/src/crypto.cpp new file mode 100644 index 0000000..e297513 --- /dev/null +++ b/ext/olm/src/crypto.cpp @@ -0,0 +1,299 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/crypto.h" +#include "olm/memory.hh" + +#include <cstring> + +extern "C" { + +#include "crypto-algorithms/aes.h" +#include "crypto-algorithms/sha256.h" + +} + +#include "ed25519/src/ed25519.h" +#include "curve25519-donna.h" + +namespace { + +static const std::uint8_t CURVE25519_BASEPOINT[32] = {9}; +static const std::size_t AES_KEY_SCHEDULE_LENGTH = 60; +static const std::size_t AES_KEY_BITS = 8 * AES256_KEY_LENGTH; +static const std::size_t AES_BLOCK_LENGTH = 16; +static const std::size_t SHA256_BLOCK_LENGTH = 64; +static const std::uint8_t HKDF_DEFAULT_SALT[32] = {}; + + +template<std::size_t block_size> +inline static void xor_block( + std::uint8_t * block, + std::uint8_t const * input +) { + for (std::size_t i = 0; i < block_size; ++i) { + block[i] ^= input[i]; + } +} + + +inline static void hmac_sha256_key( + std::uint8_t const * input_key, std::size_t input_key_length, + std::uint8_t * hmac_key +) { + std::memset(hmac_key, 0, SHA256_BLOCK_LENGTH); + if (input_key_length > SHA256_BLOCK_LENGTH) { + ::SHA256_CTX context; + ::sha256_init(&context); + ::sha256_update(&context, input_key, input_key_length); + ::sha256_final(&context, hmac_key); + } else { + std::memcpy(hmac_key, input_key, input_key_length); + } +} + + +inline static void hmac_sha256_init( + ::SHA256_CTX * context, + std::uint8_t const * hmac_key +) { + std::uint8_t i_pad[SHA256_BLOCK_LENGTH]; + std::memcpy(i_pad, hmac_key, SHA256_BLOCK_LENGTH); + for (std::size_t i = 0; i < SHA256_BLOCK_LENGTH; ++i) { + i_pad[i] ^= 0x36; + } + ::sha256_init(context); + ::sha256_update(context, i_pad, SHA256_BLOCK_LENGTH); + olm::unset(i_pad); +} + + +inline static void hmac_sha256_final( + ::SHA256_CTX * context, + std::uint8_t const * hmac_key, + std::uint8_t * output +) { + std::uint8_t o_pad[SHA256_BLOCK_LENGTH + SHA256_OUTPUT_LENGTH]; + std::memcpy(o_pad, hmac_key, SHA256_BLOCK_LENGTH); + for (std::size_t i = 0; i < SHA256_BLOCK_LENGTH; ++i) { + o_pad[i] ^= 0x5C; + } + ::sha256_final(context, o_pad + SHA256_BLOCK_LENGTH); + ::SHA256_CTX final_context; + ::sha256_init(&final_context); + ::sha256_update(&final_context, o_pad, sizeof(o_pad)); + ::sha256_final(&final_context, output); + olm::unset(final_context); + olm::unset(o_pad); +} + +} // namespace + +void _olm_crypto_curve25519_generate_key( + uint8_t const * random_32_bytes, + struct _olm_curve25519_key_pair *key_pair +) { + std::memcpy( + key_pair->private_key.private_key, random_32_bytes, + CURVE25519_KEY_LENGTH + ); + ::curve25519_donna( + key_pair->public_key.public_key, + key_pair->private_key.private_key, + CURVE25519_BASEPOINT + ); +} + + +void _olm_crypto_curve25519_shared_secret( + const struct _olm_curve25519_key_pair *our_key, + const struct _olm_curve25519_public_key * their_key, + std::uint8_t * output +) { + ::curve25519_donna(output, our_key->private_key.private_key, their_key->public_key); +} + + +void _olm_crypto_ed25519_generate_key( + std::uint8_t const * random_32_bytes, + struct _olm_ed25519_key_pair *key_pair +) { + ::ed25519_create_keypair( + key_pair->public_key.public_key, key_pair->private_key.private_key, + random_32_bytes + ); +} + + +void _olm_crypto_ed25519_sign( + const struct _olm_ed25519_key_pair *our_key, + std::uint8_t const * message, std::size_t message_length, + std::uint8_t * output +) { + ::ed25519_sign( + output, + message, message_length, + our_key->public_key.public_key, + our_key->private_key.private_key + ); +} + + +int _olm_crypto_ed25519_verify( + const struct _olm_ed25519_public_key *their_key, + std::uint8_t const * message, std::size_t message_length, + std::uint8_t const * signature +) { + return 0 != ::ed25519_verify( + signature, + message, message_length, + their_key->public_key + ); +} + + +std::size_t _olm_crypto_aes_encrypt_cbc_length( + std::size_t input_length +) { + return input_length + AES_BLOCK_LENGTH - input_length % AES_BLOCK_LENGTH; +} + + +void _olm_crypto_aes_encrypt_cbc( + _olm_aes256_key const *key, + _olm_aes256_iv const *iv, + std::uint8_t const * input, std::size_t input_length, + std::uint8_t * output +) { + std::uint32_t key_schedule[AES_KEY_SCHEDULE_LENGTH]; + ::_olm_aes_key_setup(key->key, key_schedule, AES_KEY_BITS); + std::uint8_t input_block[AES_BLOCK_LENGTH]; + std::memcpy(input_block, iv->iv, AES_BLOCK_LENGTH); + while (input_length >= AES_BLOCK_LENGTH) { + xor_block<AES_BLOCK_LENGTH>(input_block, input); + ::_olm_aes_encrypt(input_block, output, key_schedule, AES_KEY_BITS); + std::memcpy(input_block, output, AES_BLOCK_LENGTH); + input += AES_BLOCK_LENGTH; + output += AES_BLOCK_LENGTH; + input_length -= AES_BLOCK_LENGTH; + } + std::size_t i = 0; + for (; i < input_length; ++i) { + input_block[i] ^= input[i]; + } + for (; i < AES_BLOCK_LENGTH; ++i) { + input_block[i] ^= AES_BLOCK_LENGTH - input_length; + } + ::_olm_aes_encrypt(input_block, output, key_schedule, AES_KEY_BITS); + olm::unset(key_schedule); + olm::unset(input_block); +} + + +std::size_t _olm_crypto_aes_decrypt_cbc( + _olm_aes256_key const *key, + _olm_aes256_iv const *iv, + std::uint8_t const * input, std::size_t input_length, + std::uint8_t * output +) { + std::uint32_t key_schedule[AES_KEY_SCHEDULE_LENGTH]; + ::_olm_aes_key_setup(key->key, key_schedule, AES_KEY_BITS); + std::uint8_t block1[AES_BLOCK_LENGTH]; + std::uint8_t block2[AES_BLOCK_LENGTH]; + std::memcpy(block1, iv->iv, AES_BLOCK_LENGTH); + for (std::size_t i = 0; i < input_length; i += AES_BLOCK_LENGTH) { + std::memcpy(block2, &input[i], AES_BLOCK_LENGTH); + ::_olm_aes_decrypt(&input[i], &output[i], key_schedule, AES_KEY_BITS); + xor_block<AES_BLOCK_LENGTH>(&output[i], block1); + std::memcpy(block1, block2, AES_BLOCK_LENGTH); + } + olm::unset(key_schedule); + olm::unset(block1); + olm::unset(block2); + std::size_t padding = output[input_length - 1]; + return (padding > input_length) ? std::size_t(-1) : (input_length - padding); +} + + +void _olm_crypto_sha256( + std::uint8_t const * input, std::size_t input_length, + std::uint8_t * output +) { + ::SHA256_CTX context; + ::sha256_init(&context); + ::sha256_update(&context, input, input_length); + ::sha256_final(&context, output); + olm::unset(context); +} + + +void _olm_crypto_hmac_sha256( + std::uint8_t const * key, std::size_t key_length, + std::uint8_t const * input, std::size_t input_length, + std::uint8_t * output +) { + std::uint8_t hmac_key[SHA256_BLOCK_LENGTH]; + ::SHA256_CTX context; + hmac_sha256_key(key, key_length, hmac_key); + hmac_sha256_init(&context, hmac_key); + ::sha256_update(&context, input, input_length); + hmac_sha256_final(&context, hmac_key, output); + olm::unset(hmac_key); + olm::unset(context); +} + + +void _olm_crypto_hkdf_sha256( + std::uint8_t const * input, std::size_t input_length, + std::uint8_t const * salt, std::size_t salt_length, + std::uint8_t const * info, std::size_t info_length, + std::uint8_t * output, std::size_t output_length +) { + ::SHA256_CTX context; + std::uint8_t hmac_key[SHA256_BLOCK_LENGTH]; + std::uint8_t step_result[SHA256_OUTPUT_LENGTH]; + std::size_t bytes_remaining = output_length; + std::uint8_t iteration = 1; + if (!salt) { + salt = HKDF_DEFAULT_SALT; + salt_length = sizeof(HKDF_DEFAULT_SALT); + } + /* Extract */ + hmac_sha256_key(salt, salt_length, hmac_key); + hmac_sha256_init(&context, hmac_key); + ::sha256_update(&context, input, input_length); + hmac_sha256_final(&context, hmac_key, step_result); + hmac_sha256_key(step_result, SHA256_OUTPUT_LENGTH, hmac_key); + + /* Expand */ + hmac_sha256_init(&context, hmac_key); + ::sha256_update(&context, info, info_length); + ::sha256_update(&context, &iteration, 1); + hmac_sha256_final(&context, hmac_key, step_result); + while (bytes_remaining > SHA256_OUTPUT_LENGTH) { + std::memcpy(output, step_result, SHA256_OUTPUT_LENGTH); + output += SHA256_OUTPUT_LENGTH; + bytes_remaining -= SHA256_OUTPUT_LENGTH; + iteration ++; + hmac_sha256_init(&context, hmac_key); + ::sha256_update(&context, step_result, SHA256_OUTPUT_LENGTH); + ::sha256_update(&context, info, info_length); + ::sha256_update(&context, &iteration, 1); + hmac_sha256_final(&context, hmac_key, step_result); + } + std::memcpy(output, step_result, bytes_remaining); + olm::unset(context); + olm::unset(hmac_key); + olm::unset(step_result); +} diff --git a/ext/olm/src/ed25519.c b/ext/olm/src/ed25519.c new file mode 100644 index 0000000..c7a1a8e --- /dev/null +++ b/ext/olm/src/ed25519.c @@ -0,0 +1,22 @@ +/* Copyright 2015-6 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#define select ed25519_select +#include "ed25519/src/fe.c" +#include "ed25519/src/sc.c" +#include "ed25519/src/ge.c" +#include "ed25519/src/keypair.c" +#include "ed25519/src/sha512.c" +#include "ed25519/src/verify.c" +#include "ed25519/src/sign.c" diff --git a/ext/olm/src/error.c b/ext/olm/src/error.c new file mode 100644 index 0000000..6775eee --- /dev/null +++ b/ext/olm/src/error.c @@ -0,0 +1,46 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/error.h" + +static const char * ERRORS[] = { + "SUCCESS", + "NOT_ENOUGH_RANDOM", + "OUTPUT_BUFFER_TOO_SMALL", + "BAD_MESSAGE_VERSION", + "BAD_MESSAGE_FORMAT", + "BAD_MESSAGE_MAC", + "BAD_MESSAGE_KEY_ID", + "INVALID_BASE64", + "BAD_ACCOUNT_KEY", + "UNKNOWN_PICKLE_VERSION", + "CORRUPTED_PICKLE", + "BAD_SESSION_KEY", + "UNKNOWN_MESSAGE_INDEX", + "BAD_LEGACY_ACCOUNT_PICKLE", + "BAD_SIGNATURE", + "OLM_INPUT_BUFFER_TOO_SMALL", + "OLM_SAS_THEIR_KEY_NOT_SET", + "OLM_PICKLE_EXTRA_DATA" +}; + +const char * _olm_error_to_string(enum OlmErrorCode error) +{ + if (error < (sizeof(ERRORS)/sizeof(ERRORS[0]))) { + return ERRORS[error]; + } else { + return "UNKNOWN_ERROR"; + } +} diff --git a/ext/olm/src/inbound_group_session.c b/ext/olm/src/inbound_group_session.c new file mode 100644 index 0000000..d6f73b7 --- /dev/null +++ b/ext/olm/src/inbound_group_session.c @@ -0,0 +1,540 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/inbound_group_session.h" + +#include <string.h> + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/crypto.h" +#include "olm/error.h" +#include "olm/megolm.h" +#include "olm/memory.h" +#include "olm/message.h" +#include "olm/pickle.h" +#include "olm/pickle_encoding.h" + + +#define OLM_PROTOCOL_VERSION 3 +#define GROUP_SESSION_ID_LENGTH ED25519_PUBLIC_KEY_LENGTH +#define PICKLE_VERSION 2 +#define SESSION_KEY_VERSION 2 +#define SESSION_EXPORT_VERSION 1 + +struct OlmInboundGroupSession { + /** our earliest known ratchet value */ + Megolm initial_ratchet; + + /** The most recent ratchet value */ + Megolm latest_ratchet; + + /** The ed25519 signing key */ + struct _olm_ed25519_public_key signing_key; + + /** + * Have we ever seen any evidence that this is a valid session? + * (either because the original session share was signed, or because we + * have subsequently successfully decrypted a message) + * + * (We don't do anything with this currently, but we may want to bear it in + * mind when we consider handling key-shares for sessions we already know + * about.) + */ + int signing_key_verified; + + enum OlmErrorCode last_error; +}; + +size_t olm_inbound_group_session_size(void) { + return sizeof(OlmInboundGroupSession); +} + +OlmInboundGroupSession * olm_inbound_group_session( + void *memory +) { + OlmInboundGroupSession *session = memory; + olm_clear_inbound_group_session(session); + return session; +} + +const char *olm_inbound_group_session_last_error( + const OlmInboundGroupSession *session +) { + return _olm_error_to_string(session->last_error); +} + +enum OlmErrorCode olm_inbound_group_session_last_error_code( + const OlmInboundGroupSession *session +) { + return session->last_error; +} + +size_t olm_clear_inbound_group_session( + OlmInboundGroupSession *session +) { + _olm_unset(session, sizeof(OlmInboundGroupSession)); + return sizeof(OlmInboundGroupSession); +} + +#define SESSION_EXPORT_RAW_LENGTH \ + (1 + 4 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH) + +#define SESSION_KEY_RAW_LENGTH \ + (1 + 4 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH\ + + ED25519_SIGNATURE_LENGTH) + +static size_t _init_group_session_keys( + OlmInboundGroupSession *session, + const uint8_t *key_buf, + int export_format +) { + const uint8_t expected_version = + (export_format ? SESSION_EXPORT_VERSION : SESSION_KEY_VERSION); + const uint8_t *ptr = key_buf; + size_t version = *ptr++; + + if (version != expected_version) { + session->last_error = OLM_BAD_SESSION_KEY; + return (size_t)-1; + } + + uint32_t counter = 0; + // Decode counter as a big endian 32-bit number. + for (unsigned i = 0; i < 4; i++) { + counter <<= 8; counter |= *ptr++; + } + + megolm_init(&session->initial_ratchet, ptr, counter); + megolm_init(&session->latest_ratchet, ptr, counter); + + ptr += MEGOLM_RATCHET_LENGTH; + memcpy( + session->signing_key.public_key, ptr, ED25519_PUBLIC_KEY_LENGTH + ); + ptr += ED25519_PUBLIC_KEY_LENGTH; + + if (!export_format) { + if (!_olm_crypto_ed25519_verify(&session->signing_key, key_buf, + ptr - key_buf, ptr)) { + session->last_error = OLM_BAD_SIGNATURE; + return (size_t)-1; + } + + /* signed keyshare */ + session->signing_key_verified = 1; + } + return 0; +} + +size_t olm_init_inbound_group_session( + OlmInboundGroupSession *session, + const uint8_t * session_key, size_t session_key_length +) { + uint8_t key_buf[SESSION_KEY_RAW_LENGTH]; + size_t raw_length = _olm_decode_base64_length(session_key_length); + size_t result; + + if (raw_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + if (raw_length != SESSION_KEY_RAW_LENGTH) { + session->last_error = OLM_BAD_SESSION_KEY; + return (size_t)-1; + } + + _olm_decode_base64(session_key, session_key_length, key_buf); + result = _init_group_session_keys(session, key_buf, 0); + _olm_unset(key_buf, SESSION_KEY_RAW_LENGTH); + return result; +} + +size_t olm_import_inbound_group_session( + OlmInboundGroupSession *session, + const uint8_t * session_key, size_t session_key_length +) { + uint8_t key_buf[SESSION_EXPORT_RAW_LENGTH]; + size_t raw_length = _olm_decode_base64_length(session_key_length); + size_t result; + + if (raw_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + if (raw_length != SESSION_EXPORT_RAW_LENGTH) { + session->last_error = OLM_BAD_SESSION_KEY; + return (size_t)-1; + } + + _olm_decode_base64(session_key, session_key_length, key_buf); + result = _init_group_session_keys(session, key_buf, 1); + _olm_unset(key_buf, SESSION_EXPORT_RAW_LENGTH); + return result; +} + +static size_t raw_pickle_length( + const OlmInboundGroupSession *session +) { + size_t length = 0; + length += _olm_pickle_uint32_length(PICKLE_VERSION); + length += megolm_pickle_length(&session->initial_ratchet); + length += megolm_pickle_length(&session->latest_ratchet); + length += _olm_pickle_ed25519_public_key_length(&session->signing_key); + length += _olm_pickle_bool_length(session->signing_key_verified); + return length; +} + +size_t olm_pickle_inbound_group_session_length( + const OlmInboundGroupSession *session +) { + return _olm_enc_output_length(raw_pickle_length(session)); +} + +size_t olm_pickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + size_t raw_length = raw_pickle_length(session); + uint8_t *pos; + + if (pickled_length < _olm_enc_output_length(raw_length)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + pos = _olm_enc_output_pos(pickled, raw_length); + pos = _olm_pickle_uint32(pos, PICKLE_VERSION); + pos = megolm_pickle(&session->initial_ratchet, pos); + pos = megolm_pickle(&session->latest_ratchet, pos); + pos = _olm_pickle_ed25519_public_key(pos, &session->signing_key); + pos = _olm_pickle_bool(pos, session->signing_key_verified); + + return _olm_enc_output(key, key_length, pickled, raw_length); +} + +size_t olm_unpickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + const uint8_t *pos; + const uint8_t *end; + uint32_t pickle_version; + + size_t raw_length = _olm_enc_input( + key, key_length, pickled, pickled_length, &(session->last_error) + ); + if (raw_length == (size_t)-1) { + return raw_length; + } + + pos = pickled; + end = pos + raw_length; + + pos = _olm_unpickle_uint32(pos, end, &pickle_version); + FAIL_ON_CORRUPTED_PICKLE(pos, session); + + if (pickle_version < 1 || pickle_version > PICKLE_VERSION) { + session->last_error = OLM_UNKNOWN_PICKLE_VERSION; + return (size_t)-1; + } + + pos = megolm_unpickle(&session->initial_ratchet, pos, end); + FAIL_ON_CORRUPTED_PICKLE(pos, session); + + pos = megolm_unpickle(&session->latest_ratchet, pos, end); + FAIL_ON_CORRUPTED_PICKLE(pos, session); + + pos = _olm_unpickle_ed25519_public_key(pos, end, &session->signing_key); + FAIL_ON_CORRUPTED_PICKLE(pos, session); + + if (pickle_version == 1) { + /* pickle v1 had no signing_key_verified field (all keyshares were + * verified at import time) */ + session->signing_key_verified = 1; + } else { + pos = _olm_unpickle_bool(pos, end, &(session->signing_key_verified)); + } + FAIL_ON_CORRUPTED_PICKLE(pos, session); + + if (pos != end) { + /* Input was longer than expected. */ + session->last_error = OLM_PICKLE_EXTRA_DATA; + return (size_t)-1; + } + + return pickled_length; +} + +/** + * get the max plaintext length in an un-base64-ed message + */ +static size_t _decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +) { + struct _OlmDecodeGroupMessageResults decoded_results; + + _olm_decode_group_message( + message, message_length, + megolm_cipher->ops->mac_length(megolm_cipher), + ED25519_SIGNATURE_LENGTH, + &decoded_results); + + if (decoded_results.version != OLM_PROTOCOL_VERSION) { + session->last_error = OLM_BAD_MESSAGE_VERSION; + return (size_t)-1; + } + + if (!decoded_results.ciphertext) { + session->last_error = OLM_BAD_MESSAGE_FORMAT; + return (size_t)-1; + } + + return megolm_cipher->ops->decrypt_max_plaintext_length( + megolm_cipher, decoded_results.ciphertext_length); +} + +size_t olm_group_decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +) { + size_t raw_length; + + raw_length = _olm_decode_base64(message, message_length, message); + if (raw_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + return _decrypt_max_plaintext_length( + session, message, raw_length + ); +} + +/** + * get a copy of the megolm ratchet, advanced + * to the relevant index. Returns 0 on success, -1 on error + */ +static size_t _get_megolm( + OlmInboundGroupSession *session, uint32_t message_index, Megolm *result +) { + /* pick a megolm instance to use. If we're at or beyond the latest ratchet + * value, use that */ + if ((message_index - session->latest_ratchet.counter) < (1U << 31)) { + megolm_advance_to(&session->latest_ratchet, message_index); + *result = session->latest_ratchet; + return 0; + } else if ((message_index - session->initial_ratchet.counter) >= (1U << 31)) { + /* the counter is before our intial ratchet - we can't decode this. */ + session->last_error = OLM_UNKNOWN_MESSAGE_INDEX; + return (size_t)-1; + } else { + /* otherwise, start from the initial megolm. Take a copy so that we + * don't overwrite the initial megolm */ + *result = session->initial_ratchet; + megolm_advance_to(result, message_index); + return 0; + } +} + +/** + * decrypt an un-base64-ed message + */ +static size_t _decrypt( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length, + uint8_t * plaintext, size_t max_plaintext_length, + uint32_t * message_index +) { + struct _OlmDecodeGroupMessageResults decoded_results; + size_t max_length, r; + Megolm megolm; + + _olm_decode_group_message( + message, message_length, + megolm_cipher->ops->mac_length(megolm_cipher), + ED25519_SIGNATURE_LENGTH, + &decoded_results); + + if (decoded_results.version != OLM_PROTOCOL_VERSION) { + session->last_error = OLM_BAD_MESSAGE_VERSION; + return (size_t)-1; + } + + if (!decoded_results.has_message_index || !decoded_results.ciphertext) { + session->last_error = OLM_BAD_MESSAGE_FORMAT; + return (size_t)-1; + } + + if (message_index != NULL) { + *message_index = decoded_results.message_index; + } + + /* verify the signature. We could do this before decoding the message, but + * we allow for the possibility of future protocol versions which use a + * different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION" + * than "BAD_SIGNATURE" in this case. + */ + message_length -= ED25519_SIGNATURE_LENGTH; + r = _olm_crypto_ed25519_verify( + &session->signing_key, + message, message_length, + message + message_length + ); + if (!r) { + session->last_error = OLM_BAD_SIGNATURE; + return (size_t)-1; + } + + max_length = megolm_cipher->ops->decrypt_max_plaintext_length( + megolm_cipher, + decoded_results.ciphertext_length + ); + if (max_plaintext_length < max_length) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + r = _get_megolm(session, decoded_results.message_index, &megolm); + if (r == (size_t)-1) { + return r; + } + + /* now try checking the mac, and decrypting */ + r = megolm_cipher->ops->decrypt( + megolm_cipher, + megolm_get_data(&megolm), MEGOLM_RATCHET_LENGTH, + message, message_length, + decoded_results.ciphertext, decoded_results.ciphertext_length, + plaintext, max_plaintext_length + ); + + _olm_unset(&megolm, sizeof(megolm)); + if (r == (size_t)-1) { + session->last_error = OLM_BAD_MESSAGE_MAC; + return r; + } + + /* once we have successfully decrypted a message, set a flag to say the + * session appears valid. */ + session->signing_key_verified = 1; + + return r; +} + +size_t olm_group_decrypt( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length, + uint8_t * plaintext, size_t max_plaintext_length, + uint32_t * message_index +) { + size_t raw_message_length; + + raw_message_length = _olm_decode_base64(message, message_length, message); + if (raw_message_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + return _decrypt( + session, message, raw_message_length, + plaintext, max_plaintext_length, + message_index + ); +} + +size_t olm_inbound_group_session_id_length( + const OlmInboundGroupSession *session +) { + return _olm_encode_base64_length(GROUP_SESSION_ID_LENGTH); +} + +size_t olm_inbound_group_session_id( + OlmInboundGroupSession *session, + uint8_t * id, size_t id_length +) { + if (id_length < olm_inbound_group_session_id_length(session)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + return _olm_encode_base64( + session->signing_key.public_key, GROUP_SESSION_ID_LENGTH, id + ); +} + +uint32_t olm_inbound_group_session_first_known_index( + const OlmInboundGroupSession *session +) { + return session->initial_ratchet.counter; +} + +int olm_inbound_group_session_is_verified( + const OlmInboundGroupSession *session +) { + return session->signing_key_verified; +} + +size_t olm_export_inbound_group_session_length( + const OlmInboundGroupSession *session +) { + return _olm_encode_base64_length(SESSION_EXPORT_RAW_LENGTH); +} + +size_t olm_export_inbound_group_session( + OlmInboundGroupSession *session, + uint8_t * key, size_t key_length, uint32_t message_index +) { + uint8_t *raw; + uint8_t *ptr; + Megolm megolm; + size_t r; + size_t encoded_length = olm_export_inbound_group_session_length(session); + + if (key_length < encoded_length) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + r = _get_megolm(session, message_index, &megolm); + if (r == (size_t)-1) { + return r; + } + + /* put the raw data at the end of the output buffer. */ + raw = ptr = key + encoded_length - SESSION_EXPORT_RAW_LENGTH; + *ptr++ = SESSION_EXPORT_VERSION; + + // Encode message index as a big endian 32-bit number. + for (unsigned i = 0; i < 4; i++) { + *ptr++ = 0xFF & (message_index >> 24); message_index <<= 8; + } + + memcpy(ptr, megolm_get_data(&megolm), MEGOLM_RATCHET_LENGTH); + ptr += MEGOLM_RATCHET_LENGTH; + + memcpy( + ptr, session->signing_key.public_key, + ED25519_PUBLIC_KEY_LENGTH + ); + ptr += ED25519_PUBLIC_KEY_LENGTH; + + return _olm_encode_base64(raw, SESSION_EXPORT_RAW_LENGTH, key); +} diff --git a/ext/olm/src/megolm.c b/ext/olm/src/megolm.c new file mode 100644 index 0000000..c4d1110 --- /dev/null +++ b/ext/olm/src/megolm.c @@ -0,0 +1,154 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "olm/megolm.h" + +#include <string.h> + +#include "olm/cipher.h" +#include "olm/crypto.h" +#include "olm/pickle.h" + +static const struct _olm_cipher_aes_sha_256 MEGOLM_CIPHER = + OLM_CIPHER_INIT_AES_SHA_256("MEGOLM_KEYS"); +const struct _olm_cipher *megolm_cipher = OLM_CIPHER_BASE(&MEGOLM_CIPHER); + +/* the seeds used in the HMAC-SHA-256 functions for each part of the ratchet. + */ +#define HASH_KEY_SEED_LENGTH 1 +static uint8_t HASH_KEY_SEEDS[MEGOLM_RATCHET_PARTS][HASH_KEY_SEED_LENGTH] = { + {0x00}, + {0x01}, + {0x02}, + {0x03} +}; + +static void rehash_part( + uint8_t data[MEGOLM_RATCHET_PARTS][MEGOLM_RATCHET_PART_LENGTH], + int rehash_from_part, int rehash_to_part +) { + _olm_crypto_hmac_sha256( + data[rehash_from_part], + MEGOLM_RATCHET_PART_LENGTH, + HASH_KEY_SEEDS[rehash_to_part], HASH_KEY_SEED_LENGTH, + data[rehash_to_part] + ); +} + + + +void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter) { + megolm->counter = counter; + memcpy(megolm->data, random_data, MEGOLM_RATCHET_LENGTH); +} + +size_t megolm_pickle_length(const Megolm *megolm) { + size_t length = 0; + length += _olm_pickle_bytes_length(megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH); + length += _olm_pickle_uint32_length(megolm->counter); + return length; + +} + +uint8_t * megolm_pickle(const Megolm *megolm, uint8_t *pos) { + pos = _olm_pickle_bytes(pos, megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH); + pos = _olm_pickle_uint32(pos, megolm->counter); + return pos; +} + +const uint8_t * megolm_unpickle(Megolm *megolm, const uint8_t *pos, + const uint8_t *end) { + pos = _olm_unpickle_bytes(pos, end, (uint8_t *)(megolm->data), + MEGOLM_RATCHET_LENGTH); + UNPICKLE_OK(pos); + + pos = _olm_unpickle_uint32(pos, end, &megolm->counter); + UNPICKLE_OK(pos); + + return pos; +} + +/* simplistic implementation for a single step */ +void megolm_advance(Megolm *megolm) { + uint32_t mask = 0x00FFFFFF; + int h = 0; + int i; + + megolm->counter++; + + /* figure out how much we need to rekey */ + while (h < (int)MEGOLM_RATCHET_PARTS) { + if (!(megolm->counter & mask)) + break; + h++; + mask >>= 8; + } + + /* now update R(h)...R(3) based on R(h) */ + for (i = MEGOLM_RATCHET_PARTS-1; i >= h; i--) { + rehash_part(megolm->data, h, i); + } +} + +void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { + int j; + + /* starting with R0, see if we need to update each part of the hash */ + for (j = 0; j < (int)MEGOLM_RATCHET_PARTS; j++) { + int shift = (MEGOLM_RATCHET_PARTS-j-1) * 8; + uint32_t mask = (~(uint32_t)0) << shift; + int k; + + /* how many times do we need to rehash this part? + * + * '& 0xff' ensures we handle integer wraparound correctly + */ + unsigned int steps = + ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff; + + if (steps == 0) { + /* deal with the edge case where megolm->counter is slightly larger + * than advance_to. This should only happen for R(0), and implies + * that advance_to has wrapped around and we need to advance R(0) + * 256 times. + */ + if (advance_to < megolm->counter) { + steps = 0x100; + } else { + continue; + } + } + + /* for all but the last step, we can just bump R(j) without regard + * to R(j+1)...R(3). + */ + while (steps > 1) { + rehash_part(megolm->data, j, j); + steps --; + } + + /* on the last step we also need to bump R(j+1)...R(3). + * + * (Theoretically, we could skip bumping R(j+2) if we're going to bump + * R(j+1) again, but the code to figure that out is a bit baroque and + * doesn't save us much). + */ + for (k = 3; k >= j; k--) { + rehash_part(megolm->data, j, k); + } + megolm->counter = advance_to & mask; + } +} diff --git a/ext/olm/src/memory.cpp b/ext/olm/src/memory.cpp new file mode 100644 index 0000000..20e0683 --- /dev/null +++ b/ext/olm/src/memory.cpp @@ -0,0 +1,45 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/memory.hh" +#include "olm/memory.h" + +void _olm_unset( + void volatile * buffer, size_t buffer_length +) { + olm::unset(buffer, buffer_length); +} + +void olm::unset( + void volatile * buffer, std::size_t buffer_length +) { + char volatile * pos = reinterpret_cast<char volatile *>(buffer); + char volatile * end = pos + buffer_length; + while (pos != end) { + *(pos++) = 0; + } +} + + +bool olm::is_equal( + std::uint8_t const * buffer_a, + std::uint8_t const * buffer_b, + std::size_t length +) { + std::uint8_t volatile result = 0; + while (length--) { + result |= (*(buffer_a++)) ^ (*(buffer_b++)); + } + return result == 0; +} diff --git a/ext/olm/src/message.cpp b/ext/olm/src/message.cpp new file mode 100644 index 0000000..e5e63f0 --- /dev/null +++ b/ext/olm/src/message.cpp @@ -0,0 +1,406 @@ +/* Copyright 2015-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/message.hh" + +#include "olm/memory.hh" + +namespace { + +template<typename T> +static std::size_t varint_length( + T value +) { + std::size_t result = 1; + while (value >= 128U) { + ++result; + value >>= 7; + } + return result; +} + + +template<typename T> +static std::uint8_t * varint_encode( + std::uint8_t * output, + T value +) { + while (value >= 128U) { + *(output++) = (0x7F & value) | 0x80; + value >>= 7; + } + (*output++) = value; + return output; +} + + +template<typename T> +static T varint_decode( + std::uint8_t const * varint_start, + std::uint8_t const * varint_end +) { + T value = 0; + if (varint_end == varint_start) { + return 0; + } + do { + value <<= 7; + value |= 0x7F & *(--varint_end); + } while (varint_end != varint_start); + return value; +} + + +static std::uint8_t const * varint_skip( + std::uint8_t const * input, + std::uint8_t const * input_end +) { + while (input != input_end) { + std::uint8_t tmp = *(input++); + if ((tmp & 0x80) == 0) { + return input; + } + } + return input; +} + + +static std::size_t varstring_length( + std::size_t string_length +) { + return varint_length(string_length) + string_length; +} + +static std::size_t const VERSION_LENGTH = 1; +static std::uint8_t const RATCHET_KEY_TAG = 012; +static std::uint8_t const COUNTER_TAG = 020; +static std::uint8_t const CIPHERTEXT_TAG = 042; + +static std::uint8_t * encode( + std::uint8_t * pos, + std::uint8_t tag, + std::uint32_t value +) { + *(pos++) = tag; + return varint_encode(pos, value); +} + +static std::uint8_t * encode( + std::uint8_t * pos, + std::uint8_t tag, + std::uint8_t * & value, std::size_t value_length +) { + *(pos++) = tag; + pos = varint_encode(pos, value_length); + value = pos; + return pos + value_length; +} + +static std::uint8_t const * decode( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint8_t tag, + std::uint32_t & value, bool & has_value +) { + if (pos != end && *pos == tag) { + ++pos; + std::uint8_t const * value_start = pos; + pos = varint_skip(pos, end); + value = varint_decode<std::uint32_t>(value_start, pos); + has_value = true; + } + return pos; +} + + +static std::uint8_t const * decode( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint8_t tag, + std::uint8_t const * & value, std::size_t & value_length +) { + if (pos != end && *pos == tag) { + ++pos; + std::uint8_t const * len_start = pos; + pos = varint_skip(pos, end); + std::size_t len = varint_decode<std::size_t>(len_start, pos); + if (len > std::size_t(end - pos)) return end; + value = pos; + value_length = len; + pos += len; + } + return pos; +} + +static std::uint8_t const * skip_unknown( + std::uint8_t const * pos, std::uint8_t const * end +) { + if (pos != end) { + uint8_t tag = *pos; + if ((tag & 0x7) == 0) { + pos = varint_skip(pos, end); + pos = varint_skip(pos, end); + } else if ((tag & 0x7) == 2) { + pos = varint_skip(pos, end); + std::uint8_t const * len_start = pos; + pos = varint_skip(pos, end); + std::size_t len = varint_decode<std::size_t>(len_start, pos); + if (len > std::size_t(end - pos)) return end; + pos += len; + } else { + return end; + } + } + return pos; +} + +} // namespace + + +std::size_t olm::encode_message_length( + std::uint32_t counter, + std::size_t ratchet_key_length, + std::size_t ciphertext_length, + std::size_t mac_length +) { + std::size_t length = VERSION_LENGTH; + length += 1 + varstring_length(ratchet_key_length); + length += 1 + varint_length(counter); + length += 1 + varstring_length(ciphertext_length); + length += mac_length; + return length; +} + + +void olm::encode_message( + olm::MessageWriter & writer, + std::uint8_t version, + std::uint32_t counter, + std::size_t ratchet_key_length, + std::size_t ciphertext_length, + std::uint8_t * output +) { + std::uint8_t * pos = output; + *(pos++) = version; + pos = encode(pos, RATCHET_KEY_TAG, writer.ratchet_key, ratchet_key_length); + pos = encode(pos, COUNTER_TAG, counter); + pos = encode(pos, CIPHERTEXT_TAG, writer.ciphertext, ciphertext_length); +} + + +void olm::decode_message( + olm::MessageReader & reader, + std::uint8_t const * input, std::size_t input_length, + std::size_t mac_length +) { + std::uint8_t const * pos = input; + std::uint8_t const * end = input + input_length - mac_length; + std::uint8_t const * unknown = nullptr; + + reader.version = 0; + reader.has_counter = false; + reader.counter = 0; + reader.input = input; + reader.input_length = input_length; + reader.ratchet_key = nullptr; + reader.ratchet_key_length = 0; + reader.ciphertext = nullptr; + reader.ciphertext_length = 0; + + if (input_length < mac_length) return; + + if (pos == end) return; + reader.version = *(pos++); + + while (pos != end) { + unknown = pos; + pos = decode( + pos, end, RATCHET_KEY_TAG, + reader.ratchet_key, reader.ratchet_key_length + ); + pos = decode( + pos, end, COUNTER_TAG, + reader.counter, reader.has_counter + ); + pos = decode( + pos, end, CIPHERTEXT_TAG, + reader.ciphertext, reader.ciphertext_length + ); + if (unknown == pos) { + pos = skip_unknown(pos, end); + } + } +} + + +namespace { + +static std::uint8_t const ONE_TIME_KEY_ID_TAG = 012; +static std::uint8_t const BASE_KEY_TAG = 022; +static std::uint8_t const IDENTITY_KEY_TAG = 032; +static std::uint8_t const MESSAGE_TAG = 042; + +} // namespace + + +std::size_t olm::encode_one_time_key_message_length( + std::size_t one_time_key_length, + std::size_t identity_key_length, + std::size_t base_key_length, + std::size_t message_length +) { + std::size_t length = VERSION_LENGTH; + length += 1 + varstring_length(one_time_key_length); + length += 1 + varstring_length(identity_key_length); + length += 1 + varstring_length(base_key_length); + length += 1 + varstring_length(message_length); + return length; +} + + +void olm::encode_one_time_key_message( + olm::PreKeyMessageWriter & writer, + std::uint8_t version, + std::size_t identity_key_length, + std::size_t base_key_length, + std::size_t one_time_key_length, + std::size_t message_length, + std::uint8_t * output +) { + std::uint8_t * pos = output; + *(pos++) = version; + pos = encode(pos, ONE_TIME_KEY_ID_TAG, writer.one_time_key, one_time_key_length); + pos = encode(pos, BASE_KEY_TAG, writer.base_key, base_key_length); + pos = encode(pos, IDENTITY_KEY_TAG, writer.identity_key, identity_key_length); + pos = encode(pos, MESSAGE_TAG, writer.message, message_length); +} + + +void olm::decode_one_time_key_message( + PreKeyMessageReader & reader, + std::uint8_t const * input, std::size_t input_length +) { + std::uint8_t const * pos = input; + std::uint8_t const * end = input + input_length; + std::uint8_t const * unknown = nullptr; + + reader.version = 0; + reader.one_time_key = nullptr; + reader.one_time_key_length = 0; + reader.identity_key = nullptr; + reader.identity_key_length = 0; + reader.base_key = nullptr; + reader.base_key_length = 0; + reader.message = nullptr; + reader.message_length = 0; + + if (pos == end) return; + reader.version = *(pos++); + + while (pos != end) { + unknown = pos; + pos = decode( + pos, end, ONE_TIME_KEY_ID_TAG, + reader.one_time_key, reader.one_time_key_length + ); + pos = decode( + pos, end, BASE_KEY_TAG, + reader.base_key, reader.base_key_length + ); + pos = decode( + pos, end, IDENTITY_KEY_TAG, + reader.identity_key, reader.identity_key_length + ); + pos = decode( + pos, end, MESSAGE_TAG, + reader.message, reader.message_length + ); + if (unknown == pos) { + pos = skip_unknown(pos, end); + } + } +} + + + +static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 010; +static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022; + +size_t _olm_encode_group_message_length( + uint32_t message_index, + size_t ciphertext_length, + size_t mac_length, + size_t signature_length +) { + size_t length = VERSION_LENGTH; + length += 1 + varint_length(message_index); + length += 1 + varstring_length(ciphertext_length); + length += mac_length; + length += signature_length; + return length; +} + + +size_t _olm_encode_group_message( + uint8_t version, + uint32_t message_index, + size_t ciphertext_length, + uint8_t *output, + uint8_t **ciphertext_ptr +) { + std::uint8_t * pos = output; + + *(pos++) = version; + pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index); + pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); + return pos-output; +} + +void _olm_decode_group_message( + const uint8_t *input, size_t input_length, + size_t mac_length, size_t signature_length, + struct _OlmDecodeGroupMessageResults *results +) { + std::uint8_t const * pos = input; + std::size_t trailer_length = mac_length + signature_length; + std::uint8_t const * end = input + input_length - trailer_length; + std::uint8_t const * unknown = nullptr; + + bool has_message_index = false; + results->version = 0; + results->message_index = 0; + results->has_message_index = (int)has_message_index; + results->ciphertext = nullptr; + results->ciphertext_length = 0; + + if (input_length < trailer_length) return; + + if (pos == end) return; + results->version = *(pos++); + + while (pos != end) { + unknown = pos; + pos = decode( + pos, end, GROUP_MESSAGE_INDEX_TAG, + results->message_index, has_message_index + ); + pos = decode( + pos, end, GROUP_CIPHERTEXT_TAG, + results->ciphertext, results->ciphertext_length + ); + if (unknown == pos) { + pos = skip_unknown(pos, end); + } + } + + results->has_message_index = (int)has_message_index; +} diff --git a/ext/olm/src/olm.cpp b/ext/olm/src/olm.cpp new file mode 100644 index 0000000..3a30f7a --- /dev/null +++ b/ext/olm/src/olm.cpp @@ -0,0 +1,846 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/olm.h" +#include "olm/session.hh" +#include "olm/account.hh" +#include "olm/cipher.h" +#include "olm/pickle_encoding.h" +#include "olm/utility.hh" +#include "olm/base64.hh" +#include "olm/memory.hh" + +#include <new> +#include <cstring> + +namespace { + +static OlmAccount * to_c(olm::Account * account) { + return reinterpret_cast<OlmAccount *>(account); +} + +static OlmSession * to_c(olm::Session * session) { + return reinterpret_cast<OlmSession *>(session); +} + +static OlmUtility * to_c(olm::Utility * utility) { + return reinterpret_cast<OlmUtility *>(utility); +} + +static olm::Account * from_c(OlmAccount * account) { + return reinterpret_cast<olm::Account *>(account); +} + +static const olm::Account * from_c(OlmAccount const * account) { + return reinterpret_cast<olm::Account const *>(account); +} + +static olm::Session * from_c(OlmSession * session) { + return reinterpret_cast<olm::Session *>(session); +} + +static const olm::Session * from_c(OlmSession const * session) { + return reinterpret_cast<const olm::Session *>(session); +} + +static olm::Utility * from_c(OlmUtility * utility) { + return reinterpret_cast<olm::Utility *>(utility); +} + +static const olm::Utility * from_c(OlmUtility const * utility) { + return reinterpret_cast<const olm::Utility *>(utility); +} + +static std::uint8_t * from_c(void * bytes) { + return reinterpret_cast<std::uint8_t *>(bytes); +} + +static std::uint8_t const * from_c(void const * bytes) { + return reinterpret_cast<std::uint8_t const *>(bytes); +} + +std::size_t b64_output_length( + size_t raw_length +) { + return olm::encode_base64_length(raw_length); +} + +std::uint8_t * b64_output_pos( + std::uint8_t * output, + size_t raw_length +) { + return output + olm::encode_base64_length(raw_length) - raw_length; +} + +std::size_t b64_output( + std::uint8_t * output, size_t raw_length +) { + std::size_t base64_length = olm::encode_base64_length(raw_length); + std::uint8_t * raw_output = output + base64_length - raw_length; + olm::encode_base64(raw_output, raw_length, output); + return base64_length; +} + +std::size_t b64_input( + std::uint8_t * input, size_t b64_length, + OlmErrorCode & last_error +) { + std::size_t raw_length = olm::decode_base64_length(b64_length); + if (raw_length == std::size_t(-1)) { + last_error = OlmErrorCode::OLM_INVALID_BASE64; + return std::size_t(-1); + } + olm::decode_base64(input, b64_length, input); + return raw_length; +} + +} // namespace + + +extern "C" { + +void olm_get_library_version(uint8_t *major, uint8_t *minor, uint8_t *patch) { + if (major != NULL) *major = OLMLIB_VERSION_MAJOR; + if (minor != NULL) *minor = OLMLIB_VERSION_MINOR; + if (patch != NULL) *patch = OLMLIB_VERSION_PATCH; +} + +size_t olm_error(void) { + return std::size_t(-1); +} + + +const char * olm_account_last_error( + const OlmAccount * account +) { + auto error = from_c(account)->last_error; + return _olm_error_to_string(error); +} + +enum OlmErrorCode olm_account_last_error_code( + const OlmAccount * account +) { + return from_c(account)->last_error; +} + +const char * olm_session_last_error( + const OlmSession * session +) { + auto error = from_c(session)->last_error; + return _olm_error_to_string(error); +} + +enum OlmErrorCode olm_session_last_error_code( + OlmSession const * session +) { + return from_c(session)->last_error; +} + +const char * olm_utility_last_error( + OlmUtility const * utility +) { + auto error = from_c(utility)->last_error; + return _olm_error_to_string(error); +} + +enum OlmErrorCode olm_utility_last_error_code( + OlmUtility const * utility +) { + return from_c(utility)->last_error; +} + +size_t olm_account_size(void) { + return sizeof(olm::Account); +} + + +size_t olm_session_size(void) { + return sizeof(olm::Session); +} + +size_t olm_utility_size(void) { + return sizeof(olm::Utility); +} + +OlmAccount * olm_account( + void * memory +) { + olm::unset(memory, sizeof(olm::Account)); + return to_c(new(memory) olm::Account()); +} + + +OlmSession * olm_session( + void * memory +) { + olm::unset(memory, sizeof(olm::Session)); + return to_c(new(memory) olm::Session()); +} + + +OlmUtility * olm_utility( + void * memory +) { + olm::unset(memory, sizeof(olm::Utility)); + return to_c(new(memory) olm::Utility()); +} + + +size_t olm_clear_account( + OlmAccount * account +) { + /* Clear the memory backing the account */ + olm::unset(account, sizeof(olm::Account)); + /* Initialise a fresh account object in case someone tries to use it */ + new(account) olm::Account(); + return sizeof(olm::Account); +} + + +size_t olm_clear_session( + OlmSession * session +) { + /* Clear the memory backing the session */ + olm::unset(session, sizeof(olm::Session)); + /* Initialise a fresh session object in case someone tries to use it */ + new(session) olm::Session(); + return sizeof(olm::Session); +} + + +size_t olm_clear_utility( + OlmUtility * utility +) { + /* Clear the memory backing the session */ + olm::unset(utility, sizeof(olm::Utility)); + /* Initialise a fresh session object in case someone tries to use it */ + new(utility) olm::Utility(); + return sizeof(olm::Utility); +} + + +size_t olm_pickle_account_length( + OlmAccount const * account +) { + return _olm_enc_output_length(pickle_length(*from_c(account))); +} + + +size_t olm_pickle_session_length( + OlmSession const * session +) { + return _olm_enc_output_length(pickle_length(*from_c(session))); +} + + +size_t olm_pickle_account( + OlmAccount * account, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + olm::Account & object = *from_c(account); + std::size_t raw_length = pickle_length(object); + if (pickled_length < _olm_enc_output_length(raw_length)) { + object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return size_t(-1); + } + pickle(_olm_enc_output_pos(from_c(pickled), raw_length), object); + return _olm_enc_output(from_c(key), key_length, from_c(pickled), raw_length); +} + + +size_t olm_pickle_session( + OlmSession * session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + olm::Session & object = *from_c(session); + std::size_t raw_length = pickle_length(object); + if (pickled_length < _olm_enc_output_length(raw_length)) { + object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return size_t(-1); + } + pickle(_olm_enc_output_pos(from_c(pickled), raw_length), object); + return _olm_enc_output(from_c(key), key_length, from_c(pickled), raw_length); +} + + +size_t olm_unpickle_account( + OlmAccount * account, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + olm::Account & object = *from_c(account); + std::uint8_t * input = from_c(pickled); + std::size_t raw_length = _olm_enc_input( + from_c(key), key_length, input, pickled_length, &object.last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + + std::uint8_t const * pos = input; + std::uint8_t const * end = pos + raw_length; + + pos = unpickle(pos, end, object); + + if (!pos) { + /* Input was corrupted. */ + if (object.last_error == OlmErrorCode::OLM_SUCCESS) { + object.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE; + } + return std::size_t(-1); + } else if (pos != end) { + /* Input was longer than expected. */ + object.last_error = OlmErrorCode::OLM_PICKLE_EXTRA_DATA; + return std::size_t(-1); + } + + return pickled_length; +} + + +size_t olm_unpickle_session( + OlmSession * session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + olm::Session & object = *from_c(session); + std::uint8_t * input = from_c(pickled); + std::size_t raw_length = _olm_enc_input( + from_c(key), key_length, input, pickled_length, &object.last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + + std::uint8_t const * pos = input; + std::uint8_t const * end = pos + raw_length; + + pos = unpickle(pos, end, object); + + if (!pos) { + /* Input was corrupted. */ + if (object.last_error == OlmErrorCode::OLM_SUCCESS) { + object.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE; + } + return std::size_t(-1); + } else if (pos != end) { + /* Input was longer than expected. */ + object.last_error = OlmErrorCode::OLM_PICKLE_EXTRA_DATA; + return std::size_t(-1); + } + + return pickled_length; +} + + +size_t olm_create_account_random_length( + OlmAccount const * account +) { + return from_c(account)->new_account_random_length(); +} + + +size_t olm_create_account( + OlmAccount * account, + void * random, size_t random_length +) { + size_t result = from_c(account)->new_account(from_c(random), random_length); + olm::unset(random, random_length); + return result; +} + + +size_t olm_account_identity_keys_length( + OlmAccount const * account +) { + return from_c(account)->get_identity_json_length(); +} + + +size_t olm_account_identity_keys( + OlmAccount * account, + void * identity_keys, size_t identity_key_length +) { + return from_c(account)->get_identity_json( + from_c(identity_keys), identity_key_length + ); +} + + +size_t olm_account_signature_length( + OlmAccount const * account +) { + return b64_output_length(from_c(account)->signature_length()); +} + + +size_t olm_account_sign( + OlmAccount * account, + void const * message, size_t message_length, + void * signature, size_t signature_length +) { + std::size_t raw_length = from_c(account)->signature_length(); + if (signature_length < b64_output_length(raw_length)) { + from_c(account)->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + from_c(account)->sign( + from_c(message), message_length, + b64_output_pos(from_c(signature), raw_length), raw_length + ); + return b64_output(from_c(signature), raw_length); +} + + +size_t olm_account_one_time_keys_length( + OlmAccount const * account +) { + return from_c(account)->get_one_time_keys_json_length(); +} + + +size_t olm_account_one_time_keys( + OlmAccount * account, + void * one_time_keys_json, size_t one_time_key_json_length +) { + return from_c(account)->get_one_time_keys_json( + from_c(one_time_keys_json), one_time_key_json_length + ); +} + + +size_t olm_account_mark_keys_as_published( + OlmAccount * account +) { + return from_c(account)->mark_keys_as_published(); +} + + +size_t olm_account_max_number_of_one_time_keys( + OlmAccount const * account +) { + return from_c(account)->max_number_of_one_time_keys(); +} + + +size_t olm_account_generate_one_time_keys_random_length( + OlmAccount const * account, + size_t number_of_keys +) { + return from_c(account)->generate_one_time_keys_random_length(number_of_keys); +} + + +size_t olm_account_generate_one_time_keys( + OlmAccount * account, + size_t number_of_keys, + void * random, size_t random_length +) { + size_t result = from_c(account)->generate_one_time_keys( + number_of_keys, + from_c(random), random_length + ); + olm::unset(random, random_length); + return result; +} + + +size_t olm_account_generate_fallback_key_random_length( + OlmAccount const * account +) { + return from_c(account)->generate_fallback_key_random_length(); +} + + +size_t olm_account_generate_fallback_key( + OlmAccount * account, + void * random, size_t random_length +) { + size_t result = from_c(account)->generate_fallback_key( + from_c(random), random_length + ); + olm::unset(random, random_length); + return result; +} + + +size_t olm_account_fallback_key_length( + OlmAccount const * account +) { + return from_c(account)->get_fallback_key_json_length(); +} + + +size_t olm_account_fallback_key( + OlmAccount * account, + void * fallback_key_json, size_t fallback_key_json_length +) { + return from_c(account)->get_fallback_key_json( + from_c(fallback_key_json), fallback_key_json_length + ); +} + + +size_t olm_account_unpublished_fallback_key_length( + OlmAccount const * account +) { + return from_c(account)->get_unpublished_fallback_key_json_length(); +} + + +size_t olm_account_unpublished_fallback_key( + OlmAccount * account, + void * fallback_key_json, size_t fallback_key_json_length +) { + return from_c(account)->get_unpublished_fallback_key_json( + from_c(fallback_key_json), fallback_key_json_length + ); +} + + +void olm_account_forget_old_fallback_key( + OlmAccount * account +) { + return from_c(account)->forget_old_fallback_key(); +} + + +size_t olm_create_outbound_session_random_length( + OlmSession const * session +) { + return from_c(session)->new_outbound_session_random_length(); +} + + +size_t olm_create_outbound_session( + OlmSession * session, + OlmAccount const * account, + void const * their_identity_key, size_t their_identity_key_length, + void const * their_one_time_key, size_t their_one_time_key_length, + void * random, size_t random_length +) { + std::uint8_t const * id_key = from_c(their_identity_key); + std::uint8_t const * ot_key = from_c(their_one_time_key); + std::size_t id_key_length = their_identity_key_length; + std::size_t ot_key_length = their_one_time_key_length; + + if (olm::decode_base64_length(id_key_length) != CURVE25519_KEY_LENGTH + || olm::decode_base64_length(ot_key_length) != CURVE25519_KEY_LENGTH + ) { + from_c(session)->last_error = OlmErrorCode::OLM_INVALID_BASE64; + return std::size_t(-1); + } + _olm_curve25519_public_key identity_key; + _olm_curve25519_public_key one_time_key; + + olm::decode_base64(id_key, id_key_length, identity_key.public_key); + olm::decode_base64(ot_key, ot_key_length, one_time_key.public_key); + + size_t result = from_c(session)->new_outbound_session( + *from_c(account), identity_key, one_time_key, + from_c(random), random_length + ); + olm::unset(random, random_length); + return result; +} + + +size_t olm_create_inbound_session( + OlmSession * session, + OlmAccount * account, + void * one_time_key_message, size_t message_length +) { + std::size_t raw_length = b64_input( + from_c(one_time_key_message), message_length, from_c(session)->last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + return from_c(session)->new_inbound_session( + *from_c(account), nullptr, from_c(one_time_key_message), raw_length + ); +} + + +size_t olm_create_inbound_session_from( + OlmSession * session, + OlmAccount * account, + void const * their_identity_key, size_t their_identity_key_length, + void * one_time_key_message, size_t message_length +) { + std::uint8_t const * id_key = from_c(their_identity_key); + std::size_t id_key_length = their_identity_key_length; + + if (olm::decode_base64_length(id_key_length) != CURVE25519_KEY_LENGTH) { + from_c(session)->last_error = OlmErrorCode::OLM_INVALID_BASE64; + return std::size_t(-1); + } + _olm_curve25519_public_key identity_key; + olm::decode_base64(id_key, id_key_length, identity_key.public_key); + + std::size_t raw_length = b64_input( + from_c(one_time_key_message), message_length, from_c(session)->last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + return from_c(session)->new_inbound_session( + *from_c(account), &identity_key, + from_c(one_time_key_message), raw_length + ); +} + + +size_t olm_session_id_length( + OlmSession const * session +) { + return b64_output_length(from_c(session)->session_id_length()); +} + +size_t olm_session_id( + OlmSession * session, + void * id, size_t id_length +) { + std::size_t raw_length = from_c(session)->session_id_length(); + if (id_length < b64_output_length(raw_length)) { + from_c(session)->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + std::size_t result = from_c(session)->session_id( + b64_output_pos(from_c(id), raw_length), raw_length + ); + if (result == std::size_t(-1)) { + return result; + } + return b64_output(from_c(id), raw_length); +} + + +int olm_session_has_received_message( + OlmSession const * session +) { + return from_c(session)->received_message; +} + +void olm_session_describe( + OlmSession * session, char *buf, size_t buflen +) { + from_c(session)->describe(buf, buflen); +} + +size_t olm_matches_inbound_session( + OlmSession * session, + void * one_time_key_message, size_t message_length +) { + std::size_t raw_length = b64_input( + from_c(one_time_key_message), message_length, from_c(session)->last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + bool matches = from_c(session)->matches_inbound_session( + nullptr, from_c(one_time_key_message), raw_length + ); + return matches ? 1 : 0; +} + + +size_t olm_matches_inbound_session_from( + OlmSession * session, + void const * their_identity_key, size_t their_identity_key_length, + void * one_time_key_message, size_t message_length +) { + std::uint8_t const * id_key = from_c(their_identity_key); + std::size_t id_key_length = their_identity_key_length; + + if (olm::decode_base64_length(id_key_length) != CURVE25519_KEY_LENGTH) { + from_c(session)->last_error = OlmErrorCode::OLM_INVALID_BASE64; + return std::size_t(-1); + } + _olm_curve25519_public_key identity_key; + olm::decode_base64(id_key, id_key_length, identity_key.public_key); + + std::size_t raw_length = b64_input( + from_c(one_time_key_message), message_length, from_c(session)->last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + bool matches = from_c(session)->matches_inbound_session( + &identity_key, from_c(one_time_key_message), raw_length + ); + return matches ? 1 : 0; +} + + +size_t olm_remove_one_time_keys( + OlmAccount * account, + OlmSession * session +) { + size_t result = from_c(account)->remove_key( + from_c(session)->bob_one_time_key + ); + if (result == std::size_t(-1)) { + from_c(account)->last_error = OlmErrorCode::OLM_BAD_MESSAGE_KEY_ID; + } + return result; +} + + +size_t olm_encrypt_message_type( + OlmSession const * session +) { + return size_t(from_c(session)->encrypt_message_type()); +} + + +size_t olm_encrypt_random_length( + OlmSession const * session +) { + return from_c(session)->encrypt_random_length(); +} + + +size_t olm_encrypt_message_length( + OlmSession const * session, + size_t plaintext_length +) { + return b64_output_length( + from_c(session)->encrypt_message_length(plaintext_length) + ); +} + + +size_t olm_encrypt( + OlmSession * session, + void const * plaintext, size_t plaintext_length, + void * random, size_t random_length, + void * message, size_t message_length +) { + std::size_t raw_length = from_c(session)->encrypt_message_length( + plaintext_length + ); + if (message_length < b64_output_length(raw_length)) { + from_c(session)->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + std::size_t result = from_c(session)->encrypt( + from_c(plaintext), plaintext_length, + from_c(random), random_length, + b64_output_pos(from_c(message), raw_length), raw_length + ); + olm::unset(random, random_length); + if (result == std::size_t(-1)) { + return result; + } + return b64_output(from_c(message), raw_length); +} + + +size_t olm_decrypt_max_plaintext_length( + OlmSession * session, + size_t message_type, + void * message, size_t message_length +) { + std::size_t raw_length = b64_input( + from_c(message), message_length, from_c(session)->last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + return from_c(session)->decrypt_max_plaintext_length( + olm::MessageType(message_type), from_c(message), raw_length + ); +} + + +size_t olm_decrypt( + OlmSession * session, + size_t message_type, + void * message, size_t message_length, + void * plaintext, size_t max_plaintext_length +) { + std::size_t raw_length = b64_input( + from_c(message), message_length, from_c(session)->last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + return from_c(session)->decrypt( + olm::MessageType(message_type), from_c(message), raw_length, + from_c(plaintext), max_plaintext_length + ); +} + + +size_t olm_sha256_length( + OlmUtility const * utility +) { + return b64_output_length(from_c(utility)->sha256_length()); +} + + +size_t olm_sha256( + OlmUtility * utility, + void const * input, size_t input_length, + void * output, size_t output_length +) { + std::size_t raw_length = from_c(utility)->sha256_length(); + if (output_length < b64_output_length(raw_length)) { + from_c(utility)->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + std::size_t result = from_c(utility)->sha256( + from_c(input), input_length, + b64_output_pos(from_c(output), raw_length), raw_length + ); + if (result == std::size_t(-1)) { + return result; + } + return b64_output(from_c(output), raw_length); +} + + +size_t olm_ed25519_verify( + OlmUtility * utility, + void const * key, size_t key_length, + void const * message, size_t message_length, + void * signature, size_t signature_length +) { + if (olm::decode_base64_length(key_length) != CURVE25519_KEY_LENGTH) { + from_c(utility)->last_error = OlmErrorCode::OLM_INVALID_BASE64; + return std::size_t(-1); + } + _olm_ed25519_public_key verify_key; + olm::decode_base64(from_c(key), key_length, verify_key.public_key); + std::size_t raw_signature_length = b64_input( + from_c(signature), signature_length, from_c(utility)->last_error + ); + if (raw_signature_length == std::size_t(-1)) { + return std::size_t(-1); + } + return from_c(utility)->ed25519_verify( + verify_key, + from_c(message), message_length, + from_c(signature), raw_signature_length + ); +} + +} diff --git a/ext/olm/src/outbound_group_session.c b/ext/olm/src/outbound_group_session.c new file mode 100644 index 0000000..cbbba9c --- /dev/null +++ b/ext/olm/src/outbound_group_session.c @@ -0,0 +1,390 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/outbound_group_session.h" + +#include <string.h> + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/crypto.h" +#include "olm/error.h" +#include "olm/megolm.h" +#include "olm/memory.h" +#include "olm/message.h" +#include "olm/pickle.h" +#include "olm/pickle_encoding.h" + +#define OLM_PROTOCOL_VERSION 3 +#define GROUP_SESSION_ID_LENGTH ED25519_PUBLIC_KEY_LENGTH +#define PICKLE_VERSION 1 +#define SESSION_KEY_VERSION 2 + +struct OlmOutboundGroupSession { + /** the Megolm ratchet providing the encryption keys */ + Megolm ratchet; + + /** The ed25519 keypair used for signing the messages */ + struct _olm_ed25519_key_pair signing_key; + + enum OlmErrorCode last_error; +}; + + +size_t olm_outbound_group_session_size(void) { + return sizeof(OlmOutboundGroupSession); +} + +OlmOutboundGroupSession * olm_outbound_group_session( + void *memory +) { + OlmOutboundGroupSession *session = memory; + olm_clear_outbound_group_session(session); + return session; +} + +const char *olm_outbound_group_session_last_error( + const OlmOutboundGroupSession *session +) { + return _olm_error_to_string(session->last_error); +} + +enum OlmErrorCode olm_outbound_group_session_last_error_code( + const OlmOutboundGroupSession *session +) { + return session->last_error; +} + +size_t olm_clear_outbound_group_session( + OlmOutboundGroupSession *session +) { + _olm_unset(session, sizeof(OlmOutboundGroupSession)); + return sizeof(OlmOutboundGroupSession); +} + +static size_t raw_pickle_length( + const OlmOutboundGroupSession *session +) { + size_t length = 0; + length += _olm_pickle_uint32_length(PICKLE_VERSION); + length += megolm_pickle_length(&(session->ratchet)); + length += _olm_pickle_ed25519_key_pair_length(&(session->signing_key)); + return length; +} + +size_t olm_pickle_outbound_group_session_length( + const OlmOutboundGroupSession *session +) { + return _olm_enc_output_length(raw_pickle_length(session)); +} + +size_t olm_pickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + size_t raw_length = raw_pickle_length(session); + uint8_t *pos; + + if (pickled_length < _olm_enc_output_length(raw_length)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + +#ifndef OLM_FUZZING + pos = _olm_enc_output_pos(pickled, raw_length); +#else + pos = pickled; +#endif + + pos = _olm_pickle_uint32(pos, PICKLE_VERSION); + pos = megolm_pickle(&(session->ratchet), pos); + pos = _olm_pickle_ed25519_key_pair(pos, &(session->signing_key)); + +#ifndef OLM_FUZZING + return _olm_enc_output(key, key_length, pickled, raw_length); +#else + return raw_length; +#endif +} + +size_t olm_unpickle_outbound_group_session( + OlmOutboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + const uint8_t *pos; + const uint8_t *end; + uint32_t pickle_version; + +#ifndef OLM_FUZZING + size_t raw_length = _olm_enc_input( + key, key_length, pickled, pickled_length, &(session->last_error) + ); +#else + size_t raw_length = pickled_length; +#endif + + if (raw_length == (size_t)-1) { + return raw_length; + } + + pos = pickled; + end = pos + raw_length; + + pos = _olm_unpickle_uint32(pos, end, &pickle_version); + FAIL_ON_CORRUPTED_PICKLE(pos, session); + + if (pickle_version != PICKLE_VERSION) { + session->last_error = OLM_UNKNOWN_PICKLE_VERSION; + return (size_t)-1; + } + + pos = megolm_unpickle(&(session->ratchet), pos, end); + FAIL_ON_CORRUPTED_PICKLE(pos, session); + + pos = _olm_unpickle_ed25519_key_pair(pos, end, &(session->signing_key)); + FAIL_ON_CORRUPTED_PICKLE(pos, session); + + if (pos != end) { + /* Input was longer than expected. */ + session->last_error = OLM_PICKLE_EXTRA_DATA; + return (size_t)-1; + } + + return pickled_length; +} + + +size_t olm_init_outbound_group_session_random_length( + const OlmOutboundGroupSession *session +) { + /* we need data to initialize the megolm ratchet, plus some more for the + * session id. + */ + return MEGOLM_RATCHET_LENGTH + + ED25519_RANDOM_LENGTH; +} + +size_t olm_init_outbound_group_session( + OlmOutboundGroupSession *session, + uint8_t *random, size_t random_length +) { + const uint8_t *random_ptr = random; + + if (random_length < olm_init_outbound_group_session_random_length(session)) { + /* Insufficient random data for new session */ + session->last_error = OLM_NOT_ENOUGH_RANDOM; + return (size_t)-1; + } + + megolm_init(&(session->ratchet), random_ptr, 0); + random_ptr += MEGOLM_RATCHET_LENGTH; + + _olm_crypto_ed25519_generate_key(random_ptr, &(session->signing_key)); + random_ptr += ED25519_RANDOM_LENGTH; + + _olm_unset(random, random_length); + return 0; +} + +static size_t raw_message_length( + OlmOutboundGroupSession *session, + size_t plaintext_length) +{ + size_t ciphertext_length, mac_length; + + ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length( + megolm_cipher, plaintext_length + ); + + mac_length = megolm_cipher->ops->mac_length(megolm_cipher); + + return _olm_encode_group_message_length( + session->ratchet.counter, + ciphertext_length, mac_length, ED25519_SIGNATURE_LENGTH + ); +} + +size_t olm_group_encrypt_message_length( + OlmOutboundGroupSession *session, + size_t plaintext_length +) { + size_t message_length = raw_message_length(session, plaintext_length); + return _olm_encode_base64_length(message_length); +} + +/** write an un-base64-ed message to the buffer */ +static size_t _encrypt( + OlmOutboundGroupSession *session, uint8_t const * plaintext, size_t plaintext_length, + uint8_t * buffer +) { + size_t ciphertext_length, mac_length, message_length; + size_t result; + uint8_t *ciphertext_ptr; + + ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length( + megolm_cipher, + plaintext_length + ); + + mac_length = megolm_cipher->ops->mac_length(megolm_cipher); + + /* first we build the message structure, then we encrypt + * the plaintext into it. + */ + message_length = _olm_encode_group_message( + OLM_PROTOCOL_VERSION, + session->ratchet.counter, + ciphertext_length, + buffer, + &ciphertext_ptr); + + message_length += mac_length; + + result = megolm_cipher->ops->encrypt( + megolm_cipher, + megolm_get_data(&(session->ratchet)), MEGOLM_RATCHET_LENGTH, + plaintext, plaintext_length, + ciphertext_ptr, ciphertext_length, + buffer, message_length + ); + + if (result == (size_t)-1) { + return result; + } + + megolm_advance(&(session->ratchet)); + + /* sign the whole thing with the ed25519 key. */ + _olm_crypto_ed25519_sign( + &(session->signing_key), + buffer, message_length, + buffer + message_length + ); + + return result; +} + +size_t olm_group_encrypt( + OlmOutboundGroupSession *session, + uint8_t const * plaintext, size_t plaintext_length, + uint8_t * message, size_t max_message_length +) { + size_t rawmsglen; + size_t result; + uint8_t *message_pos; + + rawmsglen = raw_message_length(session, plaintext_length); + + if (max_message_length < _olm_encode_base64_length(rawmsglen)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + /* we construct the message at the end of the buffer, so that + * we have room to base64-encode it once we're done. + */ + message_pos = message + _olm_encode_base64_length(rawmsglen) - rawmsglen; + + /* write the message, and encrypt it, at message_pos */ + result = _encrypt(session, plaintext, plaintext_length, message_pos); + if (result == (size_t)-1) { + return result; + } + + /* bas64-encode it */ + return _olm_encode_base64( + message_pos, rawmsglen, message + ); +} + + +size_t olm_outbound_group_session_id_length( + const OlmOutboundGroupSession *session +) { + return _olm_encode_base64_length(GROUP_SESSION_ID_LENGTH); +} + +size_t olm_outbound_group_session_id( + OlmOutboundGroupSession *session, + uint8_t * id, size_t id_length +) { + if (id_length < olm_outbound_group_session_id_length(session)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + return _olm_encode_base64( + session->signing_key.public_key.public_key, GROUP_SESSION_ID_LENGTH, id + ); +} + +uint32_t olm_outbound_group_session_message_index( + OlmOutboundGroupSession *session +) { + return session->ratchet.counter; +} + +#define SESSION_KEY_RAW_LENGTH \ + (1 + 4 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH\ + + ED25519_SIGNATURE_LENGTH) + +size_t olm_outbound_group_session_key_length( + const OlmOutboundGroupSession *session +) { + return _olm_encode_base64_length(SESSION_KEY_RAW_LENGTH); +} + +size_t olm_outbound_group_session_key( + OlmOutboundGroupSession *session, + uint8_t * key, size_t key_length +) { + uint8_t *raw; + uint8_t *ptr; + size_t encoded_length = olm_outbound_group_session_key_length(session); + + if (key_length < encoded_length) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + /* put the raw data at the end of the output buffer. */ + raw = ptr = key + encoded_length - SESSION_KEY_RAW_LENGTH; + *ptr++ = SESSION_KEY_VERSION; + + uint32_t counter = session->ratchet.counter; + // Encode counter as a big endian 32-bit number. + for (unsigned i = 0; i < 4; i++) { + *ptr++ = 0xFF & (counter >> 24); counter <<= 8; + } + + memcpy(ptr, megolm_get_data(&session->ratchet), MEGOLM_RATCHET_LENGTH); + ptr += MEGOLM_RATCHET_LENGTH; + + memcpy( + ptr, session->signing_key.public_key.public_key, + ED25519_PUBLIC_KEY_LENGTH + ); + ptr += ED25519_PUBLIC_KEY_LENGTH; + + /* sign the whole thing with the ed25519 key. */ + _olm_crypto_ed25519_sign( + &(session->signing_key), + raw, ptr - raw, ptr + ); + + return _olm_encode_base64(raw, SESSION_KEY_RAW_LENGTH, key); +} diff --git a/ext/olm/src/pickle.cpp b/ext/olm/src/pickle.cpp new file mode 100644 index 0000000..3bffb36 --- /dev/null +++ b/ext/olm/src/pickle.cpp @@ -0,0 +1,274 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/pickle.hh" +#include "olm/pickle.h" + +std::uint8_t * olm::pickle( + std::uint8_t * pos, + std::uint32_t value +) { + pos += 4; + for (unsigned i = 4; i--;) { *(--pos) = value; value >>= 8; } + return pos + 4; +} + + +std::uint8_t const * olm::unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint32_t & value +) { + value = 0; + if (!pos || end < pos + 4) return nullptr; + for (unsigned i = 4; i--;) { value <<= 8; value |= *(pos++); } + return pos; +} + +std::uint8_t * olm::pickle( + std::uint8_t * pos, + std::uint8_t value +) { + *(pos++) = value; + return pos; +} + +std::uint8_t const * olm::unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint8_t & value +) { + if (!pos || pos == end) return nullptr; + value = *(pos++); + return pos; +} + +std::uint8_t * olm::pickle( + std::uint8_t * pos, + bool value +) { + *(pos++) = value ? 1 : 0; + return pos; +} + +std::uint8_t const * olm::unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + bool & value +) { + if (!pos || pos == end) return nullptr; + value = *(pos++); + return pos; +} + +std::uint8_t * olm::pickle_bytes( + std::uint8_t * pos, + std::uint8_t const * bytes, std::size_t bytes_length +) { + std::memcpy(pos, bytes, bytes_length); + return pos + bytes_length; +} + +std::uint8_t const * olm::unpickle_bytes( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint8_t * bytes, std::size_t bytes_length +) { + if (!pos || end < pos + bytes_length) return nullptr; + std::memcpy(bytes, pos, bytes_length); + return pos + bytes_length; +} + + +std::size_t olm::pickle_length( + const _olm_curve25519_public_key & value +) { + return sizeof(value.public_key); +} + + +std::uint8_t * olm::pickle( + std::uint8_t * pos, + const _olm_curve25519_public_key & value +) { + pos = olm::pickle_bytes( + pos, value.public_key, sizeof(value.public_key) + ); + return pos; +} + + +std::uint8_t const * olm::unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + _olm_curve25519_public_key & value +) { + return olm::unpickle_bytes( + pos, end, value.public_key, sizeof(value.public_key) + ); +} + + +std::size_t olm::pickle_length( + const _olm_curve25519_key_pair & value +) { + return sizeof(value.public_key.public_key) + + sizeof(value.private_key.private_key); +} + + +std::uint8_t * olm::pickle( + std::uint8_t * pos, + const _olm_curve25519_key_pair & value +) { + pos = olm::pickle_bytes( + pos, value.public_key.public_key, + sizeof(value.public_key.public_key) + ); + pos = olm::pickle_bytes( + pos, value.private_key.private_key, + sizeof(value.private_key.private_key) + ); + return pos; +} + + +std::uint8_t const * olm::unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + _olm_curve25519_key_pair & value +) { + pos = olm::unpickle_bytes( + pos, end, value.public_key.public_key, + sizeof(value.public_key.public_key) + ); + if (!pos) return nullptr; + + pos = olm::unpickle_bytes( + pos, end, value.private_key.private_key, + sizeof(value.private_key.private_key) + ); + if (!pos) return nullptr; + + return pos; +} + +////// pickle.h implementations + +std::size_t _olm_pickle_ed25519_public_key_length( + const _olm_ed25519_public_key * value +) { + return sizeof(value->public_key); +} + + +std::uint8_t * _olm_pickle_ed25519_public_key( + std::uint8_t * pos, + const _olm_ed25519_public_key *value +) { + return olm::pickle_bytes( + pos, value->public_key, sizeof(value->public_key) + ); +} + + +std::uint8_t const * _olm_unpickle_ed25519_public_key( + std::uint8_t const * pos, std::uint8_t const * end, + _olm_ed25519_public_key * value +) { + return olm::unpickle_bytes( + pos, end, value->public_key, sizeof(value->public_key) + ); +} + + +std::size_t _olm_pickle_ed25519_key_pair_length( + const _olm_ed25519_key_pair *value +) { + return sizeof(value->public_key.public_key) + + sizeof(value->private_key.private_key); +} + + +std::uint8_t * _olm_pickle_ed25519_key_pair( + std::uint8_t * pos, + const _olm_ed25519_key_pair *value +) { + pos = olm::pickle_bytes( + pos, value->public_key.public_key, + sizeof(value->public_key.public_key) + ); + pos = olm::pickle_bytes( + pos, value->private_key.private_key, + sizeof(value->private_key.private_key) + ); + return pos; +} + + +std::uint8_t const * _olm_unpickle_ed25519_key_pair( + std::uint8_t const * pos, std::uint8_t const * end, + _olm_ed25519_key_pair *value +) { + pos = olm::unpickle_bytes( + pos, end, value->public_key.public_key, + sizeof(value->public_key.public_key) + ); + if (!pos) return nullptr; + + pos = olm::unpickle_bytes( + pos, end, value->private_key.private_key, + sizeof(value->private_key.private_key) + ); + if (!pos) return nullptr; + + return pos; +} + +uint8_t * _olm_pickle_uint32(uint8_t * pos, uint32_t value) { + return olm::pickle(pos, value); +} + +uint8_t const * _olm_unpickle_uint32( + uint8_t const * pos, uint8_t const * end, + uint32_t *value +) { + return olm::unpickle(pos, end, *value); +} + +uint8_t * _olm_pickle_uint8(uint8_t * pos, uint8_t value) { + return olm::pickle(pos, value); +} + +uint8_t const * _olm_unpickle_uint8( + uint8_t const * pos, uint8_t const * end, + uint8_t *value +) { + return olm::unpickle(pos, end, *value); +} + +uint8_t * _olm_pickle_bool(uint8_t * pos, int value) { + return olm::pickle(pos, (bool)value); +} + +uint8_t const * _olm_unpickle_bool( + uint8_t const * pos, uint8_t const * end, + int *value +) { + return olm::unpickle(pos, end, *reinterpret_cast<bool *>(value)); +} + +uint8_t * _olm_pickle_bytes(uint8_t * pos, uint8_t const * bytes, + size_t bytes_length) { + return olm::pickle_bytes(pos, bytes, bytes_length); +} + +uint8_t const * _olm_unpickle_bytes(uint8_t const * pos, uint8_t const * end, + uint8_t * bytes, size_t bytes_length) { + return olm::unpickle_bytes(pos, end, bytes, bytes_length); +} diff --git a/ext/olm/src/pickle_encoding.c b/ext/olm/src/pickle_encoding.c new file mode 100644 index 0000000..a56e9e3 --- /dev/null +++ b/ext/olm/src/pickle_encoding.c @@ -0,0 +1,92 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/pickle_encoding.h" + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/olm.h" + +static const struct _olm_cipher_aes_sha_256 PICKLE_CIPHER = + OLM_CIPHER_INIT_AES_SHA_256("Pickle"); + +size_t _olm_enc_output_length( + size_t raw_length +) { + const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); + size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); + length += cipher->ops->mac_length(cipher); + return _olm_encode_base64_length(length); +} + +uint8_t * _olm_enc_output_pos( + uint8_t * output, + size_t raw_length +) { + const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); + size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); + length += cipher->ops->mac_length(cipher); + return output + _olm_encode_base64_length(length) - length; +} + +size_t _olm_enc_output( + uint8_t const * key, size_t key_length, + uint8_t * output, size_t raw_length +) { + const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); + size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length( + cipher, raw_length + ); + size_t length = ciphertext_length + cipher->ops->mac_length(cipher); + size_t base64_length = _olm_encode_base64_length(length); + uint8_t * raw_output = output + base64_length - length; + cipher->ops->encrypt( + cipher, + key, key_length, + raw_output, raw_length, + raw_output, ciphertext_length, + raw_output, length + ); + _olm_encode_base64(raw_output, length, output); + return base64_length; +} + + +size_t _olm_enc_input(uint8_t const * key, size_t key_length, + uint8_t * input, size_t b64_length, + enum OlmErrorCode * last_error +) { + size_t enc_length = _olm_decode_base64_length(b64_length); + if (enc_length == (size_t)-1) { + if (last_error) { + *last_error = OLM_INVALID_BASE64; + } + return (size_t)-1; + } + _olm_decode_base64(input, b64_length, input); + const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); + size_t raw_length = enc_length - cipher->ops->mac_length(cipher); + size_t result = cipher->ops->decrypt( + cipher, + key, key_length, + input, enc_length, + input, raw_length, + input, raw_length + ); + if (result == (size_t)-1 && last_error) { + *last_error = OLM_BAD_ACCOUNT_KEY; + } + return result; +} diff --git a/ext/olm/src/pk.cpp b/ext/olm/src/pk.cpp new file mode 100644 index 0000000..9217c48 --- /dev/null +++ b/ext/olm/src/pk.cpp @@ -0,0 +1,542 @@ +/* Copyright 2018, 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/pk.h" +#include "olm/cipher.h" +#include "olm/crypto.h" +#include "olm/ratchet.hh" +#include "olm/error.h" +#include "olm/memory.hh" +#include "olm/base64.hh" +#include "olm/pickle_encoding.h" +#include "olm/pickle.hh" + +static const std::size_t MAC_LENGTH = 8; + +const struct _olm_cipher_aes_sha_256 olm_pk_cipher_aes_sha256 = + OLM_CIPHER_INIT_AES_SHA_256(""); +const struct _olm_cipher *olm_pk_cipher = + OLM_CIPHER_BASE(&olm_pk_cipher_aes_sha256); + +extern "C" { + +struct OlmPkEncryption { + OlmErrorCode last_error; + _olm_curve25519_public_key recipient_key; +}; + +const char * olm_pk_encryption_last_error( + const OlmPkEncryption * encryption +) { + auto error = encryption->last_error; + return _olm_error_to_string(error); +} + +OlmErrorCode olm_pk_encryption_last_error_code( + const OlmPkEncryption * encryption +) { + return encryption->last_error; +} + +size_t olm_pk_encryption_size(void) { + return sizeof(OlmPkEncryption); +} + +OlmPkEncryption *olm_pk_encryption( + void * memory +) { + olm::unset(memory, sizeof(OlmPkEncryption)); + return new(memory) OlmPkEncryption; +} + +size_t olm_clear_pk_encryption( + OlmPkEncryption *encryption +) { + /* Clear the memory backing the encryption */ + olm::unset(encryption, sizeof(OlmPkEncryption)); + /* Initialise a fresh encryption object in case someone tries to use it */ + new(encryption) OlmPkEncryption(); + return sizeof(OlmPkEncryption); +} + +size_t olm_pk_encryption_set_recipient_key ( + OlmPkEncryption *encryption, + void const * key, size_t key_length +) { + if (key_length < olm_pk_key_length()) { + encryption->last_error = + OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + + olm::decode_base64( + (const uint8_t*)key, + olm_pk_key_length(), + (uint8_t *)encryption->recipient_key.public_key + ); + + return 0; +} + +size_t olm_pk_ciphertext_length( + const OlmPkEncryption *encryption, + size_t plaintext_length +) { + return olm::encode_base64_length( + _olm_cipher_aes_sha_256_ops.encrypt_ciphertext_length(olm_pk_cipher, plaintext_length) + ); +} + +size_t olm_pk_mac_length( + const OlmPkEncryption *encryption +) { + return olm::encode_base64_length(_olm_cipher_aes_sha_256_ops.mac_length(olm_pk_cipher)); +} + +size_t olm_pk_encrypt_random_length( + const OlmPkEncryption *encryption +) { + return CURVE25519_KEY_LENGTH; +} + +size_t olm_pk_encrypt( + OlmPkEncryption *encryption, + void const * plaintext, size_t plaintext_length, + void * ciphertext, size_t ciphertext_length, + void * mac, size_t mac_length, + void * ephemeral_key, size_t ephemeral_key_size, + const void * random, size_t random_length +) { + if (ciphertext_length + < olm_pk_ciphertext_length(encryption, plaintext_length) + || mac_length + < _olm_cipher_aes_sha_256_ops.mac_length(olm_pk_cipher) + || ephemeral_key_size + < olm_pk_key_length()) { + encryption->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + if (random_length < olm_pk_encrypt_random_length(encryption)) { + encryption->last_error = + OlmErrorCode::OLM_NOT_ENOUGH_RANDOM; + return std::size_t(-1); + } + + _olm_curve25519_key_pair ephemeral_keypair; + _olm_crypto_curve25519_generate_key((const uint8_t *) random, &ephemeral_keypair); + olm::encode_base64( + (const uint8_t *)ephemeral_keypair.public_key.public_key, + CURVE25519_KEY_LENGTH, + (uint8_t *)ephemeral_key + ); + + olm::SharedKey secret; + _olm_crypto_curve25519_shared_secret(&ephemeral_keypair, &encryption->recipient_key, secret); + size_t raw_ciphertext_length = + _olm_cipher_aes_sha_256_ops.encrypt_ciphertext_length(olm_pk_cipher, plaintext_length); + uint8_t *ciphertext_pos = (uint8_t *) ciphertext + ciphertext_length - raw_ciphertext_length; + uint8_t raw_mac[MAC_LENGTH]; + size_t result = _olm_cipher_aes_sha_256_ops.encrypt( + olm_pk_cipher, + secret, sizeof(secret), + (const uint8_t *) plaintext, plaintext_length, + (uint8_t *) ciphertext_pos, raw_ciphertext_length, + (uint8_t *) raw_mac, MAC_LENGTH + ); + if (result != std::size_t(-1)) { + olm::encode_base64(raw_mac, MAC_LENGTH, (uint8_t *)mac); + olm::encode_base64(ciphertext_pos, raw_ciphertext_length, (uint8_t *)ciphertext); + } + return result; +} + +struct OlmPkDecryption { + OlmErrorCode last_error; + _olm_curve25519_key_pair key_pair; +}; + +const char * olm_pk_decryption_last_error( + const OlmPkDecryption * decryption +) { + auto error = decryption->last_error; + return _olm_error_to_string(error); +} + +OlmErrorCode olm_pk_decryption_last_error_code( + const OlmPkDecryption * decryption +) { + return decryption->last_error; +} + +size_t olm_pk_decryption_size(void) { + return sizeof(OlmPkDecryption); +} + +OlmPkDecryption *olm_pk_decryption( + void * memory +) { + olm::unset(memory, sizeof(OlmPkDecryption)); + return new(memory) OlmPkDecryption; +} + +size_t olm_clear_pk_decryption( + OlmPkDecryption *decryption +) { + /* Clear the memory backing the decryption */ + olm::unset(decryption, sizeof(OlmPkDecryption)); + /* Initialise a fresh decryption object in case someone tries to use it */ + new(decryption) OlmPkDecryption(); + return sizeof(OlmPkDecryption); +} + +size_t olm_pk_private_key_length(void) { + return CURVE25519_KEY_LENGTH; +} + +size_t olm_pk_generate_key_random_length(void) { + return olm_pk_private_key_length(); +} + +size_t olm_pk_key_length(void) { + return olm::encode_base64_length(CURVE25519_KEY_LENGTH); +} + +size_t olm_pk_key_from_private( + OlmPkDecryption * decryption, + void * pubkey, size_t pubkey_length, + const void * privkey, size_t privkey_length +) { + if (pubkey_length < olm_pk_key_length()) { + decryption->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + if (privkey_length < olm_pk_private_key_length()) { + decryption->last_error = + OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + + _olm_crypto_curve25519_generate_key((const uint8_t *) privkey, &decryption->key_pair); + olm::encode_base64( + (const uint8_t *)decryption->key_pair.public_key.public_key, + CURVE25519_KEY_LENGTH, + (uint8_t *)pubkey + ); + return 0; +} + +size_t olm_pk_generate_key( + OlmPkDecryption * decryption, + void * pubkey, size_t pubkey_length, + const void * privkey, size_t privkey_length +) { + return olm_pk_key_from_private(decryption, pubkey, pubkey_length, privkey, privkey_length); +} + +namespace { + static const std::uint32_t PK_DECRYPTION_PICKLE_VERSION = 1; + + static std::size_t pickle_length( + OlmPkDecryption const & value + ) { + std::size_t length = 0; + length += olm::pickle_length(PK_DECRYPTION_PICKLE_VERSION); + length += olm::pickle_length(value.key_pair); + return length; + } + + + static std::uint8_t * pickle( + std::uint8_t * pos, + OlmPkDecryption const & value + ) { + pos = olm::pickle(pos, PK_DECRYPTION_PICKLE_VERSION); + pos = olm::pickle(pos, value.key_pair); + return pos; + } + + + static std::uint8_t const * unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + OlmPkDecryption & value + ) { + uint32_t pickle_version; + pos = olm::unpickle(pos, end, pickle_version); UNPICKLE_OK(pos); + + switch (pickle_version) { + case 1: + break; + + default: + value.last_error = OlmErrorCode::OLM_UNKNOWN_PICKLE_VERSION; + return nullptr; + } + + pos = olm::unpickle(pos, end, value.key_pair); UNPICKLE_OK(pos); + + return pos; + } +} + +size_t olm_pickle_pk_decryption_length( + const OlmPkDecryption * decryption +) { + return _olm_enc_output_length(pickle_length(*decryption)); +} + +size_t olm_pickle_pk_decryption( + OlmPkDecryption * decryption, + void const * key, size_t key_length, + void *pickled, size_t pickled_length +) { + OlmPkDecryption & object = *decryption; + std::size_t raw_length = pickle_length(object); + if (pickled_length < _olm_enc_output_length(raw_length)) { + object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + pickle(_olm_enc_output_pos(reinterpret_cast<std::uint8_t *>(pickled), raw_length), object); + return _olm_enc_output( + reinterpret_cast<std::uint8_t const *>(key), key_length, + reinterpret_cast<std::uint8_t *>(pickled), raw_length + ); +} + +size_t olm_unpickle_pk_decryption( + OlmPkDecryption * decryption, + void const * key, size_t key_length, + void *pickled, size_t pickled_length, + void *pubkey, size_t pubkey_length +) { + OlmPkDecryption & object = *decryption; + if (pubkey != NULL && pubkey_length < olm_pk_key_length()) { + object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + std::uint8_t * const input = reinterpret_cast<std::uint8_t *>(pickled); + std::size_t raw_length = _olm_enc_input( + reinterpret_cast<std::uint8_t const *>(key), key_length, + input, pickled_length, &object.last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + + std::uint8_t const * pos = input; + std::uint8_t const * end = pos + raw_length; + + pos = unpickle(pos, end, object); + + if (!pos) { + /* Input was corrupted. */ + if (object.last_error == OlmErrorCode::OLM_SUCCESS) { + object.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE; + } + return std::size_t(-1); + } else if (pos != end) { + /* Input was longer than expected. */ + object.last_error = OlmErrorCode::OLM_PICKLE_EXTRA_DATA; + return std::size_t(-1); + } + + if (pubkey != NULL) { + olm::encode_base64( + (const uint8_t *)object.key_pair.public_key.public_key, + CURVE25519_KEY_LENGTH, + (uint8_t *)pubkey + ); + } + + return pickled_length; +} + +size_t olm_pk_max_plaintext_length( + const OlmPkDecryption * decryption, + size_t ciphertext_length +) { + return _olm_cipher_aes_sha_256_ops.decrypt_max_plaintext_length( + olm_pk_cipher, olm::decode_base64_length(ciphertext_length) + ); +} + +size_t olm_pk_decrypt( + OlmPkDecryption * decryption, + void const * ephemeral_key, size_t ephemeral_key_length, + void const * mac, size_t mac_length, + void * ciphertext, size_t ciphertext_length, + void * plaintext, size_t max_plaintext_length +) { + if (max_plaintext_length + < olm_pk_max_plaintext_length(decryption, ciphertext_length)) { + decryption->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + + size_t raw_ciphertext_length = olm::decode_base64_length(ciphertext_length); + + if (ephemeral_key_length != olm::encode_base64_length(CURVE25519_KEY_LENGTH) + || mac_length != olm::encode_base64_length(MAC_LENGTH) + || raw_ciphertext_length == std::size_t(-1)) { + decryption->last_error = OlmErrorCode::OLM_INVALID_BASE64; + return std::size_t(-1); + } + + struct _olm_curve25519_public_key ephemeral; + olm::decode_base64( + (const uint8_t*)ephemeral_key, + olm::encode_base64_length(CURVE25519_KEY_LENGTH), + (uint8_t *)ephemeral.public_key + ); + + olm::SharedKey secret; + _olm_crypto_curve25519_shared_secret(&decryption->key_pair, &ephemeral, secret); + + uint8_t raw_mac[MAC_LENGTH]; + olm::decode_base64( + (const uint8_t *)mac, + olm::encode_base64_length(MAC_LENGTH), + raw_mac + ); + + olm::decode_base64( + (const uint8_t *)ciphertext, + ciphertext_length, + (uint8_t *)ciphertext + ); + + size_t result = _olm_cipher_aes_sha_256_ops.decrypt( + olm_pk_cipher, + secret, sizeof(secret), + (uint8_t *) raw_mac, MAC_LENGTH, + (const uint8_t *) ciphertext, raw_ciphertext_length, + (uint8_t *) plaintext, max_plaintext_length + ); + if (result == std::size_t(-1)) { + // we already checked the buffer sizes, so the only error that decrypt + // will return is if the MAC is incorrect + decryption->last_error = + OlmErrorCode::OLM_BAD_MESSAGE_MAC; + return std::size_t(-1); + } else { + return result; + } +} + +size_t olm_pk_get_private_key( + OlmPkDecryption * decryption, + void *private_key, size_t private_key_length +) { + if (private_key_length < olm_pk_private_key_length()) { + decryption->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + std::memcpy( + private_key, + decryption->key_pair.private_key.private_key, + olm_pk_private_key_length() + ); + return olm_pk_private_key_length(); +} + +struct OlmPkSigning { + OlmErrorCode last_error; + _olm_ed25519_key_pair key_pair; +}; + +size_t olm_pk_signing_size(void) { + return sizeof(OlmPkSigning); +} + +OlmPkSigning *olm_pk_signing(void * memory) { + olm::unset(memory, sizeof(OlmPkSigning)); + return new(memory) OlmPkSigning; +} + +const char * olm_pk_signing_last_error(const OlmPkSigning * sign) { + auto error = sign->last_error; + return _olm_error_to_string(error); +} + +OlmErrorCode olm_pk_signing_last_error_code(const OlmPkSigning * sign) { + return sign->last_error; +} + +size_t olm_clear_pk_signing(OlmPkSigning *sign) { + /* Clear the memory backing the signing */ + olm::unset(sign, sizeof(OlmPkSigning)); + /* Initialise a fresh signing object in case someone tries to use it */ + new(sign) OlmPkSigning(); + return sizeof(OlmPkSigning); +} + +size_t olm_pk_signing_seed_length(void) { + return ED25519_RANDOM_LENGTH; +} + +size_t olm_pk_signing_public_key_length(void) { + return olm::encode_base64_length(ED25519_PUBLIC_KEY_LENGTH); +} + +size_t olm_pk_signing_key_from_seed( + OlmPkSigning * signing, + void * pubkey, size_t pubkey_length, + const void * seed, size_t seed_length +) { + if (pubkey_length < olm_pk_signing_public_key_length()) { + signing->last_error = + OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + if (seed_length < olm_pk_signing_seed_length()) { + signing->last_error = + OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + + _olm_crypto_ed25519_generate_key((const uint8_t *) seed, &signing->key_pair); + olm::encode_base64( + (const uint8_t *)signing->key_pair.public_key.public_key, + ED25519_PUBLIC_KEY_LENGTH, + (uint8_t *)pubkey + ); + return 0; +} + +size_t olm_pk_signature_length(void) { + return olm::encode_base64_length(ED25519_SIGNATURE_LENGTH); +} + +size_t olm_pk_sign( + OlmPkSigning *signing, + uint8_t const * message, size_t message_length, + uint8_t * signature, size_t signature_length +) { + if (signature_length < olm_pk_signature_length()) { + signing->last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + uint8_t *raw_sig = signature + olm_pk_signature_length() - ED25519_SIGNATURE_LENGTH; + _olm_crypto_ed25519_sign( + &signing->key_pair, + message, message_length, raw_sig + ); + olm::encode_base64(raw_sig, ED25519_SIGNATURE_LENGTH, signature); + return olm_pk_signature_length(); +} + +} diff --git a/ext/olm/src/ratchet.cpp b/ext/olm/src/ratchet.cpp new file mode 100644 index 0000000..1d284a6 --- /dev/null +++ b/ext/olm/src/ratchet.cpp @@ -0,0 +1,625 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/ratchet.hh" +#include "olm/message.hh" +#include "olm/memory.hh" +#include "olm/cipher.h" +#include "olm/pickle.hh" + +#include <cstring> + +namespace { + +static const std::uint8_t PROTOCOL_VERSION = 3; +static const std::uint8_t MESSAGE_KEY_SEED[1] = {0x01}; +static const std::uint8_t CHAIN_KEY_SEED[1] = {0x02}; +static const std::size_t MAX_MESSAGE_GAP = 2000; + + +/** + * Advance the root key, creating a new message chain. + * + * @param root_key previous root key R(n-1) + * @param our_key our new ratchet key T(n) + * @param their_key their most recent ratchet key T(n-1) + * @param info table of constants for the ratchet function + * @param new_root_key[out] returns the new root key R(n) + * @param new_chain_key[out] returns the first chain key in the new chain + * C(n,0) + */ +static void create_chain_key( + olm::SharedKey const & root_key, + _olm_curve25519_key_pair const & our_key, + _olm_curve25519_public_key const & their_key, + olm::KdfInfo const & info, + olm::SharedKey & new_root_key, + olm::ChainKey & new_chain_key +) { + olm::SharedKey secret; + _olm_crypto_curve25519_shared_secret(&our_key, &their_key, secret); + std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH]; + _olm_crypto_hkdf_sha256( + secret, sizeof(secret), + root_key, sizeof(root_key), + info.ratchet_info, info.ratchet_info_length, + derived_secrets, sizeof(derived_secrets) + ); + std::uint8_t const * pos = derived_secrets; + pos = olm::load_array(new_root_key, pos); + pos = olm::load_array(new_chain_key.key, pos); + new_chain_key.index = 0; + olm::unset(derived_secrets); + olm::unset(secret); +} + + +static void advance_chain_key( + olm::ChainKey const & chain_key, + olm::ChainKey & new_chain_key +) { + _olm_crypto_hmac_sha256( + chain_key.key, sizeof(chain_key.key), + CHAIN_KEY_SEED, sizeof(CHAIN_KEY_SEED), + new_chain_key.key + ); + new_chain_key.index = chain_key.index + 1; +} + + +static void create_message_keys( + olm::ChainKey const & chain_key, + olm::KdfInfo const & info, + olm::MessageKey & message_key) { + _olm_crypto_hmac_sha256( + chain_key.key, sizeof(chain_key.key), + MESSAGE_KEY_SEED, sizeof(MESSAGE_KEY_SEED), + message_key.key + ); + message_key.index = chain_key.index; +} + + +static std::size_t verify_mac_and_decrypt( + _olm_cipher const *cipher, + olm::MessageKey const & message_key, + olm::MessageReader const & reader, + std::uint8_t * plaintext, std::size_t max_plaintext_length +) { + return cipher->ops->decrypt( + cipher, + message_key.key, sizeof(message_key.key), + reader.input, reader.input_length, + reader.ciphertext, reader.ciphertext_length, + plaintext, max_plaintext_length + ); +} + + +static std::size_t verify_mac_and_decrypt_for_existing_chain( + olm::Ratchet const & session, + olm::ChainKey const & chain, + olm::MessageReader const & reader, + std::uint8_t * plaintext, std::size_t max_plaintext_length +) { + if (reader.counter < chain.index) { + return std::size_t(-1); + } + + /* Limit the number of hashes we're prepared to compute */ + if (reader.counter - chain.index > MAX_MESSAGE_GAP) { + return std::size_t(-1); + } + + olm::ChainKey new_chain = chain; + + while (new_chain.index < reader.counter) { + advance_chain_key(new_chain, new_chain); + } + + olm::MessageKey message_key; + create_message_keys(new_chain, session.kdf_info, message_key); + + std::size_t result = verify_mac_and_decrypt( + session.ratchet_cipher, message_key, reader, + plaintext, max_plaintext_length + ); + + olm::unset(new_chain); + return result; +} + + +static std::size_t verify_mac_and_decrypt_for_new_chain( + olm::Ratchet const & session, + olm::MessageReader const & reader, + std::uint8_t * plaintext, std::size_t max_plaintext_length +) { + olm::SharedKey new_root_key; + olm::ReceiverChain new_chain; + + /* They shouldn't move to a new chain until we've sent them a message + * acknowledging the last one */ + if (session.sender_chain.empty()) { + return std::size_t(-1); + } + + /* Limit the number of hashes we're prepared to compute */ + if (reader.counter > MAX_MESSAGE_GAP) { + return std::size_t(-1); + } + olm::load_array(new_chain.ratchet_key.public_key, reader.ratchet_key); + + create_chain_key( + session.root_key, session.sender_chain[0].ratchet_key, + new_chain.ratchet_key, session.kdf_info, + new_root_key, new_chain.chain_key + ); + std::size_t result = verify_mac_and_decrypt_for_existing_chain( + session, new_chain.chain_key, reader, + plaintext, max_plaintext_length + ); + olm::unset(new_root_key); + olm::unset(new_chain); + return result; +} + +} // namespace + + +olm::Ratchet::Ratchet( + olm::KdfInfo const & kdf_info, + _olm_cipher const * ratchet_cipher +) : kdf_info(kdf_info), + ratchet_cipher(ratchet_cipher), + last_error(OlmErrorCode::OLM_SUCCESS) { +} + + +void olm::Ratchet::initialise_as_bob( + std::uint8_t const * shared_secret, std::size_t shared_secret_length, + _olm_curve25519_public_key const & their_ratchet_key +) { + std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH]; + _olm_crypto_hkdf_sha256( + shared_secret, shared_secret_length, + nullptr, 0, + kdf_info.root_info, kdf_info.root_info_length, + derived_secrets, sizeof(derived_secrets) + ); + receiver_chains.insert(); + receiver_chains[0].chain_key.index = 0; + std::uint8_t const * pos = derived_secrets; + pos = olm::load_array(root_key, pos); + pos = olm::load_array(receiver_chains[0].chain_key.key, pos); + receiver_chains[0].ratchet_key = their_ratchet_key; + olm::unset(derived_secrets); +} + + +void olm::Ratchet::initialise_as_alice( + std::uint8_t const * shared_secret, std::size_t shared_secret_length, + _olm_curve25519_key_pair const & our_ratchet_key +) { + std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH]; + _olm_crypto_hkdf_sha256( + shared_secret, shared_secret_length, + nullptr, 0, + kdf_info.root_info, kdf_info.root_info_length, + derived_secrets, sizeof(derived_secrets) + ); + sender_chain.insert(); + sender_chain[0].chain_key.index = 0; + std::uint8_t const * pos = derived_secrets; + pos = olm::load_array(root_key, pos); + pos = olm::load_array(sender_chain[0].chain_key.key, pos); + sender_chain[0].ratchet_key = our_ratchet_key; + olm::unset(derived_secrets); +} + +namespace olm { + + +static std::size_t pickle_length( + const olm::SharedKey & value +) { + return olm::OLM_SHARED_KEY_LENGTH; +} + + +static std::uint8_t * pickle( + std::uint8_t * pos, + const olm::SharedKey & value +) { + return olm::pickle_bytes(pos, value, olm::OLM_SHARED_KEY_LENGTH); +} + + +static std::uint8_t const * unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + olm::SharedKey & value +) { + return olm::unpickle_bytes(pos, end, value, olm::OLM_SHARED_KEY_LENGTH); +} + + +static std::size_t pickle_length( + const olm::SenderChain & value +) { + std::size_t length = 0; + length += olm::pickle_length(value.ratchet_key); + length += olm::pickle_length(value.chain_key.key); + length += olm::pickle_length(value.chain_key.index); + return length; +} + + +static std::uint8_t * pickle( + std::uint8_t * pos, + const olm::SenderChain & value +) { + pos = olm::pickle(pos, value.ratchet_key); + pos = olm::pickle(pos, value.chain_key.key); + pos = olm::pickle(pos, value.chain_key.index); + return pos; +} + + +static std::uint8_t const * unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + olm::SenderChain & value +) { + pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.chain_key.key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.chain_key.index); UNPICKLE_OK(pos); + return pos; +} + +static std::size_t pickle_length( + const olm::ReceiverChain & value +) { + std::size_t length = 0; + length += olm::pickle_length(value.ratchet_key); + length += olm::pickle_length(value.chain_key.key); + length += olm::pickle_length(value.chain_key.index); + return length; +} + + +static std::uint8_t * pickle( + std::uint8_t * pos, + const olm::ReceiverChain & value +) { + pos = olm::pickle(pos, value.ratchet_key); + pos = olm::pickle(pos, value.chain_key.key); + pos = olm::pickle(pos, value.chain_key.index); + return pos; +} + + +static std::uint8_t const * unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + olm::ReceiverChain & value +) { + pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.chain_key.key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.chain_key.index); UNPICKLE_OK(pos); + return pos; +} + + +static std::size_t pickle_length( + const olm::SkippedMessageKey & value +) { + std::size_t length = 0; + length += olm::pickle_length(value.ratchet_key); + length += olm::pickle_length(value.message_key.key); + length += olm::pickle_length(value.message_key.index); + return length; +} + + +static std::uint8_t * pickle( + std::uint8_t * pos, + const olm::SkippedMessageKey & value +) { + pos = olm::pickle(pos, value.ratchet_key); + pos = olm::pickle(pos, value.message_key.key); + pos = olm::pickle(pos, value.message_key.index); + return pos; +} + + +static std::uint8_t const * unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + olm::SkippedMessageKey & value +) { + pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.message_key.key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.message_key.index); UNPICKLE_OK(pos); + return pos; +} + + +} // namespace olm + + +std::size_t olm::pickle_length( + olm::Ratchet const & value +) { + std::size_t length = 0; + length += olm::OLM_SHARED_KEY_LENGTH; + length += olm::pickle_length(value.sender_chain); + length += olm::pickle_length(value.receiver_chains); + length += olm::pickle_length(value.skipped_message_keys); + return length; +} + +std::uint8_t * olm::pickle( + std::uint8_t * pos, + olm::Ratchet const & value +) { + pos = pickle(pos, value.root_key); + pos = pickle(pos, value.sender_chain); + pos = pickle(pos, value.receiver_chains); + pos = pickle(pos, value.skipped_message_keys); + return pos; +} + + +std::uint8_t const * olm::unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + olm::Ratchet & value, + bool includes_chain_index +) { + pos = unpickle(pos, end, value.root_key); UNPICKLE_OK(pos); + pos = unpickle(pos, end, value.sender_chain); UNPICKLE_OK(pos); + pos = unpickle(pos, end, value.receiver_chains); UNPICKLE_OK(pos); + pos = unpickle(pos, end, value.skipped_message_keys); UNPICKLE_OK(pos); + + // pickle v 0x80000001 includes a chain index; pickle v1 does not. + if (includes_chain_index) { + std::uint32_t dummy; + pos = unpickle(pos, end, dummy); UNPICKLE_OK(pos); + } + return pos; +} + + +std::size_t olm::Ratchet::encrypt_output_length( + std::size_t plaintext_length +) const { + std::size_t counter = 0; + if (!sender_chain.empty()) { + counter = sender_chain[0].chain_key.index; + } + std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length( + ratchet_cipher, + plaintext_length + ); + return olm::encode_message_length( + counter, CURVE25519_KEY_LENGTH, padded, ratchet_cipher->ops->mac_length(ratchet_cipher) + ); +} + + +std::size_t olm::Ratchet::encrypt_random_length() const { + return sender_chain.empty() ? CURVE25519_RANDOM_LENGTH : 0; +} + + +std::size_t olm::Ratchet::encrypt( + std::uint8_t const * plaintext, std::size_t plaintext_length, + std::uint8_t const * random, std::size_t random_length, + std::uint8_t * output, std::size_t max_output_length +) { + std::size_t output_length = encrypt_output_length(plaintext_length); + + if (random_length < encrypt_random_length()) { + last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM; + return std::size_t(-1); + } + if (max_output_length < output_length) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + + if (sender_chain.empty()) { + sender_chain.insert(); + _olm_crypto_curve25519_generate_key(random, &sender_chain[0].ratchet_key); + create_chain_key( + root_key, + sender_chain[0].ratchet_key, + receiver_chains[0].ratchet_key, + kdf_info, + root_key, sender_chain[0].chain_key + ); + } + + MessageKey keys; + create_message_keys(sender_chain[0].chain_key, kdf_info, keys); + advance_chain_key(sender_chain[0].chain_key, sender_chain[0].chain_key); + + std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length( + ratchet_cipher, + plaintext_length + ); + std::uint32_t counter = keys.index; + _olm_curve25519_public_key const & ratchet_key = + sender_chain[0].ratchet_key.public_key; + + olm::MessageWriter writer; + + olm::encode_message( + writer, PROTOCOL_VERSION, counter, CURVE25519_KEY_LENGTH, + ciphertext_length, + output + ); + + olm::store_array(writer.ratchet_key, ratchet_key.public_key); + + ratchet_cipher->ops->encrypt( + ratchet_cipher, + keys.key, sizeof(keys.key), + plaintext, plaintext_length, + writer.ciphertext, ciphertext_length, + output, output_length + ); + + olm::unset(keys); + return output_length; +} + + +std::size_t olm::Ratchet::decrypt_max_plaintext_length( + std::uint8_t const * input, std::size_t input_length +) { + olm::MessageReader reader; + olm::decode_message( + reader, input, input_length, + ratchet_cipher->ops->mac_length(ratchet_cipher) + ); + + if (!reader.ciphertext) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; + return std::size_t(-1); + } + + return ratchet_cipher->ops->decrypt_max_plaintext_length( + ratchet_cipher, reader.ciphertext_length); +} + + +std::size_t olm::Ratchet::decrypt( + std::uint8_t const * input, std::size_t input_length, + std::uint8_t * plaintext, std::size_t max_plaintext_length +) { + olm::MessageReader reader; + olm::decode_message( + reader, input, input_length, + ratchet_cipher->ops->mac_length(ratchet_cipher) + ); + + if (reader.version != PROTOCOL_VERSION) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_VERSION; + return std::size_t(-1); + } + + if (!reader.has_counter || !reader.ratchet_key || !reader.ciphertext) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; + return std::size_t(-1); + } + + std::size_t max_length = ratchet_cipher->ops->decrypt_max_plaintext_length( + ratchet_cipher, + reader.ciphertext_length + ); + + if (max_plaintext_length < max_length) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + + if (reader.ratchet_key_length != CURVE25519_KEY_LENGTH) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; + return std::size_t(-1); + } + + ReceiverChain * chain = nullptr; + + for (olm::ReceiverChain & receiver_chain : receiver_chains) { + if (0 == std::memcmp( + receiver_chain.ratchet_key.public_key, reader.ratchet_key, + CURVE25519_KEY_LENGTH + )) { + chain = &receiver_chain; + break; + } + } + + std::size_t result = std::size_t(-1); + + if (!chain) { + result = verify_mac_and_decrypt_for_new_chain( + *this, reader, plaintext, max_plaintext_length + ); + } else if (chain->chain_key.index > reader.counter) { + /* Chain already advanced beyond the key for this message + * Check if the message keys are in the skipped key list. */ + for (olm::SkippedMessageKey & skipped : skipped_message_keys) { + if (reader.counter == skipped.message_key.index + && 0 == std::memcmp( + skipped.ratchet_key.public_key, reader.ratchet_key, + CURVE25519_KEY_LENGTH + ) + ) { + /* Found the key for this message. Check the MAC. */ + + result = verify_mac_and_decrypt( + ratchet_cipher, skipped.message_key, reader, + plaintext, max_plaintext_length + ); + + if (result != std::size_t(-1)) { + /* Remove the key from the skipped keys now that we've + * decoded the message it corresponds to. */ + olm::unset(skipped); + skipped_message_keys.erase(&skipped); + return result; + } + } + } + } else { + result = verify_mac_and_decrypt_for_existing_chain( + *this, chain->chain_key, + reader, plaintext, max_plaintext_length + ); + } + + if (result == std::size_t(-1)) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_MAC; + return std::size_t(-1); + } + + if (!chain) { + /* They have started using a new ephemeral ratchet key. + * We need to derive a new set of chain keys. + * We can discard our previous ephemeral ratchet key. + * We will generate a new key when we send the next message. */ + + chain = receiver_chains.insert(); + olm::load_array(chain->ratchet_key.public_key, reader.ratchet_key); + + // TODO: we've already done this once, in + // verify_mac_and_decrypt_for_new_chain(). we could reuse the result. + create_chain_key( + root_key, sender_chain[0].ratchet_key, chain->ratchet_key, + kdf_info, root_key, chain->chain_key + ); + + olm::unset(sender_chain[0]); + sender_chain.erase(sender_chain.begin()); + } + + while (chain->chain_key.index < reader.counter) { + olm::SkippedMessageKey & key = *skipped_message_keys.insert(); + create_message_keys(chain->chain_key, kdf_info, key.message_key); + key.ratchet_key = chain->ratchet_key; + advance_chain_key(chain->chain_key, chain->chain_key); + } + + advance_chain_key(chain->chain_key, chain->chain_key); + + return result; +} diff --git a/ext/olm/src/sas.c b/ext/olm/src/sas.c new file mode 100644 index 0000000..d9cec7e --- /dev/null +++ b/ext/olm/src/sas.c @@ -0,0 +1,229 @@ +/* Copyright 2018-2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/sas.h" +#include "olm/base64.h" +#include "olm/crypto.h" +#include "olm/error.h" +#include "olm/memory.h" + +struct OlmSAS { + enum OlmErrorCode last_error; + struct _olm_curve25519_key_pair curve25519_key; + uint8_t secret[CURVE25519_SHARED_SECRET_LENGTH]; + int their_key_set; +}; + +const char * olm_sas_last_error( + const OlmSAS * sas +) { + return _olm_error_to_string(sas->last_error); +} + +enum OlmErrorCode olm_sas_last_error_code( + const OlmSAS * sas +) { + return sas->last_error; +} + +size_t olm_sas_size(void) { + return sizeof(OlmSAS); +} + +OlmSAS * olm_sas( + void * memory +) { + _olm_unset(memory, sizeof(OlmSAS)); + return (OlmSAS *) memory; +} + +size_t olm_clear_sas( + OlmSAS * sas +) { + _olm_unset(sas, sizeof(OlmSAS)); + return sizeof(OlmSAS); +} + +size_t olm_create_sas_random_length(const OlmSAS * sas) { + return CURVE25519_KEY_LENGTH; +} + +size_t olm_create_sas( + OlmSAS * sas, + void * random, size_t random_length +) { + if (random_length < olm_create_sas_random_length(sas)) { + sas->last_error = OLM_NOT_ENOUGH_RANDOM; + return (size_t)-1; + } + _olm_crypto_curve25519_generate_key((uint8_t *) random, &sas->curve25519_key); + sas->their_key_set = 0; + return 0; +} + +size_t olm_sas_pubkey_length(const OlmSAS * sas) { + return _olm_encode_base64_length(CURVE25519_KEY_LENGTH); +} + +size_t olm_sas_get_pubkey( + OlmSAS * sas, + void * pubkey, size_t pubkey_length +) { + if (pubkey_length < olm_sas_pubkey_length(sas)) { + sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + _olm_encode_base64( + (const uint8_t *)sas->curve25519_key.public_key.public_key, + CURVE25519_KEY_LENGTH, + (uint8_t *)pubkey + ); + return 0; +} + +size_t olm_sas_set_their_key( + OlmSAS *sas, + void * their_key, size_t their_key_length +) { + if (their_key_length < olm_sas_pubkey_length(sas)) { + sas->last_error = OLM_INPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + size_t ret = _olm_decode_base64(their_key, their_key_length, their_key); + if (ret == (size_t)-1) { + sas->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + _olm_crypto_curve25519_shared_secret(&sas->curve25519_key, their_key, sas->secret); + sas->their_key_set = 1; + return 0; +} + +int olm_sas_is_their_key_set( + const OlmSAS *sas +) { + return sas->their_key_set; +} + +size_t olm_sas_generate_bytes( + OlmSAS * sas, + const void * info, size_t info_length, + void * output, size_t output_length +) { + if (!sas->their_key_set) { + sas->last_error = OLM_SAS_THEIR_KEY_NOT_SET; + return (size_t)-1; + } + _olm_crypto_hkdf_sha256( + sas->secret, sizeof(sas->secret), + NULL, 0, + (const uint8_t *) info, info_length, + output, output_length + ); + return 0; +} + +size_t olm_sas_mac_length( + const OlmSAS *sas +) { + return _olm_encode_base64_length(SHA256_OUTPUT_LENGTH); +} + +// A version of the calculate mac function that produces base64 strings that are +// compatible with other base64 implementations. +size_t olm_sas_calculate_mac_fixed_base64( + OlmSAS * sas, + const void * input, size_t input_length, + const void * info, size_t info_length, + void * mac, size_t mac_length +) { + if (mac_length < olm_sas_mac_length(sas)) { + sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + if (!sas->their_key_set) { + sas->last_error = OLM_SAS_THEIR_KEY_NOT_SET; + return (size_t)-1; + } + uint8_t key[32]; + _olm_crypto_hkdf_sha256( + sas->secret, sizeof(sas->secret), + NULL, 0, + (const uint8_t *) info, info_length, + key, 32 + ); + + uint8_t temp_mac[32]; + _olm_crypto_hmac_sha256(key, 32, input, input_length, temp_mac); + _olm_encode_base64((const uint8_t *)temp_mac, SHA256_OUTPUT_LENGTH, (uint8_t *)mac); + + return 0; +} + + +size_t olm_sas_calculate_mac( + OlmSAS * sas, + const void * input, size_t input_length, + const void * info, size_t info_length, + void * mac, size_t mac_length +) { + if (mac_length < olm_sas_mac_length(sas)) { + sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + if (!sas->their_key_set) { + sas->last_error = OLM_SAS_THEIR_KEY_NOT_SET; + return (size_t)-1; + } + uint8_t key[32]; + _olm_crypto_hkdf_sha256( + sas->secret, sizeof(sas->secret), + NULL, 0, + (const uint8_t *) info, info_length, + key, 32 + ); + _olm_crypto_hmac_sha256(key, 32, input, input_length, mac); + _olm_encode_base64((const uint8_t *)mac, SHA256_OUTPUT_LENGTH, (uint8_t *)mac); + return 0; +} + +// for compatibility with an old version of Riot +size_t olm_sas_calculate_mac_long_kdf( + OlmSAS * sas, + const void * input, size_t input_length, + const void * info, size_t info_length, + void * mac, size_t mac_length +) { + if (mac_length < olm_sas_mac_length(sas)) { + sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + if (!sas->their_key_set) { + sas->last_error = OLM_SAS_THEIR_KEY_NOT_SET; + return (size_t)-1; + } + uint8_t key[256]; + _olm_crypto_hkdf_sha256( + sas->secret, sizeof(sas->secret), + NULL, 0, + (const uint8_t *) info, info_length, + key, 256 + ); + _olm_crypto_hmac_sha256(key, 256, input, input_length, mac); + _olm_encode_base64((const uint8_t *)mac, SHA256_OUTPUT_LENGTH, (uint8_t *)mac); + return 0; +} diff --git a/ext/olm/src/session.cpp b/ext/olm/src/session.cpp new file mode 100644 index 0000000..732e0c0 --- /dev/null +++ b/ext/olm/src/session.cpp @@ -0,0 +1,531 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "olm/session.hh" +#include "olm/cipher.h" +#include "olm/crypto.h" +#include "olm/account.hh" +#include "olm/memory.hh" +#include "olm/message.hh" +#include "olm/pickle.hh" + +#include <cstring> +#include <stdio.h> + +namespace { + +static const std::uint8_t PROTOCOL_VERSION = 0x3; + +static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT"; +static const std::uint8_t RATCHET_KDF_INFO[] = "OLM_RATCHET"; +static const std::uint8_t CIPHER_KDF_INFO[] = "OLM_KEYS"; + +static const olm::KdfInfo OLM_KDF_INFO = { + ROOT_KDF_INFO, sizeof(ROOT_KDF_INFO) - 1, + RATCHET_KDF_INFO, sizeof(RATCHET_KDF_INFO) - 1 +}; + +static const struct _olm_cipher_aes_sha_256 OLM_CIPHER = + OLM_CIPHER_INIT_AES_SHA_256(CIPHER_KDF_INFO); + +} // namespace + +olm::Session::Session( +) : ratchet(OLM_KDF_INFO, OLM_CIPHER_BASE(&OLM_CIPHER)), + last_error(OlmErrorCode::OLM_SUCCESS), + received_message(false) { + +} + + +std::size_t olm::Session::new_outbound_session_random_length() const { + return CURVE25519_RANDOM_LENGTH * 2; +} + + +std::size_t olm::Session::new_outbound_session( + olm::Account const & local_account, + _olm_curve25519_public_key const & identity_key, + _olm_curve25519_public_key const & one_time_key, + std::uint8_t const * random, std::size_t random_length +) { + if (random_length < new_outbound_session_random_length()) { + last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM; + return std::size_t(-1); + } + + _olm_curve25519_key_pair base_key; + _olm_crypto_curve25519_generate_key(random, &base_key); + + _olm_curve25519_key_pair ratchet_key; + _olm_crypto_curve25519_generate_key(random + CURVE25519_RANDOM_LENGTH, &ratchet_key); + + _olm_curve25519_key_pair const & alice_identity_key_pair = ( + local_account.identity_keys.curve25519_key + ); + + received_message = false; + alice_identity_key = alice_identity_key_pair.public_key; + alice_base_key = base_key.public_key; + bob_one_time_key = one_time_key; + + // Calculate the shared secret S via triple DH + std::uint8_t secret[3 * CURVE25519_SHARED_SECRET_LENGTH]; + std::uint8_t * pos = secret; + + _olm_crypto_curve25519_shared_secret(&alice_identity_key_pair, &one_time_key, pos); + pos += CURVE25519_SHARED_SECRET_LENGTH; + _olm_crypto_curve25519_shared_secret(&base_key, &identity_key, pos); + pos += CURVE25519_SHARED_SECRET_LENGTH; + _olm_crypto_curve25519_shared_secret(&base_key, &one_time_key, pos); + + ratchet.initialise_as_alice(secret, sizeof(secret), ratchet_key); + + olm::unset(base_key); + olm::unset(ratchet_key); + olm::unset(secret); + + return std::size_t(0); +} + +namespace { + +static bool check_message_fields( + olm::PreKeyMessageReader & reader, bool have_their_identity_key +) { + bool ok = true; + ok = ok && (have_their_identity_key || reader.identity_key); + if (reader.identity_key) { + ok = ok && reader.identity_key_length == CURVE25519_KEY_LENGTH; + } + ok = ok && reader.message; + ok = ok && reader.base_key; + ok = ok && reader.base_key_length == CURVE25519_KEY_LENGTH; + ok = ok && reader.one_time_key; + ok = ok && reader.one_time_key_length == CURVE25519_KEY_LENGTH; + return ok; +} + +} // namespace + + +std::size_t olm::Session::new_inbound_session( + olm::Account & local_account, + _olm_curve25519_public_key const * their_identity_key, + std::uint8_t const * one_time_key_message, std::size_t message_length +) { + olm::PreKeyMessageReader reader; + decode_one_time_key_message(reader, one_time_key_message, message_length); + + if (!check_message_fields(reader, their_identity_key)) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; + return std::size_t(-1); + } + + if (reader.identity_key && their_identity_key) { + bool same = 0 == std::memcmp( + their_identity_key->public_key, reader.identity_key, CURVE25519_KEY_LENGTH + ); + if (!same) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_KEY_ID; + return std::size_t(-1); + } + } + + olm::load_array(alice_identity_key.public_key, reader.identity_key); + olm::load_array(alice_base_key.public_key, reader.base_key); + olm::load_array(bob_one_time_key.public_key, reader.one_time_key); + + olm::MessageReader message_reader; + decode_message( + message_reader, reader.message, reader.message_length, + ratchet.ratchet_cipher->ops->mac_length(ratchet.ratchet_cipher) + ); + + if (!message_reader.ratchet_key + || message_reader.ratchet_key_length != CURVE25519_KEY_LENGTH) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; + return std::size_t(-1); + } + + _olm_curve25519_public_key ratchet_key; + olm::load_array(ratchet_key.public_key, message_reader.ratchet_key); + + olm::OneTimeKey const * our_one_time_key = local_account.lookup_key( + bob_one_time_key + ); + + if (!our_one_time_key) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_KEY_ID; + return std::size_t(-1); + } + + _olm_curve25519_key_pair const & bob_identity_key = ( + local_account.identity_keys.curve25519_key + ); + _olm_curve25519_key_pair const & bob_one_time_key = our_one_time_key->key; + + // Calculate the shared secret S via triple DH + std::uint8_t secret[CURVE25519_SHARED_SECRET_LENGTH * 3]; + std::uint8_t * pos = secret; + _olm_crypto_curve25519_shared_secret(&bob_one_time_key, &alice_identity_key, pos); + pos += CURVE25519_SHARED_SECRET_LENGTH; + _olm_crypto_curve25519_shared_secret(&bob_identity_key, &alice_base_key, pos); + pos += CURVE25519_SHARED_SECRET_LENGTH; + _olm_crypto_curve25519_shared_secret(&bob_one_time_key, &alice_base_key, pos); + + ratchet.initialise_as_bob(secret, sizeof(secret), ratchet_key); + + olm::unset(secret); + + return std::size_t(0); +} + + +std::size_t olm::Session::session_id_length() const { + return SHA256_OUTPUT_LENGTH; +} + + +std::size_t olm::Session::session_id( + std::uint8_t * id, std::size_t id_length +) { + if (id_length < session_id_length()) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + std::uint8_t tmp[CURVE25519_KEY_LENGTH * 3]; + std::uint8_t * pos = tmp; + pos = olm::store_array(pos, alice_identity_key.public_key); + pos = olm::store_array(pos, alice_base_key.public_key); + pos = olm::store_array(pos, bob_one_time_key.public_key); + _olm_crypto_sha256(tmp, sizeof(tmp), id); + return session_id_length(); +} + + +bool olm::Session::matches_inbound_session( + _olm_curve25519_public_key const * their_identity_key, + std::uint8_t const * one_time_key_message, std::size_t message_length +) const { + olm::PreKeyMessageReader reader; + decode_one_time_key_message(reader, one_time_key_message, message_length); + + if (!check_message_fields(reader, their_identity_key)) { + return false; + } + + bool same = true; + if (reader.identity_key) { + same = same && 0 == std::memcmp( + reader.identity_key, alice_identity_key.public_key, CURVE25519_KEY_LENGTH + ); + } + if (their_identity_key) { + same = same && 0 == std::memcmp( + their_identity_key->public_key, alice_identity_key.public_key, + CURVE25519_KEY_LENGTH + ); + } + same = same && 0 == std::memcmp( + reader.base_key, alice_base_key.public_key, CURVE25519_KEY_LENGTH + ); + same = same && 0 == std::memcmp( + reader.one_time_key, bob_one_time_key.public_key, CURVE25519_KEY_LENGTH + ); + return same; +} + + +olm::MessageType olm::Session::encrypt_message_type() const { + if (received_message) { + return olm::MessageType::MESSAGE; + } else { + return olm::MessageType::PRE_KEY; + } +} + + +std::size_t olm::Session::encrypt_message_length( + std::size_t plaintext_length +) const { + std::size_t message_length = ratchet.encrypt_output_length( + plaintext_length + ); + + if (received_message) { + return message_length; + } + + return encode_one_time_key_message_length( + CURVE25519_KEY_LENGTH, + CURVE25519_KEY_LENGTH, + CURVE25519_KEY_LENGTH, + message_length + ); +} + + +std::size_t olm::Session::encrypt_random_length() const { + return ratchet.encrypt_random_length(); +} + + +std::size_t olm::Session::encrypt( + std::uint8_t const * plaintext, std::size_t plaintext_length, + std::uint8_t const * random, std::size_t random_length, + std::uint8_t * message, std::size_t message_length +) { + if (message_length < encrypt_message_length(plaintext_length)) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + std::uint8_t * message_body; + std::size_t message_body_length = ratchet.encrypt_output_length( + plaintext_length + ); + + if (received_message) { + message_body = message; + } else { + olm::PreKeyMessageWriter writer; + encode_one_time_key_message( + writer, + PROTOCOL_VERSION, + CURVE25519_KEY_LENGTH, + CURVE25519_KEY_LENGTH, + CURVE25519_KEY_LENGTH, + message_body_length, + message + ); + olm::store_array(writer.one_time_key, bob_one_time_key.public_key); + olm::store_array(writer.identity_key, alice_identity_key.public_key); + olm::store_array(writer.base_key, alice_base_key.public_key); + message_body = writer.message; + } + + std::size_t result = ratchet.encrypt( + plaintext, plaintext_length, + random, random_length, + message_body, message_body_length + ); + + if (result == std::size_t(-1)) { + last_error = ratchet.last_error; + ratchet.last_error = OlmErrorCode::OLM_SUCCESS; + return result; + } + + return result; +} + + +std::size_t olm::Session::decrypt_max_plaintext_length( + MessageType message_type, + std::uint8_t const * message, std::size_t message_length +) { + std::uint8_t const * message_body; + std::size_t message_body_length; + if (message_type == olm::MessageType::MESSAGE) { + message_body = message; + message_body_length = message_length; + } else { + olm::PreKeyMessageReader reader; + decode_one_time_key_message(reader, message, message_length); + if (!reader.message) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; + return std::size_t(-1); + } + message_body = reader.message; + message_body_length = reader.message_length; + } + + std::size_t result = ratchet.decrypt_max_plaintext_length( + message_body, message_body_length + ); + + if (result == std::size_t(-1)) { + last_error = ratchet.last_error; + ratchet.last_error = OlmErrorCode::OLM_SUCCESS; + } + return result; +} + + +std::size_t olm::Session::decrypt( + olm::MessageType message_type, + std::uint8_t const * message, std::size_t message_length, + std::uint8_t * plaintext, std::size_t max_plaintext_length +) { + std::uint8_t const * message_body; + std::size_t message_body_length; + if (message_type == olm::MessageType::MESSAGE) { + message_body = message; + message_body_length = message_length; + } else { + olm::PreKeyMessageReader reader; + decode_one_time_key_message(reader, message, message_length); + if (!reader.message) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; + return std::size_t(-1); + } + message_body = reader.message; + message_body_length = reader.message_length; + } + + std::size_t result = ratchet.decrypt( + message_body, message_body_length, plaintext, max_plaintext_length + ); + + if (result == std::size_t(-1)) { + last_error = ratchet.last_error; + ratchet.last_error = OlmErrorCode::OLM_SUCCESS; + return result; + } + + received_message = true; + return result; +} + +// make the description end with "..." instead of stopping abruptly with no +// warning +void elide_description(char *end) { + end[-3] = '.'; + end[-2] = '.'; + end[-1] = '.'; + end[0] = '\0'; +} + +void olm::Session::describe(char *describe_buffer, size_t buflen) { + // how much of the buffer is remaining (this is an int rather than a size_t + // because it will get compared to the return value from snprintf) + int remaining = buflen; + // do nothing if we have a zero-length buffer, or if buflen > INT_MAX, + // resulting in an overflow + if (remaining <= 0) return; + + describe_buffer[0] = '\0'; + // we need at least 23 characters to get any sort of meaningful + // information, so bail if we don't have that. (But more importantly, we + // need it to be at least 4 so that elide_description doesn't go out of + // bounds.) + if (remaining < 23) return; + + int size; + + // check that snprintf didn't return an error or reach the end of the buffer +#define CHECK_SIZE_AND_ADVANCE \ + if (size > remaining) { \ + return elide_description(describe_buffer + remaining - 1); \ + } else if (size > 0) { \ + describe_buffer += size; \ + remaining -= size; \ + } else { \ + return; \ + } + + size = snprintf( + describe_buffer, remaining, + "sender chain index: %ld ", ratchet.sender_chain[0].chain_key.index + ); + CHECK_SIZE_AND_ADVANCE; + + size = snprintf(describe_buffer, remaining, "receiver chain indices:"); + CHECK_SIZE_AND_ADVANCE; + + for (size_t i = 0; i < ratchet.receiver_chains.size(); ++i) { + size = snprintf( + describe_buffer, remaining, + " %ld", ratchet.receiver_chains[i].chain_key.index + ); + CHECK_SIZE_AND_ADVANCE; + } + + size = snprintf(describe_buffer, remaining, " skipped message keys:"); + CHECK_SIZE_AND_ADVANCE; + + for (size_t i = 0; i < ratchet.skipped_message_keys.size(); ++i) { + size = snprintf( + describe_buffer, remaining, + " %ld", ratchet.skipped_message_keys[i].message_key.index + ); + CHECK_SIZE_AND_ADVANCE; + } +#undef CHECK_SIZE_AND_ADVANCE +} + +namespace { +// the master branch writes pickle version 1; the logging_enabled branch writes +// 0x80000001. +static const std::uint32_t SESSION_PICKLE_VERSION = 1; +} + +std::size_t olm::pickle_length( + Session const & value +) { + std::size_t length = 0; + length += olm::pickle_length(SESSION_PICKLE_VERSION); + length += olm::pickle_length(value.received_message); + length += olm::pickle_length(value.alice_identity_key); + length += olm::pickle_length(value.alice_base_key); + length += olm::pickle_length(value.bob_one_time_key); + length += olm::pickle_length(value.ratchet); + return length; +} + + +std::uint8_t * olm::pickle( + std::uint8_t * pos, + Session const & value +) { + pos = olm::pickle(pos, SESSION_PICKLE_VERSION); + pos = olm::pickle(pos, value.received_message); + pos = olm::pickle(pos, value.alice_identity_key); + pos = olm::pickle(pos, value.alice_base_key); + pos = olm::pickle(pos, value.bob_one_time_key); + pos = olm::pickle(pos, value.ratchet); + return pos; +} + + +std::uint8_t const * olm::unpickle( + std::uint8_t const * pos, std::uint8_t const * end, + Session & value +) { + uint32_t pickle_version; + pos = olm::unpickle(pos, end, pickle_version); UNPICKLE_OK(pos); + + bool includes_chain_index; + switch (pickle_version) { + case 1: + includes_chain_index = false; + break; + + case 0x80000001UL: + includes_chain_index = true; + break; + + default: + value.last_error = OlmErrorCode::OLM_UNKNOWN_PICKLE_VERSION; + return nullptr; + } + + pos = olm::unpickle(pos, end, value.received_message); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.alice_identity_key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.alice_base_key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.bob_one_time_key); UNPICKLE_OK(pos); + pos = olm::unpickle(pos, end, value.ratchet, includes_chain_index); UNPICKLE_OK(pos); + + return pos; +} diff --git a/ext/olm/src/utility.cpp b/ext/olm/src/utility.cpp new file mode 100644 index 0000000..b6bb56e --- /dev/null +++ b/ext/olm/src/utility.cpp @@ -0,0 +1,57 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/utility.hh" +#include "olm/crypto.h" + + +olm::Utility::Utility( +) : last_error(OlmErrorCode::OLM_SUCCESS) { +} + + +size_t olm::Utility::sha256_length() const { + return SHA256_OUTPUT_LENGTH; +} + + +size_t olm::Utility::sha256( + std::uint8_t const * input, std::size_t input_length, + std::uint8_t * output, std::size_t output_length +) { + if (output_length < sha256_length()) { + last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; + return std::size_t(-1); + } + _olm_crypto_sha256(input, input_length, output); + return SHA256_OUTPUT_LENGTH; +} + + +size_t olm::Utility::ed25519_verify( + _olm_ed25519_public_key const & key, + std::uint8_t const * message, std::size_t message_length, + std::uint8_t const * signature, std::size_t signature_length +) { + if (signature_length < ED25519_SIGNATURE_LENGTH) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_MAC; + return std::size_t(-1); + } + if (!_olm_crypto_ed25519_verify(&key, message, message_length, signature)) { + last_error = OlmErrorCode::OLM_BAD_MESSAGE_MAC; + return std::size_t(-1); + } + return std::size_t(0); +} |
