LCOV - code coverage report
Current view: top level - src/bls - bls_batchverifier.h (source / functions) Hit Total Coverage
Test: total_coverage.info Lines: 85 94 90.4 %
Date: 2025-02-23 09:33:43 Functions: 13 18 72.2 %

          Line data    Source code
       1             : // Copyright (c) 2018-2022 The Dash Core developers
       2             : // Copyright (c) 2023 The PIVX Core developers
       3             : // Distributed under the MIT/X11 software license, see the accompanying
       4             : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
       5             : 
       6             : #ifndef PIVX_BLS_BLS_BATCHVERIFIER_H
       7             : #define PIVX_BLS_BLS_BATCHVERIFIER_H
       8             : 
       9             : #include "bls_worker.h"
      10             : 
      11             : #include <map>
      12             : #include <vector>
      13             : 
      14             : template <typename SourceId, typename MessageId>
      15             : class CBLSBatchVerifier
      16             : {
      17             : private:
      18         368 :     struct Message {
      19             :         MessageId msgId;
      20             :         uint256 msgHash;
      21             :         CBLSSignature sig;
      22             :         CBLSPublicKey pubKey;
      23             :     };
      24             : 
      25             :     using MessageMap = std::map<MessageId, Message>;
      26             :     using MessageMapIterator = typename MessageMap::iterator;
      27             :     using MessagesBySourceMap = std::map<SourceId, std::vector<MessageMapIterator>>;
      28             : 
      29             :     bool secureVerification;
      30             :     bool perMessageFallback;
      31             :     size_t subBatchSize;
      32             : 
      33             :     MessageMap messages;
      34             :     MessagesBySourceMap messagesBySource;
      35             : 
      36             : public:
      37             :     std::set<SourceId> badSources;
      38             :     std::set<MessageId> badMessages;
      39             : 
      40             : public:
      41        1279 :     CBLSBatchVerifier(bool _secureVerification, bool _perMessageFallback, size_t _subBatchSize = 0) : secureVerification(_secureVerification),
      42             :                                                                                                       perMessageFallback(_perMessageFallback),
      43        1279 :                                                                                                       subBatchSize(_subBatchSize)
      44             :     {
      45             :     }
      46             : 
      47        2524 :     void PushMessage(const SourceId& sourceId, const MessageId& msgId, const uint256& msgHash, const CBLSSignature& sig, const CBLSPublicKey& pubKey)
      48             :     {
      49        2524 :         assert(sig.IsValid() && pubKey.IsValid());
      50             : 
      51        2524 :         auto it = messages.emplace(msgId, Message{msgId, msgHash, sig, pubKey}).first;
      52        2524 :         messagesBySource[sourceId].emplace_back(it);
      53             : 
      54        2524 :         if (subBatchSize != 0 && messages.size() >= subBatchSize) {
      55           0 :             Verify();
      56           0 :             ClearMessages();
      57             :         }
      58        2524 :     }
      59             : 
      60           0 :     void ClearMessages()
      61             :     {
      62           0 :         messages.clear();
      63           0 :         messagesBySource.clear();
      64           0 :     }
      65             : 
      66             :     size_t GetUniqueSourceCount() const
      67             :     {
      68             :         return messagesBySource.size();
      69             :     }
      70             : 
      71        1279 :     void Verify()
      72             :     {
      73          16 :         std::map<uint256, std::vector<MessageMapIterator>> byMessageHash;
      74             : 
      75        3803 :         for (auto it = messages.begin(); it != messages.end(); ++it) {
      76        2524 :             byMessageHash[it->second.msgHash].emplace_back(it);
      77             :         }
      78             : 
      79        1279 :         if (VerifyBatch(byMessageHash)) {
      80             :             // full batch is valid
      81        1263 :             return;
      82             :         }
      83             : 
      84             :         // revert to per-source verification
      85          96 :         for (const auto& p : messagesBySource) {
      86          80 :             bool batchValid = false;
      87             : 
      88             :             // no need to verify it again if there was just one source
      89          80 :             if (messagesBySource.size() != 1) {
      90          80 :                 byMessageHash.clear();
      91         204 :                 for (auto it = p.second.begin(); it != p.second.end(); ++it) {
      92         124 :                     byMessageHash[(*it)->second.msgHash].emplace_back(*it);
      93             :                 }
      94          80 :                 batchValid = VerifyBatch(byMessageHash);
      95             :             }
      96          80 :             if (!batchValid) {
      97          16 :                 badSources.emplace(p.first);
      98             : 
      99          16 :                 if (perMessageFallback) {
     100             :                     // revert to per-message verification
     101           8 :                     if (p.second.size() == 1) {
     102             :                         // no need to re-verify a single message
     103          86 :                         badMessages.emplace(p.second[0]->second.msgId);
     104             :                     } else {
     105          10 :                         for (const auto& msgIt : p.second) {
     106          16 :                             if (badMessages.count(msgIt->first)) {
     107             :                                 // same message might be invalid from different source, so no need to re-verify it
     108           0 :                                 continue;
     109             :                             }
     110             : 
     111           8 :                             const auto& msg = msgIt->second;
     112           8 :                             if (!msg.sig.VerifyInsecure(msg.pubKey, msg.msgHash)) {
     113          10 :                                 badMessages.emplace(msg.msgId);
     114             :                             }
     115             :                         }
     116             :                     }
     117             :                 }
     118             :             }
     119             :         }
     120             :     }
     121             : 
     122             : private:
     123             :     // All Verify methods take ownership of the passed byMessageHash map and thus might modify the map. This is to avoid
     124             :     // unnecessary copies
     125             : 
     126        1359 :     bool VerifyBatch(std::map<uint256, std::vector<MessageMapIterator>>& byMessageHash)
     127             :     {
     128        1359 :         if (secureVerification) {
     129        1359 :             return VerifyBatchSecure(byMessageHash);
     130             :         } else {
     131        1305 :             return VerifyBatchInsecure(byMessageHash);
     132             :         }
     133             :     }
     134             : 
     135        1305 :     bool VerifyBatchInsecure(const std::map<uint256, std::vector<MessageMapIterator>>& byMessageHash)
     136             :     {
     137        1305 :         CBLSSignature aggSig;
     138        1305 :         std::vector<uint256> msgHashes;
     139        2610 :         std::vector<CBLSPublicKey> pubKeys;
     140        1305 :         std::set<MessageId> dups;
     141             : 
     142        1305 :         msgHashes.reserve(messages.size());
     143        1305 :         pubKeys.reserve(messages.size());
     144             : 
     145        3643 :         for (const auto& p : byMessageHash) {
     146        2338 :             const auto& msgHash = p.first;
     147             : 
     148        2338 :             CBLSPublicKey aggPubKey;
     149             : 
     150        4832 :             for (const auto& msgIt : p.second) {
     151        2494 :                 const auto& msg = msgIt->second;
     152             : 
     153        2494 :                 if (!dups.emplace(msg.msgId).second) {
     154           0 :                     continue;
     155             :                 }
     156             : 
     157        2494 :                 if (!aggSig.IsValid()) {
     158        1305 :                     aggSig = msg.sig;
     159             :                 } else {
     160        1189 :                     aggSig.AggregateInsecure(msg.sig);
     161             :                 }
     162             : 
     163        2494 :                 if (!aggPubKey.IsValid()) {
     164        2338 :                     aggPubKey = msg.pubKey;
     165             :                 } else {
     166         156 :                     aggPubKey.AggregateInsecure(msg.pubKey);
     167             :                 }
     168             :             }
     169             : 
     170        2338 :             if (!aggPubKey.IsValid()) {
     171             :                 // only duplicates for this msgHash
     172           0 :                 continue;
     173             :             }
     174             : 
     175        2338 :             msgHashes.emplace_back(msgHash);
     176        2338 :             pubKeys.emplace_back(aggPubKey);
     177             :         }
     178             : 
     179        1305 :         if (msgHashes.empty()) {
     180             :             return true;
     181             :         }
     182             : 
     183        1305 :         return aggSig.VerifyInsecureAggregated(pubKeys, msgHashes);
     184             :     }
     185             : 
     186             :     bool VerifyBatchSecure(std::map<uint256, std::vector<MessageMapIterator>>& byMessageHash)
     187             :     {
     188             :         // Loop until the byMessageHash map is empty, which means that all messages were verified
     189             :         // The secure form of verification will only aggregate one message for the same message hash, even if multiple
     190             :         // exist (signed with different keys). This avoids the rogue public key attack.
     191             :         // This is slower than the insecure form as it requires more pairings
     192         112 :         while (!byMessageHash.empty()) {
     193          74 :             if (!VerifyBatchSecureStep(byMessageHash)) {
     194             :                 return false;
     195             :             }
     196             :         }
     197             :         return true;
     198             :     }
     199             : 
     200          74 :     bool VerifyBatchSecureStep(std::map<uint256, std::vector<MessageMapIterator>>& byMessageHash)
     201             :     {
     202          74 :         CBLSSignature aggSig;
     203         148 :         std::vector<uint256> msgHashes;
     204         148 :         std::vector<CBLSPublicKey> pubKeys;
     205          74 :         std::set<MessageId> dups;
     206             : 
     207          74 :         msgHashes.reserve(messages.size());
     208          74 :         pubKeys.reserve(messages.size());
     209             : 
     210          74 :         for (auto it = byMessageHash.begin(); it != byMessageHash.end();) {
     211         136 :             const auto& msgHash = it->first;
     212         136 :             auto& messageIts = it->second;
     213         136 :             const auto& msg = messageIts.back()->second;
     214             : 
     215         136 :             if (dups.emplace(msg.msgId).second) {
     216         136 :                 msgHashes.emplace_back(msgHash);
     217         136 :                 pubKeys.emplace_back(msg.pubKey);
     218             : 
     219         136 :                 if (!aggSig.IsValid()) {
     220          74 :                     aggSig = msg.sig;
     221             :                 } else {
     222          62 :                     aggSig.AggregateInsecure(msg.sig);
     223             :                 }
     224             :             }
     225             : 
     226         136 :             messageIts.pop_back();
     227         136 :             if (messageIts.empty()) {
     228         112 :                 it = byMessageHash.erase(it);
     229             :             } else {
     230         210 :                 ++it;
     231             :             }
     232             :         }
     233             : 
     234          74 :         assert(!msgHashes.empty());
     235             : 
     236         148 :         return aggSig.VerifyInsecureAggregated(pubKeys, msgHashes);
     237             :     }
     238             : };
     239             : 
     240             : #endif // PIVX_BLS_BLS_BATCHVERIFIER_H

Generated by: LCOV version 1.14