view curve25519-donna.c @ 1306:34e6127ef02e

merge fixes from PuTTY import.c toint() from misc.c (revids are from hggit conversion) changeset: 4620:60a336a6c85c user: Simon Tatham <[email protected]> date: Thu Feb 25 20:26:33 2016 +0000 files: import.c description: Fix potential segfaults in reading OpenSSH's ASN.1 key format. The length coming back from ber_read_id_len might have overflowed, so treat it as potentially negative. Also, while I'm here, accumulate it inside ber_read_id_len as an unsigned, so as to avoid undefined behaviour on integer overflow, and toint() it before return. Thanks to Hanno Böck for spotting this, with the aid of AFL. (cherry picked from commit 5b7833cd474a24ec098654dcba8cb9509f3bf2c1) Conflicts: import.c (cherry-picker's note: resolving the conflict involved removing an entire section of the original commit which fixed ECDSA code not present on this branch) changeset: 4619:9c6c638d98d8 user: Simon Tatham <[email protected]> date: Sun Jul 14 10:45:54 2013 +0000 files: import.c ssh.c sshdss.c sshpubk.c sshrsa.c description: Tighten up a lot of casts from unsigned to int which are read by one of the GET_32BIT macros and then used as length fields. Missing bounds checks against zero have been added, and also I've introduced a helper function toint() which casts from unsigned to int in such a way as to avoid C undefined behaviour, since I'm not sure I trust compilers any more to do the obviously sensible thing. [originally from svn r9918] changeset: 4618:3957829f24d3 user: Simon Tatham <[email protected]> date: Mon Jul 08 22:36:04 2013 +0000 files: import.c sshdss.c sshrsa.c description: Add an assortment of extra safety checks. [originally from svn r9896] changeset: 4617:2cddee0bce12 user: Jacob Nevins <[email protected]> date: Wed Dec 07 00:24:45 2005 +0000 files: import.c description: Institutional failure to memset() things pointed at rather than pointers. Things should now be zeroed and memory not leaked. Spotted by Brant Thomsen. [originally from svn r6476] changeset: 4616:24ac78a9c71d user: Simon Tatham <[email protected]> date: Wed Feb 11 13:58:27 2004 +0000 files: import.c description: Jacob's last-minute testing found a couple of trivial bugs in import.c, and my attempts to reproduce them in cmdgen found another one there :-) [originally from svn r3847] changeset: 4615:088d39a73db0 user: Simon Tatham <[email protected]> date: Thu Jan 22 18:52:49 2004 +0000 files: import.c description: Placate some gcc warnings. [originally from svn r3761] changeset: 4614:e4288bad4d93 parent: 1758:108b8924593d user: Simon Tatham <[email protected]> date: Fri Oct 03 21:21:23 2003 +0000 files: import.c description: My ASN.1 decoder returned wrong IDs for anything above 0x1E! Good job it's never had to yet. Ahem. [originally from svn r3479]
author Matt Johnston <matt@ucc.asn.au>
date Tue, 12 Jul 2016 23:00:01 +0800
parents d3925ed45a85
children 27b9ddb06b09
line wrap: on
line source

/* Copyright 2008, Google Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are
 * met:
 *
 *     * Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above
 * copyright notice, this list of conditions and the following disclaimer
 * in the documentation and/or other materials provided with the
 * distribution.
 *     * Neither the name of Google Inc. nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * curve25519-donna: Curve25519 elliptic curve, public key function
 *
 * http://code.google.com/p/curve25519-donna/
 *
 * Adam Langley <[email protected]>
 *
 * Derived from public domain C code by Daniel J. Bernstein <[email protected]>
 *
 * More information about curve25519 can be found here
 *   http://cr.yp.to/ecdh.html
 *
 * djb's sample implementation of curve25519 is written in a special assembly
 * language called qhasm and uses the floating point registers.
 *
 * This is, almost, a clean room reimplementation from the curve25519 paper. It
 * uses many of the tricks described therein. Only the crecip function is taken
 * from the sample implementation.
 */

#include <string.h>
#include <stdint.h>

#ifdef _MSC_VER
#define inline __inline
#endif

typedef uint8_t u8;
typedef int32_t s32;
typedef int64_t limb;

