summaryrefslogtreecommitdiffstats
path: root/mobile/android/services/src/main/java/org/mozilla/gecko/sync/crypto/HKDF.java
blob: 16c0d8147daf00b9966188972029bed2c7ec1524 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

package org.mozilla.gecko.sync.crypto;

import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import org.mozilla.gecko.sync.Utils;

/*
 * A standards-compliant implementation of RFC 5869
 * for HMAC-based Key Derivation Function.
 * HMAC uses HMAC SHA256 standard.
 */
public class HKDF {
  public static String HMAC_ALGORITHM = "hmacSHA256";

  /**
   * Used for conversion in cases in which you *know* the encoding exists.
   */
  public static final byte[] bytes(String in) {
    try {
      return in.getBytes("UTF-8");
    } catch (java.io.UnsupportedEncodingException e) {
      return null;
    }
  }

  public static final int BLOCKSIZE     = 256 / 8;
  public static final byte[] HMAC_INPUT = bytes("Sync-AES_256_CBC-HMAC256");

  /*
   * Step 1 of RFC 5869
   * Get sha256HMAC Bytes
   * Input: salt (message), IKM (input keyring material)
   * Output: PRK (pseudorandom key)
   */
  public static byte[] hkdfExtract(byte[] salt, byte[] IKM) throws NoSuchAlgorithmException, InvalidKeyException {
    return digestBytes(IKM, makeHMACHasher(salt));
  }

  /*
   * Step 2 of RFC 5869.
   * Input: PRK from step 1, info, length.
   * Output: OKM (output keyring material).
   */
  public static byte[] hkdfExpand(byte[] prk, byte[] info, int len) throws NoSuchAlgorithmException, InvalidKeyException {
    Mac hmacHasher = makeHMACHasher(prk);

    byte[] T  = {};
    byte[] Tn = {};

    int iterations = (int) Math.ceil(((double)len) / (BLOCKSIZE));
    for (int i = 0; i < iterations; i++) {
      Tn = digestBytes(Utils.concatAll(Tn, info, Utils.hex2Byte(Integer.toHexString(i + 1))),
                       hmacHasher);
      T = Utils.concatAll(T, Tn);
    }

    byte[] result = new byte[len];
    System.arraycopy(T, 0, result, 0, len);
    return result;
  }

  /*
   * Make HMAC key
   * Input: key (salt)
   * Output: Key HMAC-Key
   */
  public static Key makeHMACKey(byte[] key) {
    if (key.length == 0) {
      key = new byte[BLOCKSIZE];
    }
    return new SecretKeySpec(key, HMAC_ALGORITHM);
  }

  /*
   * Make an HMAC hasher
   * Input: Key hmacKey
   * Ouput: An HMAC Hasher
   */
  public static Mac makeHMACHasher(byte[] key) throws NoSuchAlgorithmException, InvalidKeyException {
    Mac hmacHasher = null;
    hmacHasher = Mac.getInstance(HMAC_ALGORITHM);

    // If Mac.getInstance doesn't throw NoSuchAlgorithmException, hmacHasher is
    // non-null.
    assert(hmacHasher != null);

    hmacHasher.init(makeHMACKey(key));
    return hmacHasher;
  }

  /*
   * Hash bytes with given hasher
   * Input: message to hash, HMAC hasher
   * Output: hashed byte[].
   */
  public static byte[] digestBytes(byte[] message, Mac hasher) {
    hasher.update(message);
    byte[] ret = hasher.doFinal();
    hasher.reset();
    return ret;
  }

  public static byte[] derive(byte[] skm, byte[] xts, byte[] ctxInfo, int dkLen) throws InvalidKeyException, NoSuchAlgorithmException {
    return hkdfExpand(hkdfExtract(xts, skm), ctxInfo, dkLen);
  }

  public static void deriveMany(byte[] skm, byte[] xts, byte[] ctxInfo, byte[]... keys) throws InvalidKeyException, NoSuchAlgorithmException {
    int length = 0;
    for (byte[] key : keys) {
      length += key.length;
    }
    byte[] derived = hkdfExpand(hkdfExtract(xts, skm), ctxInfo, length);
    int offset = 0;
    for (byte[] key : keys) {
      System.arraycopy(derived, offset, key, 0, key.length);
      offset += key.length;
    }
  }
}