/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "PublicKeyPinningService.h"

#include "RootCertificateTelemetryUtils.h"
#include "mozilla/Base64.h"
#include "mozilla/Casting.h"
#include "mozilla/Logging.h"
#include "mozilla/Telemetry.h"
#include "nsISiteSecurityService.h"
#include "nsServiceManagerUtils.h"
#include "nsSiteSecurityService.h"
#include "nssb64.h"
#include "pkix/pkixtypes.h"
#include "seccomon.h"
#include "sechash.h"

#include "StaticHPKPins.h" // autogenerated by genHPKPStaticpins.js

using namespace mozilla;
using namespace mozilla::pkix;
using namespace mozilla::psm;

LazyLogModule gPublicKeyPinningLog("PublicKeyPinningService");

/**
 Computes in the location specified by base64Out the SHA256 digest
 of the DER Encoded subject Public Key Info for the given cert
*/
static nsresult
GetBase64HashSPKI(const CERTCertificate* cert, nsACString& hashSPKIDigest)
{
  hashSPKIDigest.Truncate();
  Digest digest;
  nsresult rv = digest.DigestBuf(SEC_OID_SHA256, cert->derPublicKey.data,
                                 cert->derPublicKey.len);
  if (NS_FAILED(rv)) {
    return rv;
  }
  return Base64Encode(nsDependentCSubstring(
                        BitwiseCast<char*, unsigned char*>(digest.get().data),
                        digest.get().len),
                      hashSPKIDigest);
}

/*
 * Sets certMatchesPinset to true if a given cert matches any fingerprints from
 * the given pinset or the dynamicFingerprints array, or to false otherwise.
 */
static nsresult
EvalCert(const CERTCertificate* cert, const StaticFingerprints* fingerprints,
         const nsTArray<nsCString>* dynamicFingerprints,
 /*out*/ bool& certMatchesPinset)
{
  certMatchesPinset = false;
  if (!fingerprints && !dynamicFingerprints) {
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
           ("pkpin: No hashes found\n"));
    return NS_ERROR_INVALID_ARG;
  }

  nsAutoCString base64Out;
  nsresult rv = GetBase64HashSPKI(cert, base64Out);
  if (NS_FAILED(rv)) {
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
           ("pkpin: GetBase64HashSPKI failed!\n"));
    return rv;
  }

  if (fingerprints) {
    for (size_t i = 0; i < fingerprints->size; i++) {
      if (base64Out.Equals(fingerprints->data[i])) {
        MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
               ("pkpin: found pin base_64 ='%s'\n", base64Out.get()));
        certMatchesPinset = true;
        return NS_OK;
      }
    }
  }
  if (dynamicFingerprints) {
    for (size_t i = 0; i < dynamicFingerprints->Length(); i++) {
      if (base64Out.Equals((*dynamicFingerprints)[i])) {
        MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
               ("pkpin: found pin base_64 ='%s'\n", base64Out.get()));
        certMatchesPinset = true;
        return NS_OK;
      }
    }
  }
  return NS_OK;
}

/*
 * Sets certListIntersectsPinset to true if a given chain matches any
 * fingerprints from the given static fingerprints or the
 * dynamicFingerprints array, or to false otherwise.
 */
static nsresult
EvalChain(const UniqueCERTCertList& certList,
          const StaticFingerprints* fingerprints,
          const nsTArray<nsCString>* dynamicFingerprints,
  /*out*/ bool& certListIntersectsPinset)
{
  certListIntersectsPinset = false;
  CERTCertificate* currentCert;

  if (!fingerprints && !dynamicFingerprints) {
    MOZ_ASSERT(false, "Must pass in at least one type of pinset");
    return NS_ERROR_FAILURE;
  }

  CERTCertListNode* node;
  for (node = CERT_LIST_HEAD(certList); !CERT_LIST_END(node, certList);
       node = CERT_LIST_NEXT(node)) {
    currentCert = node->cert;
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
           ("pkpin: certArray subject: '%s'\n", currentCert->subjectName));
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
           ("pkpin: certArray issuer: '%s'\n", currentCert->issuerName));
    nsresult rv = EvalCert(currentCert, fingerprints, dynamicFingerprints,
                           certListIntersectsPinset);
    if (NS_FAILED(rv)) {
      return rv;
    }
    if (certListIntersectsPinset) {
      return NS_OK;
    }
  }
  MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug, ("pkpin: no matches found\n"));
  return NS_OK;
}

/**
  Comparator for the is public key pinned host.
*/
static int
TransportSecurityPreloadCompare(const void* key, const void* entry) {
  auto keyStr = static_cast<const char*>(key);
  auto preloadEntry = static_cast<const TransportSecurityPreload*>(entry);

  return strcmp(keyStr, preloadEntry->mHost);
}