/* Field element representation:
 *
 * Field elements are written as an array of signed, 64-bit limbs, least
 * significant first. The value of the field element is:
 *   x[0] + 2^26·x[1] + x^51·x[2] + 2^102·x[3] + ...
 *
 * i.e. the limbs are 26, 25, 26, 25, ... bits wide.
 */

/* Sum two numbers: output += in */
static void fsum(limb *output, const limb *in) {
  unsigned i;
  for (i = 0; i < 10; i += 2) {
    output[0+i] = (output[0+i] + in[0+i]);
    output[1+i] = (output[1+i] + in[1+i]);
  }
}

/* Find the difference of two numbers: output = in - output
 * (note the order of the arguments!)
 */
static void fdifference(limb *output, const limb *in) {
  unsigned i;
  for (i = 0; i < 10; ++i) {
    output[i] = (in[i] - output[i]);
  }
}

/* Multiply a number by a scalar: output = in * scalar */
static void fscalar_product(limb *output, const limb *in, const limb scalar) {
  unsigned i;
  for (i = 0; i < 10; ++i) {
    output[i] = in[i] * scalar;
  }
}

/* Multiply two numbers: output = in2 * in
 *
 * output must be distinct to both inputs. The inputs are reduced coefficient
 * form, the output is not.
 */
static void fproduct(limb *output, const limb *in2, const limb *in) {
  output[0] =       ((limb) ((s32) in2[0])) * ((s32) in[0]);
  output[1] =       ((limb) ((s32) in2[0])) * ((s32) in[1]) +
                    ((limb) ((s32) in2[1])) * ((s32) in[0]);
  output[2] =  2 *  ((limb) ((s32) in2[1])) * ((s32) in[1]) +
                    ((limb) ((s32) in2[0])) * ((s32) in[2]) +
                    ((limb) ((s32) in2[2])) * ((s32) in[0]);
  output[3] =       ((limb) ((s32) in2[1])) * ((s32) in[2]) +
                    ((limb) ((s32) in2[2])) * ((s32) in[1]) +
                    ((limb) ((s32) in2[0])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[3])) * ((s32) in[0]);
  output[4] =       ((limb) ((s32) in2[2])) * ((s32) in[2]) +
               2 * (((limb) ((s32) in2[1])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[3])) * ((s32) in[1])) +
                    ((limb) ((s32) in2[0])) * ((s32) in[4]) +
                    ((limb) ((s32) in2[4])) * ((s32) in[0]);
  output[5] =       ((limb) ((s32) in2[2])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[3])) * ((s32) in[2]) +
                    ((limb) ((s32) in2[1])) * ((s32) in[4]) +
                    ((limb) ((s32) in2[4])) * ((s32) in[1]) +
                    ((limb) ((s32) in2[0])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[5])) * ((s32) in[0]);
  output[6] =  2 * (((limb) ((s32) in2[3])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[1])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[5])) * ((s32) in[1])) +
                    ((limb) ((s32) in2[2])) * ((s32) in[4]) +
                    ((limb) ((s32) in2[4])) * ((s32) in[2]) +
                    ((limb) ((s32) in2[0])) * ((s32) in[6]) +
                    ((limb) ((s32) in2[6])) * ((s32) in[0]);
  output[7] =       ((limb) ((s32) in2[3])) * ((s32) in[4]) +
                    ((limb) ((s32) in2[4])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[2])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[5])) * ((s32) in[2]) +
                    ((limb) ((s32) in2[1])) * ((s32) in[6]) +
                    ((limb) ((s32) in2[6])) * ((s32) in[1]) +
                    ((limb) ((s32) in2[0])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[7])) * ((s32) in[0]);
  output[8] =       ((limb) ((s32) in2[4])) * ((s32) in[4]) +
               2 * (((limb) ((s32) in2[3])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[5])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[1])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[7])) * ((s32) in[1])) +
                    ((limb) ((s32) in2[2])) * ((s32) in[6]) +
                    ((limb) ((s32) in2[6])) * ((s32) in[2]) +
                    ((limb) ((s32) in2[0])) * ((s32) in[8]) +
                    ((limb) ((s32) in2[8])) * ((s32) in[0]);
  output[9] =       ((limb) ((s32) in2[4])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[5])) * ((s32) in[4]) +
                    ((limb) ((s32) in2[3])) * ((s32) in[6]) +
                    ((limb) ((s32) in2[6])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[2])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[7])) * ((s32) in[2]) +
                    ((limb) ((s32) in2[1])) * ((s32) in[8]) +
                    ((limb) ((s32) in2[8])) * ((s32) in[1]) +
                    ((limb) ((s32) in2[0])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[0]);
  output[10] = 2 * (((limb) ((s32) in2[5])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[3])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[7])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[1])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[1])) +
                    ((limb) ((s32) in2[4])) * ((s32) in[6]) +
                    ((limb) ((s32) in2[6])) * ((s32) in[4]) +
                    ((limb) ((s32) in2[2])) * ((s32) in[8]) +
                    ((limb) ((s32) in2[8])) * ((s32) in[2]);
  output[11] =      ((limb) ((s32) in2[5])) * ((s32) in[6]) +
                    ((limb) ((s32) in2[6])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[4])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[7])) * ((s32) in[4]) +
                    ((limb) ((s32) in2[3])) * ((s32) in[8]) +
                    ((limb) ((s32) in2[8])) * ((s32) in[3]) +
                    ((limb) ((s32) in2[2])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[2]);
  output[12] =      ((limb) ((s32) in2[6])) * ((s32) in[6]) +
               2 * (((limb) ((s32) in2[5])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[7])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[3])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[3])) +
                    ((limb) ((s32) in2[4])) * ((s32) in[8]) +
                    ((limb) ((s32) in2[8])) * ((s32) in[4]);
  output[13] =      ((limb) ((s32) in2[6])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[7])) * ((s32) in[6]) +
                    ((limb) ((s32) in2[5])) * ((s32) in[8]) +
                    ((limb) ((s32) in2[8])) * ((s32) in[5]) +
                    ((limb) ((s32) in2[4])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[4]);
  output[14] = 2 * (((limb) ((s32) in2[7])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[5])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[5])) +
                    ((limb) ((s32) in2[6])) * ((s32) in[8]) +
                    ((limb) ((s32) in2[8])) * ((s32) in[6]);
  output[15] =      ((limb) ((s32) in2[7])) * ((s32) in[8]) +
                    ((limb) ((s32) in2[8])) * ((s32) in[7]) +
                    ((limb) ((s32) in2[6])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[6]);
  output[16] =      ((limb) ((s32) in2[8])) * ((s32) in[8]) +
               2 * (((limb) ((s32) in2[7])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[7]));
  output[17] =      ((limb) ((s32) in2[8])) * ((s32) in[9]) +
                    ((limb) ((s32) in2[9])) * ((s32) in[8]);
  output[18] = 2 *  ((limb) ((s32) in2[9])) * ((s32) in[9]);
}

