include "bytedata.s7i";
include "msgdigest.s7i";
include "hmac.s7i";
const type: rsaKey is new struct
    var bigInteger: modulus  is 0_;  
    var bigInteger: exponent is 0_;  
    var integer: modulusLen is 0;    
  end struct;
const type: rsaKeyPair is new struct
    var rsaKey: publicKey is rsaKey.value;
    var rsaKey: privateKey is rsaKey.value;
  end struct;
const func rsaKey: rsaKey (in bigInteger: modulus, in bigInteger: exponent) is func
  result
    var rsaKey: aKey is rsaKey.value;
  begin
    aKey.modulus := modulus;
    aKey.exponent := exponent;
    
    aKey.modulusLen := pred(bitLength(modulus)) mdiv 8 + 1;
    
  end func;
const func string: literal (in rsaKey: aKey) is
  return "rsaKey(" <& aKey.modulus <& "_, " <&
                      aKey.exponent <& "_)";
const func rsaKeyPair: rsaKeyPair (in rsaKey: publicKey, in rsaKey: privateKey) is func
  result
    var rsaKeyPair: aKeyPair is rsaKeyPair.value;
  begin
    aKeyPair.publicKey := publicKey;
    aKeyPair.privateKey := privateKey;
  end func;
const func rsaKeyPair: rsaKeyPair (in bigInteger: modulus,
    in bigInteger: publicExponent, in bigInteger: privateExponent) is func
  result
    var rsaKeyPair: aKeyPair is rsaKeyPair.value;
  begin
    aKeyPair.publicKey.modulus := modulus;
    aKeyPair.publicKey.exponent := publicExponent;
    aKeyPair.publicKey.modulusLen := pred(bitLength(modulus)) mdiv 8 + 1;
    aKeyPair.privateKey.modulus := modulus;
    aKeyPair.privateKey.exponent := privateExponent;
    aKeyPair.privateKey.modulusLen := aKeyPair.publicKey.modulusLen;
  end func;
const func string: literal (in rsaKeyPair: aKeyPair) is
  return "rsaKeyPair(" <& literal(aKeyPair.publicKey) <& ", " <&
                          literal(aKeyPair.privateKey) <& ")";
const rsaKeyPair: stdRsaKeyPair is rsaKeyPair(
    24879323998282062445559239376138627672251789324639932255760945448830518737190739384640974288529252990629999154107088744522353230056093690582734642121734324688998783187082956551041534076711835155928707829063256162504640988273175086145464721272768789724523888480071692284791450208699922038033347977493657812564984507698611204922456314960345290535639893058558496084848086047929946200129560169331200998008734185791703119413114661834404705794384334141821974803140053993070324815699000969159963681641979473501237806530434553142494356819076895646395931159040168081404923835928757739530245969373700156294158081824152324534461_,
    65537_,
    13088626772857606374994147660260732180049394881287449598152583688371128141673593733366670911392214849793873855002581835202125435492530910232563666189681493608607352285338757891981781465383991523965240833886856981107389901791087944521771406381777047043992471840577258748417529339084119078189629851351546974405368864243464335513384244053469543096484265307474308910575994412661107393352461835904945528516443603610779331782072425524550819129133717693244132152897414694106718806022258591593987246936431975237846760817339494993145835504324086601210101710742420357193277289205188098444617372451223173144294880256099265243073_);
const func boolean: isProbablyPrime (in bigInteger: primeCandidate, in var integer: count) is func
  result
    var boolean: isProbablyPrime is TRUE;
  local
    var bigInteger: aRandomNumber is 0_;
  begin
    while isProbablyPrime and count > 0 do
      aRandomNumber := rand(1_, pred(primeCandidate));
      isProbablyPrime := modPow(aRandomNumber, pred(primeCandidate), primeCandidate) = 1_;
      decr(count);
    end while;
    
  end func;
const func bigInteger: getProbablyPrime (in integer: binDigits, in integer: count) is func
  result
    var bigInteger: probablyPrime is 0_;
  begin
    probablyPrime := rand(0_, 2_**binDigits - 1_);
    if not odd(probablyPrime) then
      incr(probablyPrime);
    end if;
    while not isProbablyPrime(probablyPrime, count) do
      
      
      probablyPrime +:= 2_;
    end while;
    
  end func;
