Anforderungen  |   Konzepte  |   Entwurf  |   Entwicklung  |   Qualitätssicherung  |   Lebenszyklus  |   Steuerung
 
 
 
 


Quelle  tls_agent.cc   Sprache: C

 
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */


#include "tls_agent.h"
#include "databuffer.h"
#include "keyhi.h"
#include "pk11func.h"
#include "ssl.h"
#include "sslerr.h"
#include "sslexp.h"
#include "sslproto.h"
#include "tls_filter.h"
#include "tls_parser.h"

extern "C" {
// This is not something that should make you happy.
#include "libssl_internals.h"
}

#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
#include "gtest_utils.h"
#include "nss_scoped_ptrs.h"

extern std::string g_working_dir_path;

namespace nss_test {

const char* TlsAgent::states[] = {"INIT""CONNECTING""CONNECTED""ERROR"};

const std::string TlsAgent::kClient = "client";    // both sign and encrypt
const std::string TlsAgent::kRsa2048 = "rsa2048";  // bigger
const std::string TlsAgent::kRsa8192 = "rsa8192";  // biggest allowed
const std::string TlsAgent::kServerRsa = "rsa";    // both sign and encrypt
const std::string TlsAgent::kServerRsaSign = "rsa_sign";
const std::string TlsAgent::kServerRsaPss = "rsa_pss";
const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
const std::string TlsAgent::kServerDsa = "dsa";
const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256";
const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048";
const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048";

static const uint8_t kCannedTls13ServerHello[] = {
    0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3,
    0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b,
    0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76,
    0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24,
    0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
    0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
    0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
    0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04};

TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var)
    : name_(nm),
      variant_(var),
      role_(rl),
      server_key_bits_(0),
      adapter_(new DummyPrSocket(role_str(), var)),
      ssl_fd_(nullptr),
      state_(STATE_INIT),
      timer_handle_(nullptr),
      falsestart_enabled_(false),
      expected_version_(0),
      expected_cipher_suite_(0),
      expect_client_auth_(false),
      expect_ech_(false),
      expect_psk_(ssl_psk_none),
      can_falsestart_hook_called_(false),
      sni_hook_called_(false),
      auth_certificate_hook_called_(false),
      expected_received_alert_(kTlsAlertCloseNotify),
      expected_received_alert_level_(kTlsAlertWarning),
      expected_sent_alert_(kTlsAlertCloseNotify),
      expected_sent_alert_level_(kTlsAlertWarning),
      handshake_callback_called_(false),
      resumption_callback_called_(false),
      error_code_(0),
      send_ctr_(0),
      recv_ctr_(0),
      expect_readwrite_error_(false),
      handshake_callback_(),
      auth_certificate_callback_(),
      sni_callback_(),
      skip_version_checks_(false),
      resumption_token_(),
      policy_() {
  memset(&info_, 0, sizeof(info_));
  memset(&csinfo_, 0, sizeof(csinfo_));
  SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
  EXPECT_EQ(SECSuccess, rv);
}

TlsAgent::~TlsAgent() {
  if (timer_handle_) {
    timer_handle_->Cancel();
  }

  if (adapter_) {
    Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
  }

  // Add failures manually, if any, so we don't throw in a destructor.
  if (expected_received_alert_ != kTlsAlertCloseNotify ||
      expected_received_alert_level_ != kTlsAlertWarning) {
    ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str();
  }
  if (expected_sent_alert_ != kTlsAlertCloseNotify ||
      expected_sent_alert_level_ != kTlsAlertWarning) {
    ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str();
  }
}

void TlsAgent::SetState(State s) {
  if (state_ == s) return;

  LOG("Changing state from " << state_ << " to " << s);
  state_ = s;
}

/*static*/ bool TlsAgent::LoadCertificate(const std::string& name,
                                          ScopedCERTCertificate* cert,
                                          ScopedSECKEYPrivateKey* priv) {
  cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr));
  EXPECT_NE(nullptr, cert);
  if (!cert) return false;
  EXPECT_NE(nullptr, cert->get());
  if (!cert->get()) return false;

  priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr));
  EXPECT_NE(nullptr, priv);
  if (!priv) return false;
  EXPECT_NE(nullptr, priv->get());
  if (!priv->get()) return false;

  return true;
}