/* Reduce a long form to a short form by taking the input mod 2^255 - 19. */
static void freduce_degree(limb *output) {
  /* Each of these shifts and adds ends up multiplying the value by 19. */
  output[8] += output[18] << 4;
  output[8] += output[18] << 1;
  output[8] += output[18];
  output[7] += output[17] << 4;
  output[7] += output[17] << 1;
  output[7] += output[17];
  output[6] += output[16] << 4;
  output[6] += output[16] << 1;
  output[6] += output[16];
  output[5] += output[15] << 4;
  output[5] += output[15] << 1;
  output[5] += output[15];
  output[4] += output[14] << 4;
  output[4] += output[14] << 1;
  output[4] += output[14];
  output[3] += output[13] << 4;
  output[3] += output[13] << 1;
  output[3] += output[13];
  output[2] += output[12] << 4;
  output[2] += output[12] << 1;
  output[2] += output[12];
  output[1] += output[11] << 4;
  output[1] += output[11] << 1;
  output[1] += output[11];
  output[0] += output[10] << 4;
  output[0] += output[10] << 1;
  output[0] += output[10];
}

#if (-1 & 3) != 3
#error "This code only works on a two's complement system"
#endif

/* return v / 2^26, using only shifts and adds. */
static inline limb
div_by_2_26(const limb v)
{
  /* High word of v; no shift needed*/
  const uint32_t highword = (uint32_t) (((uint64_t) v) >> 32);
  /* Set to all 1s if v was negative; else set to 0s. */
  const int32_t sign = ((int32_t) highword) >> 31;
  /* Set to 0x3ffffff if v was negative; else set to 0. */
  const int32_t roundoff = ((uint32_t) sign) >> 6;
  /* Should return v / (1<<26) */
  return (v + roundoff) >> 26;
}

