python ECIES HKDF KDF2密钥派生计算出共享公钥 解密数据

  • Post author:
  • Post category:python




python ECIES HKDF KDF2密钥派生计算出共享公钥 解密数据

# -*- coding: utf-8 -*-

import base64
import binascii
import os
from hashlib import sha256

from sd.utils.KDF import KDF2
from Crypto.Cipher import AES
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.kdf.hkdf import HKDF


def generate_key_pair():
    private_key = ec.generate_private_key(ec.SECP256R1())
    public_key = private_key.public_key()
    print(public_key.public_numbers())
    print(public_key.key_size)

    # private_key_str = private_key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption()).decode()
    public_key_byte = public_key.public_bytes(serialization.Encoding.X962, serialization.PublicFormat.UncompressedPoint)
    public_key_b64 = base64.b64encode(public_key_byte).decode()
    return private_key, public_key, public_key_b64


def load_ec_pub_key(remote_key_base64_str=None):
    remote_key_bytes = base64.b64decode(remote_key_base64_str)  # to bytes
    remote_public_key = ec.EllipticCurvePublicKey.from_encoded_point(curve=ec.SECP256R1(), data=remote_key_bytes)
    return remote_public_key


def ecies_kem_hkdf(private_key, remote_key_base64_str, salt, info, shared_key_len=32):
    remote_key_bytes = base64.b64decode(remote_key_base64_str)  # to bytes
    remote_public_key = ec.EllipticCurvePublicKey.from_encoded_point(curve=ec.SECP256R1(), data=remote_key_bytes)
    shared_key = private_key.exchange(ec.ECDH(), remote_public_key)
    derived_key = HKDF(
        algorithm=hashes.SHA256(),
        length=shared_key_len,
        salt=base64.b64decode(salt),
        info=base64.b64decode(info)
    ).derive(remote_key_bytes + shared_key)

    derived_key = base64.b64encode(derived_key).decode()
    return derived_key


def ecies_kem_kdf2(private_key, ephemeral_public_key_str, secret_key_len=32):
    # private_key_bytes = base64.b64decode(private_key_str)
    # private_key = ec.derive_private_key(int.from_bytes(private_key_bytes, byteorder='big'), ec.SECP256R1())
    # private_key = serialization.load_der_private_key(private_key_bytes, password=None)

    print("private key=" + base64.b64encode(private_key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption())).decode())

    ephemeral_public_key_bytes = base64.b64decode(ephemeral_public_key_str)  # to bytes
    ephemeral_public_key_bytes = ec.EllipticCurvePublicKey.from_encoded_point(curve=ec.SECP256R1(), data=ephemeral_public_key_bytes)

    # TODO ecies kem exchange with cofactor mode
    temp_shared_key = private_key.exchange(ec.ECDH(), ephemeral_public_key_bytes)
    print("ecies shared key=" + base64.b64encode(temp_shared_key).decode())

    temp_shared_key_hex = binascii.hexlify(temp_shared_key).decode('utf-8')
    print("ecies shared key in octet string =" + temp_shared_key_hex)
    secret_key_hex = KDF2(temp_shared_key_hex, secret_key_len, sha256)

    secret_key_bytes = bytes.fromhex(secret_key_hex)
    secret_key_str = base64.b64encode(secret_key_bytes).decode('utf-8')
    print("secret key in base64 =" + secret_key_str)

    return secret_key_str


def ecies_kem_decrypt(private_key, ephemeral_public_key, shared_key_len):
    shared_key = private_key.exchange(ec.ECDH(), ephemeral_public_key)
    derived_key = HKDF(
        algorithm=hashes.SHA256(),
        length=shared_key_len,
        salt=None,
        info=None,
    ).derive(shared_key)
    return derived_key


def gen_ecc():
    # 生成椭圆曲线公私钥对
    private_key = ec.generate_private_key(ec.SECP256R1())
    public_key = private_key.public_key()
    public_key_pem = public_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo).decode()
    private_key_pem = private_key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption()).decode()
    return {"private_key_pem": private_key_pem, "public_key_pem": public_key_pem}