// Loads a key pair from the certificate identified by |id|.
/*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name,
                                              ScopedSECKEYPublicKey* pub,
                                              ScopedSECKEYPrivateKey* priv) {
  ScopedCERTCertificate cert;
  if (!TlsAgent::LoadCertificate(name, &cert, priv)) {
    return false;
  }

  pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo));
  if (!pub->get()) {
    return false;
  }

  return true;
}

void TlsAgent::DelegateCredential(const std::string& name,
                                  const ScopedSECKEYPublicKey& dc_pub,
                                  SSLSignatureScheme dc_cert_verify_alg,
                                  PRUint32 dc_valid_for, PRTime now,
                                  SECItem* dc) {
  ScopedCERTCertificate cert;
  ScopedSECKEYPrivateKey cert_priv;
  EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv))
      << "Could not load delegate certificate: " << name
      << "; test db corrupt?";

  EXPECT_EQ(SECSuccess,
            SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(),
                                   dc_cert_verify_alg, dc_valid_for, now, dc));
}

void TlsAgent::EnableDelegatedCredentials() {
  ASSERT_TRUE(EnsureTlsSetup());
  SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE);
}

void TlsAgent::AddDelegatedCredential(const std::string& dc_name,
                                      SSLSignatureScheme dc_cert_verify_alg,
                                      PRUint32 dc_valid_for, PRTime now) {
  ASSERT_TRUE(EnsureTlsSetup());

  ScopedSECKEYPublicKey pub;
  ScopedSECKEYPrivateKey priv;
  EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv));

  StackSECItem dc;
  TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for,
                               now, &dc);

  SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr,
                                       nullptr,       &dc,     priv.get()};
  EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data));
}

bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits,
                                const SSLExtraServerCertData* serverCertData) {
  ScopedCERTCertificate cert;
  ScopedSECKEYPrivateKey priv;
  if (!TlsAgent::LoadCertificate(id, &cert, &priv)) {
    return false;
  }

  if (updateKeyBits) {
    ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
    EXPECT_NE(nullptr, pub.get());
    if (!pub.get()) return false;
    server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
  }

  SECStatus rv =
      SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null);
  EXPECT_EQ(SECFailure, rv);
  rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData,
                            serverCertData ? sizeof(*serverCertData) : 0);
  return rv == SECSuccess;
}

bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
  // Don't set up twice
  if (ssl_fd_) return true;
  NssManagePolicy policyManage(policy_, option_);

  ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
  EXPECT_NE(nullptr, dummy_fd);
  if (!dummy_fd) {
    return false;
  }
  if (adapter_->variant() == ssl_variant_stream) {
    ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get()));
  } else {
    ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get()));
  }

  EXPECT_NE(nullptr, ssl_fd_);
  if (!ssl_fd_) {
    return false;
  }
  dummy_fd.release();  // Now subsumed by ssl_fd_.

  SECStatus rv;
  if (!skip_version_checks_) {
    rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
    EXPECT_EQ(SECSuccess, rv);
    if (rv != SECSuccess) return false;
  }

  ScopedCERTCertList anchors(CERT_NewCertList());
  rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
  if (rv != SECSuccess) return false;

  if (role_ == SERVER) {
    EXPECT_TRUE(ConfigServerCert(name_, true));

    rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this);
    EXPECT_EQ(SECSuccess, rv);
    if (rv != SECSuccess) return false;

    rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024);
    EXPECT_EQ(SECSuccess, rv);
    if (rv != SECSuccess) return false;
  } else {
    rv = SSL_SetURL(ssl_fd(), "server");
    EXPECT_EQ(SECSuccess, rv);
    if (rv != SECSuccess) return false;
  }

  rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this);
  EXPECT_EQ(SECSuccess, rv);
  if (rv != SECSuccess) return false;

  rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this);
  EXPECT_EQ(SECSuccess, rv);
  if (rv != SECSuccess) return false;

  rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this);
  EXPECT_EQ(SECSuccess, rv);
  if (rv != SECSuccess) return false;

  rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this);
  EXPECT_EQ(SECSuccess, rv);
  if (rv != SECSuccess) return false;

  // All these tests depend on having this disabled to start with.
  SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE);

  return true;
}

bool TlsAgent::MaybeSetResumptionToken() {
  if (!resumption_token_.empty()) {
    LOG("setting external resumption token");
    SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(),
                                          resumption_token_.size());

    // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR
    // if the resumption token was bad (expired/malformed/etc.).
    if (expect_psk_ == ssl_psk_resume) {
      // Only in case we expect resumption this has to be successful. We might
      // not expect resumption due to some reason but the token is totally fine.
      EXPECT_EQ(SECSuccess, rv);
    }
    if (rv != SECSuccess) {
      EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError());
      resumption_token_.clear();
      EXPECT_FALSE(expect_psk_ == ssl_psk_resume);
      if (expect_psk_ == ssl_psk_resume) return false;
    }
  }

  return true;
}

void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) {
  EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get()));
}

// Defaults to a Sync callback returning success
void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType,
                               bool callbackSuccess) {
  EXPECT_TRUE(EnsureTlsSetup());
  ASSERT_EQ(CLIENT, role_);

  client_auth_callback_type_ = callbackType;
  client_auth_callback_success_ = callbackSuccess;

  if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) {
    // Don't set a callback for this case.
    return;
  }
  EXPECT_EQ(SECSuccess,
            SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
                                      reinterpret_cast<void*>(this)));
}

void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) {
  ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr));

  ASSERT_EQ(expected->nnames, caNames->nnames);

  for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) {
    EXPECT_EQ(SECEqual,
              SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i])));
  }
}

// Complete processing of Client Certificate Selection
// A No-op if the agent is using synchronous client cert selection.
// Otherwise, calls SSL_ClientCertCallbackComplete.
// kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to
// ensure that the socket is correctly blocked.
void TlsAgent::ClientAuthCallbackComplete() {
  ASSERT_EQ(CLIENT, role_);

  if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay &&
      client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) {
    return;
  }
  client_auth_callback_fired_++;
  EXPECT_TRUE(client_auth_callback_awaiting_);

  std::cerr << "client: calling SSL_ClientCertCallbackComplete with status "
            << (client_auth_callback_success_ ? "success" : "failed")
            << std::endl;

  client_auth_callback_awaiting_ = false;

  if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) {
    std::cerr
        << "Running Handshake prior to running SSL_ClientCertCallbackComplete"
        << std::endl;
    SECStatus rv = SSL_ForceHandshake(ssl_fd());
    EXPECT_EQ(rv, SECFailure);
    EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR);
  }

  ScopedCERTCertificate cert;
  ScopedSECKEYPrivateKey priv;
  if (client_auth_callback_success_) {
    ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv));
    EXPECT_EQ(SECSuccess,
              SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess,
                                             priv.release(), cert.release()));
  } else {
    EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure,
                                                         nullptr, nullptr));
  }
}

SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
                                          CERTDistNames* caNames,
                                          CERTCertificate** clientCert,
                                          SECKEYPrivateKey** clientKey) {
  TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
  EXPECT_EQ(CLIENT, agent->role_);
  agent->client_auth_callback_fired_++;

  switch (agent->client_auth_callback_type_) {
    case ClientAuthCallbackType::kAsyncDelay:
    case ClientAuthCallbackType::kAsyncImmediate:
      std::cerr << "Waiting for complete call" << std::endl;
      agent->client_auth_callback_awaiting_ = true;
      return SECWouldBlock;
    case ClientAuthCallbackType::kSync:
    case ClientAuthCallbackType::kNone:
      // Handle the sync case. None && Success is treated as Sync and Success.
      if (!agent->client_auth_callback_success_) {
        return SECFailure;
      }
      ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
      EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";

      // See bug 1573945
      // CheckCertReqAgainstDefaultCAs(caNames);

      ScopedCERTCertificate cert;
      ScopedSECKEYPrivateKey priv;
      if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
        return SECFailure;
      }

      *clientCert = cert.release();
      *clientKey = priv.release();
      return SECSuccess;
  }
  /* This is unreachable, but some old compilers can't tell that. */
  PORT_Assert(0);
  PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
  return SECFailure;
}