/* return v / (2^25), using only shifts and adds. */
static inline limb
div_by_2_25(const limb v)
{
  /* High word of v; no shift needed*/
  const uint32_t highword = (uint32_t) (((uint64_t) v) >> 32);
  /* Set to all 1s if v was negative; else set to 0s. */
  const int32_t sign = ((int32_t) highword) >> 31;
  /* Set to 0x1ffffff if v was negative; else set to 0. */
  const int32_t roundoff = ((uint32_t) sign) >> 7;
  /* Should return v / (1<<25) */
  return (v + roundoff) >> 25;
}

static inline s32
div_s32_by_2_25(const s32 v)
{
   const s32 roundoff = ((uint32_t)(v >> 31)) >> 7;
   return (v + roundoff) >> 25;
}

/* Reduce all coefficients of the short form input so that |x| < 2^26.
 *
 * On entry: |output[i]| < 2^62
 */
static void freduce_coefficients(limb *output) {
  unsigned i;

  output[10] = 0;

  for (i = 0; i < 10; i += 2) {
    limb over = div_by_2_26(output[i]);
    output[i] -= over << 26;
    output[i+1] += over;

    over = div_by_2_25(output[i+1]);
    output[i+1] -= over << 25;
    output[i+2] += over;
  }
  /* Now |output[10]| < 2 ^ 38 and all other coefficients are reduced. */
  output[0] += output[10] << 4;
  output[0] += output[10] << 1;
  output[0] += output[10];

  output[10] = 0;

  /* Now output[1..9] are reduced, and |output[0]| < 2^26 + 19 * 2^38
   * So |over| will be no more than 77825  */
  {
    limb over = div_by_2_26(output[0]);
    output[0] -= over << 26;
    output[1] += over;
  }

  /* Now output[0,2..9] are reduced, and |output[1]| < 2^25 + 77825
   * So |over| will be no more than 1. */
  {
    /* output[1] fits in 32 bits, so we can use div_s32_by_2_25 here. */
    s32 over32 = div_s32_by_2_25((s32) output[1]);
    output[1] -= over32 << 25;
    output[2] += over32;
  }

  /* Finally, output[0,1,3..9] are reduced, and output[2] is "nearly reduced":
   * we have |output[2]| <= 2^26.  This is good enough for all of our math,
   * but it will require an extra freduce_coefficients before fcontract. */
}

/* A helpful wrapper around fproduct: output = in * in2.
 *
 * output must be distinct to both inputs. The output is reduced degree and
 * reduced coefficient.
 */
static void
fmul(limb *output, const limb *in, const limb *in2) {
  limb t[19];
  fproduct(t, in, in2);
  freduce_degree(t);
  freduce_coefficients(t);
  memcpy(output, t, sizeof(limb) * 10);
}