def aes_gcm_encrypt(plaintext, secret):
    secret_key = base64.b64decode(secret)
    iv = os.urandom(12)
    aes_cipher = AES.new(secret_key, AES.MODE_GCM, iv)
    ciphertext, auth_tag = aes_cipher.encrypt_and_digest(plaintext.encode('utf-8'))
    result = iv + ciphertext + auth_tag
    return base64.b64encode(result).decode('utf-8')


def aes_gcm_decrypt(encrypted, secret_key, iv, aad):
    res_bytes = base64.b64decode(encrypted.encode('utf-8'))
    iv = base64.b64decode(iv)
    aes_cipher = AES.new(base64.b64decode(secret_key), AES.MODE_GCM, iv)
    if aad:
        aad = base64.b64decode(aad)
        aes_cipher.update(aad)
    data = aes_cipher.decrypt_and_verify(res_bytes[:-16], res_bytes[-16:])
    data = base64.b64encode(data).decode()
    return data


def aes_gcm_generate_secret():
    random_bytes = os.urandom(32)
    return base64.b64encode(random_bytes).decode('utf-8')


if __name__ == '__main__':
    pri_key, pub_key, pub_key_b64 = generate_key_pair()



KDF.py

# Key Derivation functions from ISO 18033 
# Author: Peio Popov <peio@peio.org>
# License: Public Domain

from hashlib import *
from math import ceil
import random
import binascii

'Try to use  Data Primitives conversions class according to ISO 18033 and PKCS#1 '
try:
    from DataPrimitives import DataPrimitives, RSAPrimitives
except:
    class DataPrimitives():
        def __init__(self, explain=False):
            self.explain = explain

        def Explain(self, explanation, *vars):
            'Print an explanation message'
            if self.explain:
                print(explanation % vars)
                # print explanation%vars

        def I2OSP(self, longint, length):
            from binascii import a2b_hex, b2a_hex
            ''' I2OSP(longint, length) -> bytes
        
            I2OSP converts a long integer into a string of bytes (an Octet String). 
            It is defined in the  PKCS #1 v2.1: RSA Cryptography Standard (June 14, 2002)
            '''
            hex_string = '%X' % longint
            if len(hex_string) > 2 * length:
                raise ValueError('integer %i too large to encode in %i octets' % (longint, length))
            return a2b_hex(hex_string.zfill(2 * length))

cp = DataPrimitives(0)
ex = DataPrimitives(0)

'Hash function output lenght '
Hash_len = {md5: 16, sha1: 20, sha224: 28, sha256: 32, sha512: 64}


def KDF1(x, l, hashfunct=sha1):
    '''KDF (x, l) that takes as input an octet string x and
    an integer l >= 0, and outputs an octet string of length l '''
    assert l >= 0, 'l should be positive integer'

    k = l / float(Hash_len[hashfunct])
    ex.Explain('l=%d Hash_len=%d k=%f [k]=%d', l, Hash_len[hashfunct], k, int(ceil(k)))
    k = int(ceil(k))

    l_str = ''
    for i in range(0, k):
        l_str = l_str + hashfunct(x + cp.I2OSP(i, 4)).hexdigest()
        ex.Explain('i = %d len = %d str(hex) = %s', i, len(l_str), l_str)

    return l_str[:l * 2]


'''Same as KDF1 
Defined in B.2.1 section of PKCS#1
IEEE P1363 Standard Specifications for Public Key Cryptography'''