// Increments by 1 for each callback
bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) {
  EXPECT_EQ(CLIENT, role_);
  return expected == client_auth_callback_fired_;
}

bool TlsAgent::GetPeerChainLength(size_t* count) {
  SECItemArray* chain = nullptr;
  SECStatus rv = SSL_PeerCertificateChainDER(ssl_fd(), &chain);
  if (rv != SECSuccess) return false;

  *count = chain->len;

  SECITEM_FreeArray(chain, true);

  return true;
}

void TlsAgent::CheckPeerChainFunctionConsistency() {
  SECItemArray* derChain = nullptr;
  SECStatus rv = SSL_PeerCertificateChainDER(ssl_fd(), &derChain);
  PRErrorCode err1 = PR_GetError();
  CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd());
  PRErrorCode err2 = PR_GetError();
  if (rv != SECSuccess) {
    ASSERT_EQ(nullptr, chain);
    ASSERT_EQ(nullptr, derChain);
    ASSERT_EQ(err1, SSL_ERROR_NO_CERTIFICATE);
    ASSERT_EQ(err2, SSL_ERROR_NO_CERTIFICATE);
    return;
  }
  ASSERT_NE(nullptr, chain);
  ASSERT_NE(nullptr, derChain);

  unsigned int count = 0;
  for (PRCList* cursor = PR_NEXT_LINK(&chain->list);
       count < derChain->len && cursor != &chain->list;
       cursor = PR_NEXT_LINK(cursor)) {
    CERTCertListNode* node = (CERTCertListNode*)cursor;
    EXPECT_TRUE(
        SECITEM_ItemsAreEqual(&node->cert->derCert, &derChain->items[count]));
    ++count;
  }
  ASSERT_EQ(count, derChain->len);

  SECITEM_FreeArray(derChain, true);
  CERT_DestroyCertList(chain);
}

void TlsAgent::CheckCipherSuite(uint16_t suite) {
  EXPECT_EQ(csinfo_.cipherSuite, suite);
}

void TlsAgent::RequestClientAuth(bool requireAuth) {
  ASSERT_EQ(SERVER, role_);

  SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
  SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE);

  EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
                            ssl_fd(), &TlsAgent::ClientAuthenticated, this));
  expect_client_auth_ = true;
}

void TlsAgent::StartConnect(PRFileDesc* model) {
  EXPECT_TRUE(EnsureTlsSetup(model));

  SECStatus rv;
  rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE);
  EXPECT_EQ(SECSuccess, rv);
  SetState(STATE_CONNECTING);
}

void TlsAgent::DisableAllCiphers() {
  for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
    SECStatus rv =
        SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE);
    EXPECT_EQ(SECSuccess, rv);
  }
}

// Not actually all groups, just the ones that we are actually willing
// to use.
const std::vector<SSLNamedGroup> kAllDHEGroups = {
    ssl_grp_ec_curve25519,   ssl_grp_ec_secp256r1,       ssl_grp_ec_secp384r1,
    ssl_grp_ec_secp521r1,    ssl_grp_ffdhe_2048,         ssl_grp_ffdhe_3072,
    ssl_grp_ffdhe_4096,      ssl_grp_ffdhe_6144,         ssl_grp_ffdhe_8192,
    ssl_grp_kem_xyber768d00, ssl_grp_kem_mlkem768x25519,
};

const std::vector<SSLNamedGroup> kECDHEGroups = {
    ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1,    ssl_grp_ec_secp384r1,
    ssl_grp_ec_secp521r1,  ssl_grp_kem_xyber768d00, ssl_grp_kem_mlkem768x25519,
};

const std::vector<SSLNamedGroup> kFFDHEGroups = {
    ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
    ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};

// Defined because the big DHE groups are ridiculously slow.
const std::vector<SSLNamedGroup> kFasterDHEGroups = {
    ssl_grp_ec_curve25519,      ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
    ssl_grp_ffdhe_2048,         ssl_grp_ffdhe_3072,   ssl_grp_kem_xyber768d00,
    ssl_grp_kem_mlkem768x25519,
};