nsresult
PublicKeyPinningService::ChainMatchesPinset(const UniqueCERTCertList& certList,
                                            const nsTArray<nsCString>& aSHA256keys,
                                    /*out*/ bool& chainMatchesPinset)
{
  return EvalChain(certList, nullptr, &aSHA256keys, chainMatchesPinset);
}

// Returns via one of the output parameters the most relevant pinning
// information that is valid for the given host at the given time.
// Dynamic pins are prioritized over static pins.
static nsresult
FindPinningInformation(const char* hostname, mozilla::pkix::Time time,
               /*out*/ nsTArray<nsCString>& dynamicFingerprints,
               /*out*/ TransportSecurityPreload*& staticFingerprints)
{
  if (!hostname || hostname[0] == 0) {
    return NS_ERROR_INVALID_ARG;
  }
  staticFingerprints = nullptr;
  dynamicFingerprints.Clear();
  nsCOMPtr<nsISiteSecurityService> sssService =
    do_GetService(NS_SSSERVICE_CONTRACTID);
  if (!sssService) {
    return NS_ERROR_FAILURE;
  }
  TransportSecurityPreload* foundEntry = nullptr;
  char* evalHost = const_cast<char*>(hostname);
  char* evalPart;
  // Notice how the (xx = strchr) prevents pins for unqualified domain names.
  while (!foundEntry && (evalPart = strchr(evalHost, '.'))) {
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
           ("pkpin: Querying pinsets for host: '%s'\n", evalHost));
    // Attempt dynamic pins first
    nsresult rv;
    bool found;
    bool includeSubdomains;
    nsTArray<nsCString> pinArray;
    rv = sssService->GetKeyPinsForHostname(evalHost, time, pinArray,
                                           &includeSubdomains, &found);
    if (NS_FAILED(rv)) {
      return rv;
    }
    if (found && (evalHost == hostname || includeSubdomains)) {
      MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
             ("pkpin: Found dyn match for host: '%s'\n", evalHost));
      dynamicFingerprints = pinArray;
      return NS_OK;
    }

    foundEntry = (TransportSecurityPreload *)bsearch(evalHost,
      kPublicKeyPinningPreloadList,
      sizeof(kPublicKeyPinningPreloadList) / sizeof(TransportSecurityPreload),
      sizeof(TransportSecurityPreload),
      TransportSecurityPreloadCompare);
    if (foundEntry) {
      MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
             ("pkpin: Found pinset for host: '%s'\n", evalHost));
      if (evalHost != hostname) {
        if (!foundEntry->mIncludeSubdomains) {
          // Does not apply to this host, continue iterating
          foundEntry = nullptr;
        }
      }
    } else {
      MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
             ("pkpin: Didn't find pinset for host: '%s'\n", evalHost));
    }
    // Add one for '.'
    evalHost = evalPart + 1;
  }

  if (foundEntry && foundEntry->pinset) {
    if (time > TimeFromEpochInSeconds(kPreloadPKPinsExpirationTime /
                                      PR_USEC_PER_SEC)) {
      return NS_OK;
    }
    staticFingerprints = foundEntry;
  }
  return NS_OK;
}