def MGF1(mgfSeed, maskLen, hashfunct=sha1):
    ''' MGF1 is a Mask Generation Function based on a hash function.
    MGF1 (mgfSeed, maskLen)
    Options: Hash hash function (hLen denotes the length in octets of the hash
                 function output)
    Input: mgfSeed seed from which mask is generated, an octet string
    maskLen intended length in octets of the mask, at most 2*32 hLen
    Output: mask mask, an octet string of length maskLen
    Error: "mask too long"
    Steps:
    1. If maskLen > 2**32 hLen, output "mask too long" and stop.
    2. Let T be the empty octet string.
    3. For counter from 0 to (ceil maskLen / hLen ) - 1, do the following:
    a. Convert counter to an octet string C of length 4 octets (see Section 4.1):
    C = I2OSP (counter, 4) .
    b. Concatenate the hash of the seed mgfSeed and C to the octet string T:
    T = T || Hash (mgfSeed || C) .
    4. Output the leading maskLen octets of T as the octet string mask.
    '''
    assert len(mgfSeed) < 2 ** 32 * Hash_len[hashfunct], "mask too long"

    T = ''
    counter = ceil(maskLen / float(Hash_len[hashfunct]))

    try:
        rsa = RSAPrimitives()
    except:
        rsa = DataPrimitives()

    for i in range(0, int(counter)):
        C = rsa.I2OSP(i, 4)
        T = T + hashfunct(mgfSeed + C).hexdigest()

    return T[:maskLen * 2]


def KDF2(x, l, hashfunct=sha1):
    '''KDF (x, l) that takes as input an octet string x and
    an integer l >= 0, and outputs an octet string of length l '''
    assert l >= 0, 'l should be positive integer'

    k = l / float(Hash_len[hashfunct])
    ex.Explain('l=%d Hash_len=%d k=%f [k]=%d', l, Hash_len[hashfunct], k, int(ceil(k)))
    k = int(ceil(k))

    l_str = ''
    for i in range(1, k + 1):
        print(x + binascii.hexlify(cp.I2OSP(i, 4)).decode('utf-8'))
        sum = x + binascii.hexlify(cp.I2OSP(i, 4)).decode('utf-8')
        sum_bytes = bytes.fromhex(sum)
        # hashfunct(sum_bytes)
        l_str = l_str + hashfunct(sum_bytes).hexdigest()
        # l_str = l_str+hashfunct(x+cp.I2OSP(i,4)).hexdigest()
        ex.Explain('i = %d len = %d str(hex) = %s', i, len(l_str), l_str)

    return l_str[:l * 2]


def KDF3(x, l, hashfunct=sha1, pamt=64):
    '''KDF (x, l) that takes as input an octet string x and
    an integer l >= 0, and outputs an octet string of length l 
    pamt padding amount  >= 4'''
    assert l >= 0, 'l should be positive integer'

    k = l / float(Hash_len[hashfunct])
    ex.Explain('l=%d Hash_len=%d k=%f [k]=%d', l, Hash_len[hashfunct], k, int(ceil(k)))
    k = int(ceil(k))

    l_str = ''
    for i in range(0, k):
        l_str = l_str + hashfunct(cp.I2OSP(i, pamt) + x).hexdigest()
        '''Having the counter value as the first input to the Hash function 
        removes a possible security issue compared to KDF1 and KDF2. '''
        ex.Explain('i = %d len = %d str(hex) = %s', i, len(l_str), l_str)

    return l_str[:l * 2]


def KDF4(x, l, hashfunct=sha1):
    seed = hashfunct(x).digest()
    random.seed(seed)

    str = ''
    for _ in range(0, l):
        str = str + chr(random.randint(0, 256))

    return str.encode('hex')


def KDFTestVectors():
    shared = 'deadbeeffeebdaed'
    shared = shared.decode('hex')
    l = 32
    hashfunct = sha1

    kdf1test = 'b0ad565b14b478cad4763856ff3016b1a93d840f87261bede7ddf0f9305a6e44'
    kdf2test = '87261bede7ddf0f9305a6e44a74e6a0846dede27f48205c6b141888742b0ce2c'
    kdf3test = '60cef67059af33f6aebce1e10188f434f80306ac0360470aeb41f81bafb35790'

    result = KDF1(shared, l, hashfunct)

    if result == kdf1test:
        print('KDF1 test passed')
    else:
        print('KDF1 test failed')

    result = KDF2(shared, l, hashfunct)

    if result == kdf2test:
        print('KDF2 test passed')
    else:
        print('KDF2 test failed')

    result = KDF3(shared, l, hashfunct, 4)

    if result == kdf3test:
        print('KDF3 test passed')
    else:
        print('KDF3 test failed')


if __name__ == '__main__':
    KDFTestVectors()



版权声明:本文为weixin_39038035原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。