const std::vector<SSLNamedGroup> kEcdhHybridGroups = {
    ssl_grp_kem_xyber768d00,
    ssl_grp_kem_mlkem768x25519,
};

void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
  EXPECT_TRUE(EnsureTlsSetup());

  for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
    SSLCipherSuiteInfo csinfo;

    SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
                                          sizeof(csinfo));
    ASSERT_EQ(SECSuccess, rv);
    EXPECT_EQ(sizeof(csinfo), csinfo.length);

    if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
      rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
      EXPECT_EQ(SECSuccess, rv);
    }
  }
}

void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) {
  switch (kea) {
    case ssl_kea_dh:
      ConfigNamedGroups(kFFDHEGroups);
      break;
    case ssl_kea_ecdh:
      ConfigNamedGroups(kECDHEGroups);
      break;
    case ssl_kea_ecdh_hybrid:
      ConfigNamedGroups(kEcdhHybridGroups);
      break;
    default:
      break;
  }
}

void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) {
  if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa ||
      authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) {
    ConfigNamedGroups(kECDHEGroups);
  }
}

void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
  EXPECT_TRUE(EnsureTlsSetup());

  for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
    SSLCipherSuiteInfo csinfo;

    SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
                                          sizeof(csinfo));
    ASSERT_EQ(SECSuccess, rv);

    if ((csinfo.authType == authType) ||
        (csinfo.keaType == ssl_kea_tls13_any)) {
      rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
      EXPECT_EQ(SECSuccess, rv);
    }
  }
}

void TlsAgent::EnableSingleCipher(uint16_t cipher) {
  DisableAllCiphers();
  SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE);
  EXPECT_EQ(SECSuccess, rv);
}

void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
  EXPECT_TRUE(EnsureTlsSetup());
  SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size());
  EXPECT_EQ(SECSuccess, rv);
}

void TlsAgent::Set0RttEnabled(bool en) {
  SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
}

void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
  vrange_.min = minver;
  vrange_.max = maxver;

  if (ssl_fd()) {
    SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
    EXPECT_EQ(SECSuccess, rv);
  }
}

SECStatus ResumptionTokenCallback(PRFileDesc* fd,
                                  const PRUint8* resumptionToken,
                                  unsigned int len, void* ctx) {
  EXPECT_NE(nullptr, resumptionToken);
  if (!resumptionToken) {
    return SECFailure;
  }

  std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len);
  reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token);
  reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled();
  return SECSuccess;
}

void TlsAgent::SetResumptionTokenCallback() {
  EXPECT_TRUE(EnsureTlsSetup());
  SECStatus rv =
      SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this);
  EXPECT_EQ(SECSuccess, rv);
}

void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
  *minver = vrange_.min;
  *maxver = vrange_.max;
}

void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; }

void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }

void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }

void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }

void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
                                   size_t count) {
  EXPECT_TRUE(EnsureTlsSetup());
  EXPECT_LE(count, SSL_SignatureMaxCount());
  EXPECT_EQ(SECSuccess,
            SSL_SignatureSchemePrefSet(ssl_fd(), schemes,
                                       static_cast<unsigned int>(count)));
  EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0))
      << "setting no schemes should fail and do nothing";

  std::vector<SSLSignatureScheme> configuredSchemes(count);
  unsigned int configuredCount;
  EXPECT_EQ(SECFailure,
            SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1))
      << "get schemes, schemes is nullptr";
  EXPECT_EQ(SECFailure,
            SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0],
                                       &configuredCount, 0))
      << "get schemes, too little space";
  EXPECT_EQ(SECFailure,
            SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr,
                                       configuredSchemes.size()))
      << "get schemes, countOut is nullptr";

  EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
                            ssl_fd(), &configuredSchemes[0], &configuredCount,
                            configuredSchemes.size()));
  // SignatureSchemePrefSet drops unsupported algorithms silently, so the
  // number that are configured might be fewer.
  EXPECT_LE(configuredCount, count);
  unsigned int i = 0;
  for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
    if (i < configuredCount && schemes[j] == configuredSchemes[i]) {
      ++i;
    }
  }
  EXPECT_EQ(i, configuredCount) << "schemes in use were all set";
}

void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group,
                        size_t kea_size) const {
  EXPECT_EQ(STATE_CONNECTED, state_);
  EXPECT_EQ(kea, info_.keaType);
  if (kea_size == 0) {
    switch (kea_group) {
      case ssl_grp_ec_curve25519:
      case ssl_grp_kem_xyber768d00:
      case ssl_grp_kem_mlkem768x25519:
        kea_size = 255;
        break;
      case ssl_grp_ec_secp256r1:
        kea_size = 256;
        break;
      case ssl_grp_ec_secp384r1:
        kea_size = 384;
        break;
      case ssl_grp_ffdhe_2048:
        kea_size = 2048;
        break;
      case ssl_grp_ffdhe_3072:
        kea_size = 3072;
        break;
      case ssl_grp_ffdhe_custom:
        break;
      default:
        if (kea == ssl_kea_rsa) {
          kea_size = server_key_bits_;
        } else {
          EXPECT_TRUE(false) << "need to update group sizes";
        }
    }
  }
  if (kea_group != ssl_grp_ffdhe_custom) {
    EXPECT_EQ(kea_size, info_.keaKeyBits);
    EXPECT_EQ(kea_group, info_.keaGroup);
  }
}