static void fsquare_inner(limb *output, const limb *in) {
  output[0] =       ((limb) ((s32) in[0])) * ((s32) in[0]);
  output[1] =  2 *  ((limb) ((s32) in[0])) * ((s32) in[1]);
  output[2] =  2 * (((limb) ((s32) in[1])) * ((s32) in[1]) +
                    ((limb) ((s32) in[0])) * ((s32) in[2]));
  output[3] =  2 * (((limb) ((s32) in[1])) * ((s32) in[2]) +
                    ((limb) ((s32) in[0])) * ((s32) in[3]));
  output[4] =       ((limb) ((s32) in[2])) * ((s32) in[2]) +
               4 *  ((limb) ((s32) in[1])) * ((s32) in[3]) +
               2 *  ((limb) ((s32) in[0])) * ((s32) in[4]);
  output[5] =  2 * (((limb) ((s32) in[2])) * ((s32) in[3]) +
                    ((limb) ((s32) in[1])) * ((s32) in[4]) +
                    ((limb) ((s32) in[0])) * ((s32) in[5]));
  output[6] =  2 * (((limb) ((s32) in[3])) * ((s32) in[3]) +
                    ((limb) ((s32) in[2])) * ((s32) in[4]) +
                    ((limb) ((s32) in[0])) * ((s32) in[6]) +
               2 *  ((limb) ((s32) in[1])) * ((s32) in[5]));
  output[7] =  2 * (((limb) ((s32) in[3])) * ((s32) in[4]) +
                    ((limb) ((s32) in[2])) * ((s32) in[5]) +
                    ((limb) ((s32) in[1])) * ((s32) in[6]) +
                    ((limb) ((s32) in[0])) * ((s32) in[7]));
  output[8] =       ((limb) ((s32) in[4])) * ((s32) in[4]) +
               2 * (((limb) ((s32) in[2])) * ((s32) in[6]) +
                    ((limb) ((s32) in[0])) * ((s32) in[8]) +
               2 * (((limb) ((s32) in[1])) * ((s32) in[7]) +
                    ((limb) ((s32) in[3])) * ((s32) in[5])));
  output[9] =  2 * (((limb) ((s32) in[4])) * ((s32) in[5]) +
                    ((limb) ((s32) in[3])) * ((s32) in[6]) +
                    ((limb) ((s32) in[2])) * ((s32) in[7]) +
                    ((limb) ((s32) in[1])) * ((s32) in[8]) +
                    ((limb) ((s32) in[0])) * ((s32) in[9]));
  output[10] = 2 * (((limb) ((s32) in[5])) * ((s32) in[5]) +
                    ((limb) ((s32) in[4])) * ((s32) in[6]) +
                    ((limb) ((s32) in[2])) * ((s32) in[8]) +
               2 * (((limb) ((s32) in[3])) * ((s32) in[7]) +
                    ((limb) ((s32) in[1])) * ((s32) in[9])));
  output[11] = 2 * (((limb) ((s32) in[5])) * ((s32) in[6]) +
                    ((limb) ((s32) in[4])) * ((s32) in[7]) +
                    ((limb) ((s32) in[3])) * ((s32) in[8]) +
                    ((limb) ((s32) in[2])) * ((s32) in[9]));
  output[12] =      ((limb) ((s32) in[6])) * ((s32) in[6]) +
               2 * (((limb) ((s32) in[4])) * ((s32) in[8]) +
               2 * (((limb) ((s32) in[5])) * ((s32) in[7]) +
                    ((limb) ((s32) in[3])) * ((s32) in[9])));
  output[13] = 2 * (((limb) ((s32) in[6])) * ((s32) in[7]) +
                    ((limb) ((s32) in[5])) * ((s32) in[8]) +
                    ((limb) ((s32) in[4])) * ((s32) in[9]));
  output[14] = 2 * (((limb) ((s32) in[7])) * ((s32) in[7]) +
                    ((limb) ((s32) in[6])) * ((s32) in[8]) +
               2 *  ((limb) ((s32) in[5])) * ((s32) in[9]));
  output[15] = 2 * (((limb) ((s32) in[7])) * ((s32) in[8]) +
                    ((limb) ((s32) in[6])) * ((s32) in[9]));
  output[16] =      ((limb) ((s32) in[8])) * ((s32) in[8]) +
               4 *  ((limb) ((s32) in[7])) * ((s32) in[9]);
  output[17] = 2 *  ((limb) ((s32) in[8])) * ((s32) in[9]);
  output[18] = 2 *  ((limb) ((s32) in[9])) * ((s32) in[9]);
}

static void
fsquare(limb *output, const limb *in) {
  limb t[19];
  fsquare_inner(t, in);
  freduce_degree(t);
  freduce_coefficients(t);
  memcpy(output, t, sizeof(limb) * 10);
}

/* Take a little-endian, 32-byte number and expand it into polynomial form */
static void
fexpand(limb *output, const u8 *input) {
#define F(n,start,shift,mask) \
  output[n] = ((((limb) input[start + 0]) | \
                ((limb) input[start + 1]) << 8 | \
                ((limb) input[start + 2]) << 16 | \
                ((limb) input[start + 3]) << 24) >> shift) & mask;
  F(0, 0, 0, 0x3ffffff);
  F(1, 3, 2, 0x1ffffff);
  F(2, 6, 3, 0x3ffffff);
  F(3, 9, 5, 0x1ffffff);
  F(4, 12, 6, 0x3ffffff);
  F(5, 16, 0, 0x1ffffff);
  F(6, 19, 1, 0x3ffffff);
  F(7, 22, 3, 0x1ffffff);
  F(8, 25, 4, 0x3ffffff);
  F(9, 28, 6, 0x3ffffff);
#undef F
}