const func rsaKeyPair: genRsaKeyPair (in integer: keyLength, in bigInteger: exponent) is func
  result
    var rsaKeyPair: keyPair is rsaKeyPair.value;
  local
    const integer: numTests is 10;
    var bigInteger: p is 0_;
    var bigInteger: q is 0_;
    var bigInteger: modulus is 0_;         
    var bigInteger: phiOfModulus is 0_;    
    var bigInteger: privateExponent is 0_; 
  begin
    p := getProbablyPrime(keyLength mdiv 2, numTests);
    q := getProbablyPrime(keyLength mdiv 2, numTests);
    modulus := p * q;
    phiOfModulus := pred(p) * pred(q);
    keyPair.publicKey := rsaKey(modulus, exponent);
    keyPair.privateKey := rsaKey(modulus, modInverse(exponent, phiOfModulus));
  end func;
const func string: int2Octets (in bigInteger: number, in integer: length) is
  return bytes(number, UNSIGNED, BE, length);
const func bigInteger: octets2int (in string: stri) is
  return bytes2BigInt(stri, UNSIGNED, BE);
const func string: emeOaepEncoding (in string: message, in string: label, in integer: modulusLen) is func
  result
    var string: encodedMessage is "";
  local
    const integer: hLen is 20;  
    var string: lHash is "";
    var string: ps is "";
    var string: db is "";
    var string: seed is "";
    var string: dbMask is "";
    var string: maskedDb is "";
    var string: seedMask is "";
    var string: maskedSeed is "";
  begin
    lHash := sha1(label);
    ps := "\0;" mult modulusLen - length(message) - 2 * hLen - 2;
    db := lHash & ps & "\1;" & message;
    
    seed := int2Octets(rand(0_, 2_ ** (hLen * 8) - 1_), hLen);
    dbMask := mgf1Sha1(seed, modulusLen - hLen - 1);
    maskedDb := db >< dbMask;
    seedMask := mgf1Sha1(maskedDb, hLen);
    maskedSeed := seed >< seedMask;
    encodedMessage := "\0;" & maskedSeed & maskedDb;
  end func;
const func string: emeOaepDecoding (in string: encodedMessage, in string: label, in integer: modulusLen) is func
  result
    var string: message is "";
  local
    const integer: hLen is 20;  
    var string: y is "";
    var string: maskedSeed is "";
    var string: maskedDb is "";
    var string: seedMask is "";
    var string: seed is "";
    var string: dbMask is "";
    var string: db is "";
    var string: lHash is "";
    var integer: pos is 0;
  begin
    y := encodedMessage[1 len 1];
    maskedSeed := encodedMessage[2 len hLen];
    maskedDb := encodedMessage[hLen + 2 ..];  
    seedMask := mgf1Sha1(maskedDb, hLen);
    seed := maskedSeed >< seedMask;
    dbMask := mgf1Sha1(seed, modulusLen - hLen - 1);
    db := maskedDb >< dbMask;
    lHash := db[.. hLen];
    
    
    if lHash <> sha1(label) then
      raise RANGE_ERROR;
    end if;
    pos := succ(hLen);
    while db[pos] <> '\1;' do
      incr(pos);
    end while;
    message := db[succ(pos) ..];
  end func;
const func string: emePkcs1V15Encoding (in string: message, in integer: modulusLen) is func
  result
    var string: encodedMessage is "";
  local
    var integer: pos is 0;
    var string: ps is "";
  begin
    for pos range 1 to modulusLen - length(message) - 3 do
      ps &:= rand('\1;', '\255;');
    end for;
    encodedMessage := "\0;\2;" & ps & "\0;" & message;
  end func;
const func string: emePkcs1V15Decoding (in string: encodedMessage, in integer: modulusLen) is func
  result
    var string: message is "";
  local
    var integer: pos is 0;
  begin
    if not startsWith(encodedMessage, "\0;\2;") then
      
      raise RANGE_ERROR;
    else
      pos := pos(encodedMessage[3 ..], '\0;');
      if pos = 0 then
        raise RANGE_ERROR;
      else
        message := encodedMessage[pos + 3 ..];
      end if;
    end if;
  end func;
const func string: emsaPkcs1V15Encoding (in string: message, in integer: modulusLen) is func
  result
    var string: encodedMessage is "";
  local
    var string: ps is "";
  begin
    ps := "\255;" mult modulusLen - length(message) - 3;
    encodedMessage := "\0;\1;" & ps & "\0;" & message;
  end func;
const func string: emsaPkcs1V15Decoding (in string: encodedMessage) is func
  result
    var string: message is "";
  local
    var integer: pos is 0;
  begin
    if not startsWith(encodedMessage, "\0;\1;") then
      
      raise RANGE_ERROR;
    else
      pos := pos(encodedMessage[3 ..], '\0;');
      if pos < 9 or encodedMessage[3 len pred(pos)] <> "\255;" mult pred(pos) then
        raise RANGE_ERROR;
      else
        message := encodedMessage[pos + 3 ..];
      end if;
    end if;
  end func;