void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const {
  if (kea_group != ssl_grp_ffdhe_custom) {
    EXPECT_EQ(kea_group, info_.originalKeaGroup);
  }
}

void TlsAgent::CheckAuthType(SSLAuthType auth,
                             SSLSignatureScheme sig_scheme) const {
  EXPECT_EQ(STATE_CONNECTED, state_);
  EXPECT_EQ(auth, info_.authType);
  if (auth != ssl_auth_psk) {
    EXPECT_EQ(server_key_bits_, info_.authKeyBits);
  }
  if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
    switch (auth) {
      case ssl_auth_rsa_sign:
        sig_scheme = ssl_sig_rsa_pkcs1_sha1md5;
        break;
      case ssl_auth_ecdsa:
        sig_scheme = ssl_sig_ecdsa_sha1;
        break;
      default:
        break;
    }
  }
  EXPECT_EQ(sig_scheme, info_.signatureScheme);

  if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
    return;
  }

  // Check authAlgorithm, which is the old value for authType.  This is a second
  // switch statement because default label is different.
  switch (auth) {
    case ssl_auth_rsa_sign:
    case ssl_auth_rsa_pss:
      EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
          << "authAlgorithm for RSA is always decrypt";
      break;
    case ssl_auth_ecdh_rsa:
      EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
          << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)";
      break;
    case ssl_auth_ecdh_ecdsa:
      EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm)
          << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
      break;
    default:
      EXPECT_EQ(auth, csinfo_.authAlgorithm)
          << "authAlgorithm is (usually) the same as authType";
      break;
  }
}

void TlsAgent::EnableFalseStart() {
  EXPECT_TRUE(EnsureTlsSetup());

  falsestart_enabled_ = true;
  EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
                            ssl_fd(), CanFalseStartCallback, this));
  SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
}

void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; }

void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; }

void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; }

void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
  EXPECT_TRUE(EnsureTlsSetup());
  EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
}

void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label,
                      SSLHashType hash, uint16_t zeroRttSuite) {
  EXPECT_TRUE(EnsureTlsSetup());
  EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt(
                            ssl_fd(), psk.get(),
                            reinterpret_cast<const uint8_t*>(label.data()),
                            label.length(), hash, zeroRttSuite, 1000));
}

void TlsAgent::RemovePsk(std::string label) {
  EXPECT_EQ(SECSuccess,
            SSL_RemoveExternalPsk(
                ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()),
                label.length()));
}

void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
                         const std::string& expected) const {
  SSLNextProtoState alpn_state;
  char chosen[10];
  unsigned int chosen_len;
  SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state,
                                  reinterpret_cast<unsigned char*>(chosen),
                                  &chosen_len, sizeof(chosen));
  EXPECT_EQ(SECSuccess, rv);
  EXPECT_EQ(expected_state, alpn_state);
  if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) {
    EXPECT_EQ("", expected);
  } else {
    EXPECT_NE("", expected);
    EXPECT_EQ(expected, std::string(chosen, chosen_len));
  }
}

void TlsAgent::CheckEpochs(uint16_t expected_read,
                           uint16_t expected_write) const {
  uint16_t read_epoch = 0;
  uint16_t write_epoch = 0;
  EXPECT_EQ(SECSuccess,
            SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch));
  EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch";
  EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch";
}

void TlsAgent::EnableSrtp() {
  EXPECT_TRUE(EnsureTlsSetup());
  const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
                              SRTP_AES128_CM_HMAC_SHA1_32};
  EXPECT_EQ(SECSuccess,
            SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers)));
}

void TlsAgent::CheckSrtp() const {
  uint16_t actual;
  EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual));
  EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
}

void TlsAgent::CheckErrorCode(int32_t expected) const {
  EXPECT_EQ(STATE_ERROR, state_);
  EXPECT_EQ(expected, error_code_)
      << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
      << PORT_ErrorToName(expected) << std::endl;
}

static uint8_t GetExpectedAlertLevel(uint8_t alert) {
  if (alert == kTlsAlertCloseNotify) {
    return kTlsAlertWarning;
  }
  return kTlsAlertFatal;
}

void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) {
  expected_received_alert_ = alert;
  if (level == 0) {
    expected_received_alert_level_ = GetExpectedAlertLevel(alert);
  } else {
    expected_received_alert_level_ = level;
  }
}

void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) {
  expected_sent_alert_ = alert;
  if (level == 0) {
    expected_sent_alert_level_ = GetExpectedAlertLevel(alert);
  } else {
    expected_sent_alert_level_ = level;
  }
}

void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) {
  LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal")
      << " alert " << (sent ? "sent" : "received") << ": "
      << static_cast<int>(alert->description));

  auto& expected = sent ? expected_sent_alert_ : expected_received_alert_;
  auto& expected_level =
      sent ? expected_sent_alert_level_ : expected_received_alert_level_;
  /* Silently pass close_notify in case the test has already ended. */
  if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning &&
      alert->description == expected && alert->level == expected_level) {
    return;
  }

  EXPECT_EQ(expected, alert->description);
  EXPECT_EQ(expected_level, alert->level);
  expected = kTlsAlertCloseNotify;
  expected_level = kTlsAlertWarning;
}

void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
  ASSERT_EQ(0, error_code_);
  WAIT_(error_code_ != 0, delay);
  EXPECT_EQ(expected, error_code_)
      << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
      << PORT_ErrorToName(expected) << std::endl;
}