#if (-32 >> 1) != -16
#error "This code only works when >> does sign-extension on negative numbers"
#endif

/* Take a fully reduced polynomial form number and contract it into a
 * little-endian, 32-byte array
 */
static void
fcontract(u8 *output, limb *input) {
  int i;
  int j;

  for (j = 0; j < 2; ++j) {
    for (i = 0; i < 9; ++i) {
      if ((i & 1) == 1) {
        /* This calculation is a time-invariant way to make input[i] positive
           by borrowing from the next-larger limb.
        */
        const s32 mask = (s32)(input[i]) >> 31;
        const s32 carry = -(((s32)(input[i]) & mask) >> 25);
        input[i] = (s32)(input[i]) + (carry << 25);
        input[i+1] = (s32)(input[i+1]) - carry;
      } else {
        const s32 mask = (s32)(input[i]) >> 31;
        const s32 carry = -(((s32)(input[i]) & mask) >> 26);
        input[i] = (s32)(input[i]) + (carry << 26);
        input[i+1] = (s32)(input[i+1]) - carry;
      }
    }
    {
      const s32 mask = (s32)(input[9]) >> 31;
      const s32 carry = -(((s32)(input[9]) & mask) >> 25);
      input[9] = (s32)(input[9]) + (carry << 25);
      input[0] = (s32)(input[0]) - (carry * 19);
    }
  }

  /* The first borrow-propagation pass above ended with every limb
     except (possibly) input[0] non-negative.

     Since each input limb except input[0] is decreased by at most 1
     by a borrow-propagation pass, the second borrow-propagation pass
     could only have wrapped around to decrease input[0] again if the
     first pass left input[0] negative *and* input[1] through input[9]
     were all zero.  In that case, input[1] is now 2^25 - 1, and this
     last borrow-propagation step will leave input[1] non-negative.
  */
  {
    const s32 mask = (s32)(input[0]) >> 31;
    const s32 carry = -(((s32)(input[0]) & mask) >> 26);
    input[0] = (s32)(input[0]) + (carry << 26);
    input[1] = (s32)(input[1]) - carry;
  }

  /* Both passes through the above loop, plus the last 0-to-1 step, are
     necessary: if input[9] is -1 and input[0] through input[8] are 0,
     negative values will remain in the array until the end.
   */

  input[1] <<= 2;
  input[2] <<= 3;
  input[3] <<= 5;
  input[4] <<= 6;
  input[6] <<= 1;
  input[7] <<= 3;
  input[8] <<= 4;
  input[9] <<= 6;
#define F(i, s) \
  output[s+0] |=  input[i] & 0xff; \
  output[s+1]  = (input[i] >> 8) & 0xff; \
  output[s+2]  = (input[i] >> 16) & 0xff; \
  output[s+3]  = (input[i] >> 24) & 0xff;
  output[0] = 0;
  output[16] = 0;
  F(0,0);
  F(1,3);
  F(2,6);
  F(3,9);
  F(4,12);
  F(5,16);
  F(6,19);
  F(7,22);
  F(8,25);
  F(9,28);
#undef F
}

/* Input: Q, Q', Q-Q'
 * Output: 2Q, Q+Q'
 *
 *   x2 z3: long form
 *   x3 z3: long form
 *   x z: short form, destroyed
 *   xprime zprime: short form, destroyed
 *   qmqp: short form, preserved
 */