const func bigInteger: rsaEncrypt (in rsaKey: encryptionKey, in bigInteger: message) is func
  result
    var bigInteger: ciphertext is 0_;
  begin
    if message >= encryptionKey.modulus then
      raise RANGE_ERROR;
    else
      ciphertext := modPow(message, encryptionKey.exponent, encryptionKey.modulus);
    end if;
  end func;
const func bigInteger: rsaDecrypt (in rsaKey: decryptionKey, in bigInteger: ciphertext) is func
  result
    var bigInteger: message is 0_;
  begin
    
    
    
    if ciphertext >= decryptionKey.modulus then
      raise RANGE_ERROR;
    else
      message := modPow(ciphertext, decryptionKey.exponent, decryptionKey.modulus);
    end if;
  end func;
const func string: rsaesOaepEncrypt (in rsaKey: encryptionKey, in string: message,
    in string: label) is func
  result
    var string: encryptedMessage is "";
  local
    const integer: hLen is 20;  
    var string: encodedMessage is "";
  begin
    if length(message) > encryptionKey.modulusLen - 2 * hLen - 2 then
      raise RANGE_ERROR;
    else
      encodedMessage := emeOaepEncoding(message, label, encryptionKey.modulusLen);
      encryptedMessage := int2Octets(rsaEncrypt(encryptionKey, octets2int(encodedMessage)),
                                     encryptionKey.modulusLen);
    end if;
  end func;
const func string: rsaesOaepDecrypt (in rsaKey: decryptionKey, in string: ciphertext,
    in string: label) is func
  result
    var string: message is "";
  local
    const integer: hLen is 20;  
    var string: encodedMessage is "";
  begin
    if length(ciphertext) <> decryptionKey.modulusLen or
        decryptionKey.modulusLen < 2 * hLen + 2 then
      
      
      raise RANGE_ERROR;
    else
      encodedMessage := int2Octets(rsaDecrypt(decryptionKey, octets2int(ciphertext)),
                                   decryptionKey.modulusLen);
      message := emeOaepDecoding(encodedMessage, label, decryptionKey.modulusLen);
    end if;
  end func;
const func string: rsaesPkcs1V15Encrypt (in rsaKey: encryptionKey, in string: message) is func
  result
    var string: encryptedMessage is "";
  local
    const integer: hLen is 20;  
    var string: encodedMessage is "";
  begin
    if length(message) > encryptionKey.modulusLen - 2 * hLen - 2 then
      
      
      raise RANGE_ERROR;
    else
      encodedMessage := emePkcs1V15Encoding(message, encryptionKey.modulusLen);
      encryptedMessage := int2Octets(rsaEncrypt(encryptionKey, octets2int(encodedMessage)),
                                     encryptionKey.modulusLen);
    end if;
  end func;
const func string: rsaesPkcs1V15Decrypt (in rsaKey: decryptionKey, in string: ciphertext) is func
  result
    var string: message is "";
  local
    const integer: hLen is 20;  
    var string: encodedMessage is "";
  begin
    if length(ciphertext) <> decryptionKey.modulusLen or
        decryptionKey.modulusLen < 2 * hLen + 2 then
      
      
      raise RANGE_ERROR;
    else
      encodedMessage := int2Octets(rsaDecrypt(decryptionKey, octets2int(ciphertext)),
                                   decryptionKey.modulusLen);
      message := emePkcs1V15Decoding(encodedMessage, decryptionKey.modulusLen);
    end if;
  end func;
const func string: rsassaPkcs1V15Encrypt (in rsaKey: encryptionKey, in string: message) is func
  result
    var string: encryptedMessage is "";
  local
    const integer: hLen is 20;  
    var string: encodedMessage is "";
  begin
    if length(message) > encryptionKey.modulusLen - 2 * hLen - 2 then
      
      
      raise RANGE_ERROR;
    else
      encodedMessage := emsaPkcs1V15Encoding(message, encryptionKey.modulusLen);
      encryptedMessage := int2Octets(rsaEncrypt(encryptionKey, octets2int(encodedMessage)),
                                     encryptionKey.modulusLen);
    end if;
  end func;
const func string: rsassaPkcs1V15Decrypt (in rsaKey: decryptionKey, in string: ciphertext) is func
  result
    var string: message is "";
  local
    const integer: hLen is 20;  
    var string: encodedMessage is "";
  begin
    if length(ciphertext) <> decryptionKey.modulusLen or
        decryptionKey.modulusLen < 2 * hLen + 2 then
      
      
      raise RANGE_ERROR;
    else
      encodedMessage := int2Octets(rsaDecrypt(decryptionKey, octets2int(ciphertext)),
                                   decryptionKey.modulusLen);
      message := emsaPkcs1V15Decoding(encodedMessage);
    end if;
  end func;