summaryrefslogtreecommitdiffstats
path: root/security/nss/lib/freebl/verified/specs/Spec.CTR.fst
blob: e411cd35357abff0a3221308c0ed66e9ae551b93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/* Copyright 2016-2017 INRIA and Microsoft Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

module Spec.CTR

module ST = FStar.HyperStack.ST

open FStar.Mul
open FStar.Seq
open Spec.Lib

#reset-options "--initial_fuel 0 --max_fuel 0 --initial_ifuel 0 --max_ifuel 0"

type block_cipher_ctx = {
     keylen: nat ;
     blocklen: (x:nat{x>0});
     noncelen: nat;
     counterbits: nat;
     incr: pos}

type key (c:block_cipher_ctx) = lbytes c.keylen
type nonce (c:block_cipher_ctx) = lbytes c.noncelen
type block (c:block_cipher_ctx) = lbytes (c.blocklen*c.incr)
type counter (c:block_cipher_ctx) = UInt.uint_t c.counterbits
type block_cipher (c:block_cipher_ctx) =  key c -> nonce c -> counter c -> block c

val xor: #len:nat -> x:lbytes len -> y:lbytes len -> Tot (lbytes len)
let xor #len x y = map2 FStar.UInt8.(fun x y -> x ^^ y) x y


val counter_mode_blocks:
  ctx: block_cipher_ctx ->
  bc: block_cipher ctx ->
  k:key ctx -> n:nonce ctx -> c:counter ctx ->
  plain:seq UInt8.t{c + ctx.incr * (length plain / ctx.blocklen) < pow2 ctx.counterbits /\
    length plain % (ctx.blocklen * ctx.incr) = 0} ->
  Tot (lbytes (length plain))
      (decreases (length plain))
#reset-options "--z3rlimit 200 --max_fuel 0"
let rec counter_mode_blocks ctx block_enc key nonce counter plain =
  let len  = length plain in
  let len' = len / (ctx.blocklen * ctx.incr) in
  Math.Lemmas.lemma_div_mod len (ctx.blocklen * ctx.incr) ;
  if len = 0 then Seq.createEmpty #UInt8.t
  else (
    let prefix, block = split plain (len - ctx.blocklen * ctx.incr) in    
      (* TODO: move to a single lemma for clarify *)
      Math.Lemmas.lemma_mod_plus (length prefix) 1 (ctx.blocklen * ctx.incr);
      Math.Lemmas.lemma_div_le (length prefix) len ctx.blocklen;
      Spec.CTR.Lemmas.lemma_div len (ctx.blocklen * ctx.incr);
      (* End TODO *)
    let cipher        = counter_mode_blocks ctx block_enc key nonce counter prefix in
    let mask          = block_enc key nonce (counter + (len / ctx.blocklen - 1) * ctx.incr) in
    let eb            = xor block mask in
    cipher @| eb
  )


val counter_mode:
  ctx: block_cipher_ctx ->
  bc: block_cipher ctx ->
  k:key ctx -> n:nonce ctx -> c:counter ctx ->
  plain:seq UInt8.t{c + ctx.incr * (length plain / ctx.blocklen) < pow2 ctx.counterbits} ->
  Tot (lbytes (length plain))
      (decreases (length plain))
#reset-options "--z3rlimit 200 --max_fuel 0"
let counter_mode ctx block_enc key nonce counter plain =
  let len = length plain in
  let blocks_len = (ctx.incr * ctx.blocklen) * (len / (ctx.blocklen * ctx.incr)) in
  let part_len   = len % (ctx.blocklen * ctx.incr) in
    (* TODO: move to a single lemma for clarify *)
    Math.Lemmas.lemma_div_mod len (ctx.blocklen * ctx.incr);
    Math.Lemmas.multiple_modulo_lemma (len / (ctx.blocklen * ctx.incr)) (ctx.blocklen * ctx.incr);
    Math.Lemmas.lemma_div_le (blocks_len) len ctx.blocklen;
    (* End TODO *)
  let blocks, last_block = split plain blocks_len in
  let cipher_blocks = counter_mode_blocks ctx block_enc key nonce counter blocks in
  let cipher_last_block =
    if part_len > 0
    then (* encrypt final partial block(s) *)
      let mask = block_enc key nonce (counter+ctx.incr*(length plain / ctx.blocklen)) in
      let mask = slice mask 0 part_len in
      assert(length last_block = part_len);
      xor #part_len last_block mask
    else createEmpty in
  cipher_blocks @| cipher_last_block