void TlsAgent::CheckPreliminaryInfo() {
  SSLPreliminaryChannelInfo preinfo;
  EXPECT_EQ(SECSuccess,
            SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo)));
  EXPECT_EQ(sizeof(preinfo), preinfo.length);
  EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);

  // A version of 0 is invalid and indicates no expectation.  This value is
  // initialized to 0 so that tests that don't explicitly set an expected
  // version can negotiate a version.
  if (!expected_version_) {
    expected_version_ = preinfo.protocolVersion;
  }
  EXPECT_EQ(expected_version_, preinfo.protocolVersion);

  // As with the version; 0 is the null cipher suite (and also invalid).
  if (!expected_cipher_suite_) {
    expected_cipher_suite_ = preinfo.cipherSuite;
  }
  EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite);
}

// Check that all the expected callbacks have been called.
void TlsAgent::CheckCallbacks() const {
  // If false start happens, the handshake is reported as being complete at the
  // point that false start happens.
  if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) {
    EXPECT_TRUE(handshake_callback_called_);
  }

  // These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
  if (role_ == SERVER) {
    PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn);
    EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) ||
               expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
              sni_hook_called_);
  } else {
    EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_);
    // Note that this isn't unconditionally called, even with false start on.
    // But the callback is only skipped if a cipher that is ridiculously weak
    // (80 bits) is chosen.  Don't test that: plan to remove bad ciphers.
    EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume,
              can_falsestart_hook_called_);
  }
}

void TlsAgent::ResetPreliminaryInfo() {
  expected_version_ = 0;
  expected_cipher_suite_ = 0;
}

void TlsAgent::UpdatePreliminaryChannelInfo() {
  SECStatus rv =
      SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_));
  EXPECT_EQ(SECSuccess, rv);
  EXPECT_EQ(sizeof(pre_info_), pre_info_.length);
}

void TlsAgent::ValidateCipherSpecs() {
  PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd());
  // We use one ciphersuite in each direction.
  PRInt32 expected = 2;
  if (variant_ == ssl_variant_datagram) {
    // For DTLS 1.3, the client retains the cipher spec for early data and the
    // handshake so that it can retransmit EndOfEarlyData and its final flight.
    // It also retains the handshake read cipher spec so that it can read ACKs
    // from the server. The server retains the handshake read cipher spec so it
    // can read the client's retransmitted Finished.
    if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
      if (role_ == CLIENT) {
        expected = info_.earlyDataAccepted ? 5 : 4;
      } else {
        expected = 3;
      }
    } else {
      // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec
      // until the holddown timer runs down.
      if (expect_psk_ == ssl_psk_resume) {
        if (role_ == CLIENT) {
          expected = 3;
        }
      } else {
        if (role_ == SERVER) {
          expected = 3;
        }
      }
    }
  }
  // This function will be run before the handshake completes if false start is
  // enabled.  In that case, the client will still be reading cleartext, but
  // will have a spec prepared for reading ciphertext.  With DTLS, the client
  // will also have a spec retained for retransmission of handshake messages.
  if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) {
    EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_);
    expected = (variant_ == ssl_variant_datagram) ? 4 : 3;
  }
  EXPECT_EQ(expected, cipherSpecs);
  if (expected != cipherSpecs) {
    SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd());
  }
}

void TlsAgent::Connected() {
  if (state_ == STATE_CONNECTED) {
    return;
  }

  LOG("Handshake success");
  CheckPreliminaryInfo();
  CheckCallbacks();

  SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_));
  EXPECT_EQ(SECSuccess, rv);
  EXPECT_EQ(sizeof(info_), info_.length);

  EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE);
  EXPECT_EQ(expect_psk_, info_.pskType);
  EXPECT_EQ(expect_ech_, info_.echAccepted);

  // Preliminary values are exposed through callbacks during the handshake.
  // If either expected values were set or the callbacks were called, check
  // that the final values are correct.
  UpdatePreliminaryChannelInfo();
  EXPECT_EQ(expected_version_, info_.protocolVersion);
  EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);

  rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
  EXPECT_EQ(SECSuccess, rv);
  EXPECT_EQ(sizeof(csinfo_), csinfo_.length);

  ValidateCipherSpecs();

  SetState(STATE_CONNECTED);
}

void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) {
  EXPECT_FALSE(client_auth_callback_awaiting_);
  switch (client_auth_callback_type_) {
    case ClientAuthCallbackType::kNone:
      if (!client_auth_callback_success_) {
        EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0));
        break;
      }
    case ClientAuthCallbackType::kSync:
      EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes));
      break;
    case ClientAuthCallbackType::kAsyncDelay:
    case ClientAuthCallbackType::kAsyncImmediate:
      EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes));
      break;
  }
}

void TlsAgent::EnableExtendedMasterSecret() {
  SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
}

void TlsAgent::CheckExtendedMasterSecret(bool expected) {
  if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) {
    expected = PR_TRUE;
  }
  ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
      << "unexpected extended master secret state for " << name_;
}

void TlsAgent::CheckEarlyDataAccepted(bool expected) {
  if (version() < SSL_LIBRARY_VERSION_TLS_1_3) {
    expected = false;
  }
  ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE)
      << "unexpected early data state for " << name_;
}

void TlsAgent::CheckSecretsDestroyed() {
  ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
}

void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) {
  ASSERT_TRUE(EnsureTlsSetup());

  SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver);
  ASSERT_EQ(SECSuccess, rv);
}