// Returns true via the output parameter if the given certificate list meets
// pinning requirements for the given host at the given time. It must be the
// case that either there is an intersection between the set of hashes of
// subject public key info data in the list and the most relevant non-expired
// pinset for the host or there is no pinning information for the host.
static nsresult
CheckPinsForHostname(const UniqueCERTCertList& certList, const char* hostname,
                     bool enforceTestMode, mozilla::pkix::Time time,
             /*out*/ bool& chainHasValidPins,
    /*optional out*/ PinningTelemetryInfo* pinningTelemetryInfo)
{
  chainHasValidPins = false;
  if (!certList) {
    return NS_ERROR_INVALID_ARG;
  }
  if (!hostname || hostname[0] == 0) {
    return NS_ERROR_INVALID_ARG;
  }

  nsTArray<nsCString> dynamicFingerprints;
  TransportSecurityPreload* staticFingerprints = nullptr;
  nsresult rv = FindPinningInformation(hostname, time, dynamicFingerprints,
                                       staticFingerprints);
  // If we have no pinning information, the certificate chain trivially
  // validates with respect to pinning.
  if (dynamicFingerprints.Length() == 0 && !staticFingerprints) {
    chainHasValidPins = true;
    return NS_OK;
  }
  if (dynamicFingerprints.Length() > 0) {
    return EvalChain(certList, nullptr, &dynamicFingerprints, chainHasValidPins);
  }
  if (staticFingerprints) {
    bool enforceTestModeResult;
    rv = EvalChain(certList, staticFingerprints->pinset, nullptr,
                   enforceTestModeResult);
    if (NS_FAILED(rv)) {
      return rv;
    }
    chainHasValidPins = enforceTestModeResult;
    Telemetry::ID histogram = staticFingerprints->mIsMoz
      ? Telemetry::CERT_PINNING_MOZ_RESULTS
      : Telemetry::CERT_PINNING_RESULTS;
    if (staticFingerprints->mTestMode) {
      histogram = staticFingerprints->mIsMoz
        ? Telemetry::CERT_PINNING_MOZ_TEST_RESULTS
        : Telemetry::CERT_PINNING_TEST_RESULTS;
      if (!enforceTestMode) {
        chainHasValidPins = true;
      }
    }
    // We can collect per-host pinning violations for this host because it is
    // operationally critical to Firefox.
    if (pinningTelemetryInfo) {
      if (staticFingerprints->mId != kUnknownId) {
        int32_t bucket = staticFingerprints->mId * 2
                         + (enforceTestModeResult ? 1 : 0);
        histogram = staticFingerprints->mTestMode
          ? Telemetry::CERT_PINNING_MOZ_TEST_RESULTS_BY_HOST
          : Telemetry::CERT_PINNING_MOZ_RESULTS_BY_HOST;
        pinningTelemetryInfo->certPinningResultBucket = bucket;
      } else {
        pinningTelemetryInfo->certPinningResultBucket =
            enforceTestModeResult ? 1 : 0;
      }
      pinningTelemetryInfo->accumulateResult = true;
      pinningTelemetryInfo->certPinningResultHistogram = histogram;
    }

    // We only collect per-CA pinning statistics upon failures.
    CERTCertListNode* rootNode = CERT_LIST_TAIL(certList);
    // Only log telemetry if the certificate list is non-empty.
    if (!CERT_LIST_END(rootNode, certList)) {
      if (!enforceTestModeResult && pinningTelemetryInfo) {
        int32_t binNumber = RootCABinNumber(&rootNode->cert->derCert);
        if (binNumber != ROOT_CERTIFICATE_UNKNOWN ) {
          pinningTelemetryInfo->accumulateForRoot = true;
          pinningTelemetryInfo->rootBucket = binNumber;
        }
      }
    }

    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
           ("pkpin: Pin check %s for %s host '%s' (mode=%s)\n",
            enforceTestModeResult ? "passed" : "failed",
            staticFingerprints->mIsMoz ? "mozilla" : "non-mozilla",
            hostname, staticFingerprints->mTestMode ? "test" : "production"));
  }

  return NS_OK;
}

nsresult
PublicKeyPinningService::ChainHasValidPins(const UniqueCERTCertList& certList,
                                           const char* hostname,
                                           mozilla::pkix::Time time,
                                           bool enforceTestMode,
                                   /*out*/ bool& chainHasValidPins,
                          /*optional out*/ PinningTelemetryInfo* pinningTelemetryInfo)
{
  chainHasValidPins = false;
  if (!certList) {
    return NS_ERROR_INVALID_ARG;
  }
  if (!hostname || hostname[0] == 0) {
    return NS_ERROR_INVALID_ARG;
  }
  nsAutoCString canonicalizedHostname(CanonicalizeHostname(hostname));
  return CheckPinsForHostname(certList, canonicalizedHostname.get(),
                              enforceTestMode, time, chainHasValidPins,
                              pinningTelemetryInfo);
}

nsresult
PublicKeyPinningService::HostHasPins(const char* hostname,
                                     mozilla::pkix::Time time,
                                     bool enforceTestMode,
                                     /*out*/ bool& hostHasPins)
{
  hostHasPins = false;
  nsAutoCString canonicalizedHostname(CanonicalizeHostname(hostname));
  nsTArray<nsCString> dynamicFingerprints;
  TransportSecurityPreload* staticFingerprints = nullptr;
  nsresult rv = FindPinningInformation(canonicalizedHostname.get(), time,
                                       dynamicFingerprints, staticFingerprints);
  if (NS_FAILED(rv)) {
    return rv;
  }
  if (dynamicFingerprints.Length() > 0) {
    hostHasPins = true;
  } else if (staticFingerprints) {
    hostHasPins = !staticFingerprints->mTestMode || enforceTestMode;
  }
  return NS_OK;
}

nsAutoCString
PublicKeyPinningService::CanonicalizeHostname(const char* hostname)
{
  nsAutoCString canonicalizedHostname(hostname);
  ToLowerCase(canonicalizedHostname);
  while (canonicalizedHostname.Length() > 0 &&
         canonicalizedHostname.Last() == '.') {
    canonicalizedHostname.Truncate(canonicalizedHostname.Length() - 1);
  }
  return canonicalizedHostname;
}