static void fmonty(limb *x2, limb *z2,  /* output 2Q */
                   limb *x3, limb *z3,  /* output Q + Q' */
                   limb *x, limb *z,    /* input Q */
                   limb *xprime, limb *zprime,  /* input Q' */
                   const limb *qmqp /* input Q - Q' */) {
  limb origx[10], origxprime[10], zzz[19], xx[19], zz[19], xxprime[19],
        zzprime[19], zzzprime[19], xxxprime[19];

  memcpy(origx, x, 10 * sizeof(limb));
  fsum(x, z);
  fdifference(z, origx);  /* does x - z */

  memcpy(origxprime, xprime, sizeof(limb) * 10);
  fsum(xprime, zprime);
  fdifference(zprime, origxprime);
  fproduct(xxprime, xprime, z);
  fproduct(zzprime, x, zprime);
  freduce_degree(xxprime);
  freduce_coefficients(xxprime);
  freduce_degree(zzprime);
  freduce_coefficients(zzprime);
  memcpy(origxprime, xxprime, sizeof(limb) * 10);
  fsum(xxprime, zzprime);
  fdifference(zzprime, origxprime);
  fsquare(xxxprime, xxprime);
  fsquare(zzzprime, zzprime);
  fproduct(zzprime, zzzprime, qmqp);
  freduce_degree(zzprime);
  freduce_coefficients(zzprime);
  memcpy(x3, xxxprime, sizeof(limb) * 10);
  memcpy(z3, zzprime, sizeof(limb) * 10);

  fsquare(xx, x);
  fsquare(zz, z);
  fproduct(x2, xx, zz);
  freduce_degree(x2);
  freduce_coefficients(x2);
  fdifference(zz, xx);  /* does zz = xx - zz */
  memset(zzz + 10, 0, sizeof(limb) * 9);
  fscalar_product(zzz, zz, 121665);
  /* No need to call freduce_degree here:
     fscalar_product doesn't increase the degree of its input. */
  freduce_coefficients(zzz);
  fsum(zzz, xx);
  fproduct(z2, zz, zzz);
  freduce_degree(z2);
  freduce_coefficients(z2);
}

/* Conditionally swap two reduced-form limb arrays if 'iswap' is 1, but leave
 * them unchanged if 'iswap' is 0.  Runs in data-invariant time to avoid
 * side-channel attacks.
 *
 * NOTE that this function requires that 'iswap' be 1 or 0; other values give
 * wrong results.  Also, the two limb arrays must be in reduced-coefficient,
 * reduced-degree form: the values in a[10..19] or b[10..19] aren't swapped,
 * and all all values in a[0..9],b[0..9] must have magnitude less than
 * INT32_MAX.
 */
static void
swap_conditional(limb a[19], limb b[19], limb iswap) {
  unsigned i;
  const s32 swap = (s32) -iswap;

  for (i = 0; i < 10; ++i) {
    const s32 x = swap & ( ((s32)a[i]) ^ ((s32)b[i]) );
    a[i] = ((s32)a[i]) ^ x;
    b[i] = ((s32)b[i]) ^ x;
  }
}

/* Calculates nQ where Q is the x-coordinate of a point on the curve
 *
 *   resultx/resultz: the x coordinate of the resulting curve point (short form)
 *   n: a little endian, 32-byte number
 *   q: a point of the curve (short form)
 */
static void
cmult(limb *resultx, limb *resultz, const u8 *n, const limb *q) {
  limb a[19] = {0}, b[19] = {1}, c[19] = {1}, d[19] = {0};
  limb *nqpqx = a, *nqpqz = b, *nqx = c, *nqz = d, *t;
  limb e[19] = {0}, f[19] = {1}, g[19] = {0}, h[19] = {1};
  limb *nqpqx2 = e, *nqpqz2 = f, *nqx2 = g, *nqz2 = h;

  unsigned i, j;

  memcpy(nqpqx, q, sizeof(limb) * 10);

  for (i = 0; i < 32; ++i) {
    u8 byte = n[31 - i];
    for (j = 0; j < 8; ++j) {
      const limb bit = byte >> 7;

      swap_conditional(nqx, nqpqx, bit);
      swap_conditional(nqz, nqpqz, bit);
      fmonty(nqx2, nqz2,
             nqpqx2, nqpqz2,
             nqx, nqz,
             nqpqx, nqpqz,
             q);
      swap_conditional(nqx2, nqpqx2, bit);
      swap_conditional(nqz2, nqpqz2, bit);

      t = nqx;
      nqx = nqx2;
      nqx2 = t;
      t = nqz;
      nqz = nqz2;
      nqz2 = t;
      t = nqpqx;
      nqpqx = nqpqx2;
      nqpqx2 = t;
      t = nqpqz;
      nqpqz = nqpqz2;
      nqpqz2 = t;

      byte <<= 1;
    }
  }

  memcpy(resultx, nqx, sizeof(limb) * 10);
  memcpy(resultz, nqz, sizeof(limb) * 10);
}

/* -----------------------------------------------------------------------------
 * Shamelessly copied from djb's code
 * ----------------------------------------------------------------------------- */