void TlsAgent::Handshake() {
  LOGV("Handshake");
  SECStatus rv = SSL_ForceHandshake(ssl_fd());
  if (client_auth_callback_awaiting_) {
    ClientAuthCallbackComplete();
    rv = SSL_ForceHandshake(ssl_fd());
  }
  if (rv == SECSuccess) {
    Connected();
    Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
                             &TlsAgent::ReadableCallback);
    return;
  }

  int32_t err = PR_GetError();
  if (err == PR_WOULD_BLOCK_ERROR) {
    LOGV("Would have blocked");
    if (variant_ == ssl_variant_datagram) {
      if (timer_handle_) {
        timer_handle_->Cancel();
        timer_handle_ = nullptr;
      }

      PRIntervalTime timeout;
      rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout);
      if (rv == SECSuccess) {
        Poller::Instance()->SetTimer(
            timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
      }
    }
    Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
                             &TlsAgent::ReadableCallback);
    return;
  }

  if (err != 0) {
    LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": "
                                       << PORT_ErrorToString(err));
  }

  error_code_ = err;
  SetState(STATE_ERROR);
}

void TlsAgent::PrepareForRenegotiate() {
  EXPECT_EQ(STATE_CONNECTED, state_);

  SetState(STATE_CONNECTING);
}

void TlsAgent::StartRenegotiate() {
  PrepareForRenegotiate();

  SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE);
  EXPECT_EQ(SECSuccess, rv);
}

void TlsAgent::SendDirect(const DataBuffer& buf) {
  LOG("Send Direct " << buf);
  auto peer = adapter_->peer().lock();
  if (peer) {
    peer->PacketReceived(buf);
  } else {
    LOG("Send Direct peer absent");
  }
}

void TlsAgent::SendRecordDirect(const TlsRecord& record) {
  DataBuffer buf;

  auto rv = record.header.Write(&buf, 0, record.buffer);
  EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv);
  SendDirect(buf);
}

static bool ErrorIsFatal(PRErrorCode code) {
  return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ;
}

void TlsAgent::SendData(size_t bytes, size_t blocksize) {
  uint8_t block[16385];  // One larger than the maximum record size.

  ASSERT_LE(blocksize, sizeof(block));

  while (bytes) {
    size_t tosend = std::min(blocksize, bytes);

    for (size_t i = 0; i < tosend; ++i) {
      block[i] = 0xff & send_ctr_;
      ++send_ctr_;
    }

    SendBuffer(DataBuffer(block, tosend));
    bytes -= tosend;
  }
}

void TlsAgent::SendBuffer(const DataBuffer& buf) {
  LOGV("Writing " << buf.len() << " bytes");
  int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len());
  if (expect_readwrite_error_) {
    EXPECT_GT(0, rv);
    EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
    error_code_ = PR_GetError();
    expect_readwrite_error_ = false;
  } else {
    ASSERT_EQ(buf.len(), static_cast<size_t>(rv));
  }
}

bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
                                   uint64_t seq, uint8_t ct,
                                   const DataBuffer& buf) {
  // Ensure that we are doing TLS 1.3.
  EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
  if (variant_ != ssl_variant_datagram) {
    ADD_FAILURE();
    return false;
  }

  LOGV("Encrypting " << buf.len() << " bytes");
  uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
                      kCtDtlsCiphertextLengthPresent;
  TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq);
  TlsRecordHeader out_header(header);
  DataBuffer padded = buf;
  padded.Write(padded.len(), ct, 1);
  DataBuffer ciphertext;
  if (!spec->Protect(header, padded, &ciphertext, &out_header)) {
    return false;
  }

  DataBuffer record;
  auto rv = out_header.Write(&record, 0, ciphertext);
  EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
  SendDirect(record);
  return true;
}

void TlsAgent::ReadBytes(size_t amount) {
  uint8_t block[16384];

  size_t remaining = amount;
  while (remaining > 0) {
    int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block)));
    LOGV("ReadBytes " << rv);

    if (rv > 0) {
      size_t count = static_cast<size_t>(rv);
      for (size_t i = 0; i < count; ++i) {
        ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
        recv_ctr_++;
      }
      remaining -= rv;
    } else {
      PRErrorCode err = 0;
      if (rv < 0) {
        err = PR_GetError();
        if (err != 0) {
          LOG("Read error " << PORT_ErrorToName(err) << ": "
                            << PORT_ErrorToString(err));
        }
        if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
          if (ErrorIsFatal(err)) {
            SetState(STATE_ERROR);
          }
          error_code_ = err;
          expect_readwrite_error_ = false;
        }
      }
      if (err != 0 && ErrorIsFatal(err)) {
        // If we hit a fatal error, we're done.
        remaining = 0;
      }
      break;
    }
  }

  // If closed, then don't bother waiting around.
  if (remaining) {
    LOGV("Re-arming");
    Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
                             &TlsAgent::ReadableCallback);
  }
}

void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; }

void TlsAgent::SetOption(int32_t option, int value) {
  ASSERT_TRUE(EnsureTlsSetup());
  EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value));
}

void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
  SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
  SetOption(SSL_ENABLE_SESSION_TICKETS,
            mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
}

void TlsAgent::EnableECDHEServerKeyReuse() {
  ASSERT_EQ(TlsAgent::SERVER, role_);
  SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
}

static const std::string kTlsRolesAllArr[] = {"CLIENT""SERVER"};
::testing::internal::ParamGenerator<std::string>
    TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);

void TlsAgentTestBase::SetUp() {
  SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
}

void TlsAgentTestBase::TearDown() {
  agent_ = nullptr;
  SSL_ClearSessionCache();
  SSL_ShutdownServerSessionIDCache();
}

void TlsAgentTestBase::Reset(const std::string& server_name) {
  agent_.reset(
      new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
                   role_, variant_));
  if (version_) {
    agent_->SetVersionRange(version_, version_);
  }
  agent_->adapter()->SetPeer(sink_adapter_);
  agent_->StartConnect();
}

void TlsAgentTestBase::EnsureInit() {
  if (!agent_) {
    Reset();
  }
  const std::vector<SSLNamedGroup> groups = {
      ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
      ssl_grp_ffdhe_2048};
  agent_->ConfigNamedGroups(groups);
}

void TlsAgentTestBase::ExpectAlert(uint8_t alert) {
  EnsureInit();
  agent_->ExpectSendAlert(alert);
}

void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
                                      TlsAgent::State expected_state,
                                      int32_t error_code) {
  std::cerr << "Process message: " << buffer << std::endl;
  EnsureInit();
  agent_->adapter()->PacketReceived(buffer);
  agent_->Handshake();

  ASSERT_EQ(expected_state, agent_->state());

  if (expected_state == TlsAgent::STATE_ERROR) {
    ASSERT_EQ(error_code, agent_->error_code());
  }
}

void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
                                  uint16_t version, const uint8_t* buf,
                                  size_t len, DataBuffer* out,
                                  uint64_t sequence_number) {
  // Fixup the content type for DTLSCiphertext
  if (variant == ssl_variant_datagram &&
      version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
      type == ssl_ct_application_data) {
    type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
           kCtDtlsCiphertextLengthPresent;
  }

  size_t index = 0;
  if (variant == ssl_variant_stream) {
    index = out->Write(index, type, 1);
    index = out->Write(index, version, 2);
  } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
             (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) {
    uint32_t epoch = (sequence_number >> 48) & 0x3;
    index = out->Write(index, type | epoch, 1);
    uint32_t seqno = sequence_number & ((1ULL << 16) - 1);
    index = out->Write(index, seqno, 2);
  } else {
    index = out->Write(index, type, 1);
    index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
    index = out->Write(index, sequence_number >> 32, 4);
    index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
  }
  index = out->Write(index, len, 2);
  out->Write(index, buf, len);
}

void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
                                  const uint8_t* buf, size_t len,
                                  DataBuffer* out, uint64_t seq_num) const {
  MakeRecord(variant_, type, version, buf, len, out, seq_num);
}

void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
                                            const uint8_t* data, size_t hs_len,
                                            DataBuffer* out,
                                            uint64_t seq_num) const {
  return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0,
                                      0);
}

void TlsAgentTestBase::MakeHandshakeMessageFragment(
    uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out,
    uint64_t seq_num, uint32_t fragment_offset,
    uint32_t fragment_length) const {
  size_t index = 0;
  if (!fragment_length) fragment_length = hs_len;
  index = out->Write(index, hs_type, 1);  // Handshake record type.
  index = out->Write(index, hs_len, 3);   // Handshake length
  if (variant_ == ssl_variant_datagram) {
    index = out->Write(index, seq_num, 2);
    index = out->Write(index, fragment_offset, 3);
    index = out->Write(index, fragment_length, 3);
  }
  if (data) {
    index = out->Write(index, data, fragment_length);
  } else {
    for (size_t i = 0; i < fragment_length; ++i) {
      index = out->Write(index, 1, 1);
    }
  }
}

void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
                                                  size_t hs_len,
                                                  DataBuffer* out) {
  size_t index = 0;
  index = out->Write(index, ssl_ct_handshake, 1);  // Content Type
  index = out->Write(index, 3, 1);                 // Version high
  index = out->Write(index, 1, 1);                 // Version low
  index = out->Write(index, 4 + hs_len, 2);        // Length

  index = out->Write(index, hs_type, 1);  // Handshake record type.
  index = out->Write(index, hs_len, 3);   // Handshake length
  for (size_t i = 0; i < hs_len; ++i) {
    index = out->Write(index, 1, 1);
  }
}

DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() {
  DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello));
  if (variant_ == ssl_variant_datagram) {
    sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2);
    // The version should be at the end.
    uint32_t v;
    EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v));
    EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v);
    sh.Write(sh.len() - 2, SSL_LIBRARY_VERSION_DTLS_1_3_WIRE, 2);
  }
  return sh;
}

}  // namespace nss_test

Messung V0.5
C=95 H=96 G=95

¤ Dauer der Verarbeitung: 0.18 Sekunden  (vorverarbeitet)  ¤

*© Formatika GbR, Deutschland






Wurzel

Suchen

Beweissystem der NASA

Beweissystem Isabelle

NIST Cobol Testsuite

Cephes Mathematical Library

Wiener Entwicklungsmethode

Haftungshinweis

Die Informationen auf dieser Webseite wurden nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit, noch Qualität der bereit gestellten Informationen zugesichert.

Bemerkung:

Die farbliche Syntaxdarstellung und die Messung sind noch experimentell.






                                                                                                                                                                                                                                                                                                                                                                                                     


Neuigkeiten

     Aktuelles
     Motto des Tages

Software

     Produkte
     Quellcodebibliothek

Aktivitäten

     Artikel über Sicherheit
     Anleitung zur Aktivierung von SSL

Muße

     Gedichte
     Musik
     Bilder

Jenseits des Üblichen ....

Besucherstatistik

Besucherstatistik

Monitoring

Montastic status badge