static void
crecip(limb *out, const limb *z) {
  limb z2[10];
  limb z9[10];
  limb z11[10];
  limb z2_5_0[10];
  limb z2_10_0[10];
  limb z2_20_0[10];
  limb z2_50_0[10];
  limb z2_100_0[10];
  limb t0[10];
  limb t1[10];
  int i;

  /* 2 */ fsquare(z2,z);
  /* 4 */ fsquare(t1,z2);
  /* 8 */ fsquare(t0,t1);
  /* 9 */ fmul(z9,t0,z);
  /* 11 */ fmul(z11,z9,z2);
  /* 22 */ fsquare(t0,z11);
  /* 2^5 - 2^0 = 31 */ fmul(z2_5_0,t0,z9);

  /* 2^6 - 2^1 */ fsquare(t0,z2_5_0);
  /* 2^7 - 2^2 */ fsquare(t1,t0);
  /* 2^8 - 2^3 */ fsquare(t0,t1);
  /* 2^9 - 2^4 */ fsquare(t1,t0);
  /* 2^10 - 2^5 */ fsquare(t0,t1);
  /* 2^10 - 2^0 */ fmul(z2_10_0,t0,z2_5_0);

  /* 2^11 - 2^1 */ fsquare(t0,z2_10_0);
  /* 2^12 - 2^2 */ fsquare(t1,t0);
  /* 2^20 - 2^10 */ for (i = 2;i < 10;i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
  /* 2^20 - 2^0 */ fmul(z2_20_0,t1,z2_10_0);

  /* 2^21 - 2^1 */ fsquare(t0,z2_20_0);
  /* 2^22 - 2^2 */ fsquare(t1,t0);
  /* 2^40 - 2^20 */ for (i = 2;i < 20;i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
  /* 2^40 - 2^0 */ fmul(t0,t1,z2_20_0);

  /* 2^41 - 2^1 */ fsquare(t1,t0);
  /* 2^42 - 2^2 */ fsquare(t0,t1);
  /* 2^50 - 2^10 */ for (i = 2;i < 10;i += 2) { fsquare(t1,t0); fsquare(t0,t1); }
  /* 2^50 - 2^0 */ fmul(z2_50_0,t0,z2_10_0);

  /* 2^51 - 2^1 */ fsquare(t0,z2_50_0);
  /* 2^52 - 2^2 */ fsquare(t1,t0);
  /* 2^100 - 2^50 */ for (i = 2;i < 50;i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
  /* 2^100 - 2^0 */ fmul(z2_100_0,t1,z2_50_0);

  /* 2^101 - 2^1 */ fsquare(t1,z2_100_0);
  /* 2^102 - 2^2 */ fsquare(t0,t1);
  /* 2^200 - 2^100 */ for (i = 2;i < 100;i += 2) { fsquare(t1,t0); fsquare(t0,t1); }
  /* 2^200 - 2^0 */ fmul(t1,t0,z2_100_0);

  /* 2^201 - 2^1 */ fsquare(t0,t1);
  /* 2^202 - 2^2 */ fsquare(t1,t0);
  /* 2^250 - 2^50 */ for (i = 2;i < 50;i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
  /* 2^250 - 2^0 */ fmul(t0,t1,z2_50_0);

  /* 2^251 - 2^1 */ fsquare(t1,t0);
  /* 2^252 - 2^2 */ fsquare(t0,t1);
  /* 2^253 - 2^3 */ fsquare(t1,t0);
  /* 2^254 - 2^4 */ fsquare(t0,t1);
  /* 2^255 - 2^5 */ fsquare(t1,t0);
  /* 2^255 - 21 */ fmul(out,t1,z11);
}

int curve25519_donna(u8 *, const u8 *, const u8 *);

int
curve25519_donna(u8 *mypublic, const u8 *secret, const u8 *basepoint) {
  limb bp[10], x[10], z[11], zmone[10];
  uint8_t e[32];
  int i;

  for (i = 0; i < 32; ++i) e[i] = secret[i];
  e[0] &= 248;
  e[31] &= 127;
  e[31] |= 64;

  fexpand(bp, basepoint);
  cmult(x, z, e, bp);
  crecip(zmone, z);
  fmul(z, x, zmone);
  freduce_coefficients(z);
  fcontract(mypublic, z);
  return 0;
}