/* ***** BEGIN LICENSE BLOCK *****
 * Version: MPL 1.1/GPL 2.0/LGPL 2.1
 *
 * The contents of this file are subject to the Mozilla Public License Version
 * 1.1 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 * http://www.mozilla.org/MPL/
 *
 * Software distributed under the License is distributed on an "AS IS" basis,
 * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
 * for the specific language governing rights and limitations under the
 * License.
 *
 * The Original Code is the Netscape security libraries.
 *
 * The Initial Developer of the Original Code is
 * Netscape Communications Corporation.
 * Portions created by the Initial Developer are Copyright (C) 1994-2000
 * the Initial Developer. All Rights Reserved.
 *
 * Contributor(s):
 *
 * Alternatively, the contents of this file may be used under the terms of
 * either the GNU General Public License Version 2 or later (the "GPL"), or
 * the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
 * in which case the provisions of the GPL or the LGPL are applicable instead
 * of those above. If you wish to allow use of your version of this file only
 * under the terms of either the GPL or the LGPL, and not to allow others to
 * use your version of this file under the terms of the MPL, indicate your
 * decision by deleting the provisions above and replace them with the notice
 * and other provisions required by the GPL or the LGPL. If you do not delete
 * the provisions above, a recipient may use your version of this file under
 * the terms of any one of the MPL, the GPL or the LGPL.
 *
 * ***** END LICENSE BLOCK ***** */
#include "sechash.h"
#include "secoidt.h"
#include "secerr.h"
#include "blapi.h"
#include "pk11func.h"   /* for the PK11_ calls below. */

#define MD2_LEN_ENCODING_SIZE     8
#define MD5_LEN_ENCODING_SIZE     8
#define SHA1_LEN_ENCODING_SIZE    8
#define SHA256_LEN_ENCODING_SIZE  8
#define SHA384_LEN_ENCODING_SIZE  8
#define SHA512_LEN_ENCODING_SIZE 16

static void *
null_hash_new_context(void)
{
    return NULL;
}

static void *
null_hash_clone_context(void *v)
{
    PORT_Assert(v == NULL);
    return NULL;
}

static void
null_hash_begin(void *v)
{
}

static void
null_hash_update(void *v, const unsigned char *input, unsigned int length)
{
}

static void
null_hash_end(void *v, unsigned char *output, unsigned int *outLen,
              unsigned int maxOut)
{
    *outLen = 0;
}

static void
null_hash_destroy_context(void *v, PRBool b)
{
    PORT_Assert(v == NULL);
}


static void *
md2_NewContext(void) {
        return (void *) PK11_CreateDigestContext(SEC_OID_MD2);
}

static void *
md5_NewContext(void) {
        return (void *) PK11_CreateDigestContext(SEC_OID_MD5);
}

static void *
sha1_NewContext(void) {
        return (void *) PK11_CreateDigestContext(SEC_OID_SHA1);
}

static void *
sha256_NewContext(void) {
        return (void *) PK11_CreateDigestContext(SEC_OID_SHA256);
}

static void *
sha384_NewContext(void) {
        return (void *) PK11_CreateDigestContext(SEC_OID_SHA384);
}

static void *
sha512_NewContext(void) {
        return (void *) PK11_CreateDigestContext(SEC_OID_SHA512);
}

const SECHashObject SECHashObjects[] = {
  { 0,
    (void * (*)(void)) null_hash_new_context,
    (void * (*)(void *)) null_hash_clone_context,
    (void (*)(void *, PRBool)) null_hash_destroy_context,
    (void (*)(void *)) null_hash_begin,
    (void (*)(void *, const unsigned char *, unsigned int)) null_hash_update,
    (void (*)(void *, unsigned char *, unsigned int *,
              unsigned int)) null_hash_end,
    0,
    HASH_AlgNULL,
    0
  },
  { MD2_LENGTH,
    (void * (*)(void)) md2_NewContext,
    (void * (*)(void *)) PK11_CloneContext,
    (void (*)(void *, PRBool)) PK11_DestroyContext,
    (void (*)(void *)) PK11_DigestBegin,
    (void (*)(void *, const unsigned char *, unsigned int)) PK11_DigestOp,
    (void (*)(void *, unsigned char *, unsigned int *, unsigned int)) 
                                                        PK11_DigestFinal,
    MD2_BLOCK_LENGTH,
    HASH_AlgMD2,
    MD2_LEN_ENCODING_SIZE
  },
  { MD5_LENGTH,
    (void * (*)(void)) md5_NewContext,
    (void * (*)(void *)) PK11_CloneContext,
    (void (*)(void *, PRBool)) PK11_DestroyContext,
    (void (*)(void *)) PK11_DigestBegin,
    (void (*)(void *, const unsigned char *, unsigned int)) PK11_DigestOp,
    (void (*)(void *, unsigned char *, unsigned int *, unsigned int)) 
                                                        PK11_DigestFinal,
    MD5_BLOCK_LENGTH,
    HASH_AlgMD5,
    MD5_LEN_ENCODING_SIZE
  },
  { SHA1_LENGTH,
    (void * (*)(void)) sha1_NewContext,
    (void * (*)(void *)) PK11_CloneContext,
    (void (*)(void *, PRBool)) PK11_DestroyContext,
    (void (*)(void *)) PK11_DigestBegin,
    (void (*)(void *, const unsigned char *, unsigned int)) PK11_DigestOp,
    (void (*)(void *, unsigned char *, unsigned int *, unsigned int)) 
                                                        PK11_DigestFinal,
    SHA1_BLOCK_LENGTH,
    HASH_AlgSHA1,
    SHA1_LEN_ENCODING_SIZE
  },
  { SHA256_LENGTH,
    (void * (*)(void)) sha256_NewContext,
    (void * (*)(void *)) PK11_CloneContext,
    (void (*)(void *, PRBool)) PK11_DestroyContext,
    (void (*)(void *)) PK11_DigestBegin,
    (void (*)(void *, const unsigned char *, unsigned int)) PK11_DigestOp,
    (void (*)(void *, unsigned char *, unsigned int *, unsigned int)) 
                                                        PK11_DigestFinal,
    SHA256_BLOCK_LENGTH,
    HASH_AlgSHA256,
    SHA256_LEN_ENCODING_SIZE
  },
  { SHA384_LENGTH,
    (void * (*)(void)) sha384_NewContext,
    (void * (*)(void *)) PK11_CloneContext,
    (void (*)(void *, PRBool)) PK11_DestroyContext,
    (void (*)(void *)) PK11_DigestBegin,
    (void (*)(void *, const unsigned char *, unsigned int)) PK11_DigestOp,
    (void (*)(void *, unsigned char *, unsigned int *, unsigned int)) 
                                                        PK11_DigestFinal,
    SHA384_BLOCK_LENGTH,
    HASH_AlgSHA384,
    SHA384_LEN_ENCODING_SIZE
  },
  { SHA512_LENGTH,
    (void * (*)(void)) sha512_NewContext,
    (void * (*)(void *)) PK11_CloneContext,
    (void (*)(void *, PRBool)) PK11_DestroyContext,
    (void (*)(void *)) PK11_DigestBegin,
    (void (*)(void *, const unsigned char *, unsigned int)) PK11_DigestOp,
    (void (*)(void *, unsigned char *, unsigned int *, unsigned int)) 
                                                        PK11_DigestFinal,
    SHA512_BLOCK_LENGTH,
    HASH_AlgSHA512,
    SHA512_LEN_ENCODING_SIZE
  },
};

const SECHashObject * 
HASH_GetHashObject(HASH_HashType type)
{
    return &SECHashObjects[type];
}

HASH_HashType
HASH_GetHashTypeByOidTag(SECOidTag hashOid)
{
    HASH_HashType ht    = HASH_AlgNULL;

    switch(hashOid) {
    case SEC_OID_MD2:    ht = HASH_AlgMD2;    break;
    case SEC_OID_MD5:    ht = HASH_AlgMD5;    break;
    case SEC_OID_SHA1:   ht = HASH_AlgSHA1;   break;
    case SEC_OID_SHA256: ht = HASH_AlgSHA256; break;
    case SEC_OID_SHA384: ht = HASH_AlgSHA384; break;
    case SEC_OID_SHA512: ht = HASH_AlgSHA512; break;
    default:             ht = HASH_AlgNULL;   
        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
        break;
    }
    return ht;
}

const SECHashObject * 
HASH_GetHashObjectByOidTag(SECOidTag hashOid)
{
    HASH_HashType ht    = HASH_GetHashTypeByOidTag(hashOid);

    return (ht == HASH_AlgNULL) ? NULL : &SECHashObjects[ht];
}

/* returns zero for unknown hash OID */
unsigned int
HASH_ResultLenByOidTag(SECOidTag hashOid)
{
    const SECHashObject * hashObject = HASH_GetHashObjectByOidTag(hashOid);
    unsigned int          resultLen = 0;

    if (hashObject)
        resultLen = hashObject->length;
    return resultLen;
}

/* returns zero if hash type invalid. */
unsigned int
HASH_ResultLen(HASH_HashType type)
{
    if ( ( type < HASH_AlgNULL ) || ( type >= HASH_AlgTOTAL ) ) {
        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
        return(0);
    }
    
    return(SECHashObjects[type].length);
}

unsigned int
HASH_ResultLenContext(HASHContext *context)
{
    return(context->hashobj->length);
}



SECStatus
HASH_HashBuf(HASH_HashType type,
             unsigned char *dest,
             unsigned char *src,
             uint32 src_len)
{
    HASHContext *cx;
    unsigned int part;
    
    if ( ( type < HASH_AlgNULL ) || ( type >= HASH_AlgTOTAL ) ) {
        return(SECFailure);
    }
    
    cx = HASH_Create(type);
    if ( cx == NULL ) {
        return(SECFailure);
    }
    HASH_Begin(cx);
    HASH_Update(cx, src, src_len);
    HASH_End(cx, dest, &part, HASH_ResultLenContext(cx));
    HASH_Destroy(cx);

    return(SECSuccess);
}

HASHContext *
HASH_Create(HASH_HashType type)
{
    void *hash_context = NULL;
    HASHContext *ret = NULL;
    
    if ( ( type < HASH_AlgNULL ) || ( type >= HASH_AlgTOTAL ) ) {
        return(NULL);
    }
    
    hash_context = (* SECHashObjects[type].create)();
    if ( hash_context == NULL ) {
        goto loser;
    }

    ret = (HASHContext *)PORT_Alloc(sizeof(HASHContext));
    if ( ret == NULL ) {
        goto loser;
    }

    ret->hash_context = hash_context;
    ret->hashobj = &SECHashObjects[type];
    ret->rand_params = NULL;
    return(ret);
    
loser:
    if ( hash_context != NULL ) {
        (* SECHashObjects[type].destroy)(hash_context, PR_TRUE);
    }
    
    return(NULL);
}


HASHContext *
HASH_Clone(HASHContext *context)
{
    void *hash_context = NULL;
    HASHContext *ret = NULL;
    
    hash_context = (* context->hashobj->clone)(context->hash_context);
    if ( hash_context == NULL ) {
        goto loser;
    }

    ret = (HASHContext *)PORT_Alloc(sizeof(HASHContext));
    if ( ret == NULL ) {
        goto loser;
    }

    ret->hash_context = hash_context;
    ret->hashobj = context->hashobj;
    if (context->rand_params != NULL) {
      RANDHashParams *params =
        (RANDHashParams*)PORT_Alloc(sizeof(RANDHashParams));
      if (params == NULL) {
        PORT_Free(ret);
        goto loser;
      }
      memcpy(params, context->rand_params, sizeof(RANDHashParams));

      /* wshao: this only makes a shadow copy of the salt value in
       * randomized hash. TODO: redo salt copy
       */
      ret->hash_context = hash_context;
      ret->hashobj = context->hashobj;
      ret->rand_params = params;
    }
    return(ret);
    
loser:
    if ( hash_context != NULL ) {
        (* context->hashobj->destroy)(hash_context, PR_TRUE);
    }
    
    return(NULL);

}

void
HASH_Destroy(HASHContext *context)
{
    (* context->hashobj->destroy)(context->hash_context, PR_TRUE);
    if (context->rand_params != NULL) {
      /*  wshao: there may be a problem for cloned context.
       *  It was a shadow copy
       */
      PORT_Free(context->rand_params->salt);
      PORT_Free(context->rand_params);
    }
    PORT_Free(context);
    return;
}


void
HASH_Begin(HASHContext *context)
{
    (* context->hashobj->begin)(context->hash_context);
    return;
}


void
HASH_Update(HASHContext *context,
            const unsigned char *src,
            unsigned int len)
{
    (* context->hashobj->update)(context->hash_context, src, len);
    return;
}

void
HASH_End(HASHContext *context,
         unsigned char *result,
         unsigned int *result_len,
         unsigned int max_result_len)
{
    (* context->hashobj->end)(context->hash_context, result, result_len,
                              max_result_len);
    return;
}
#define MIN_SALT_LENGTH 16
/* Set randomized parameter for HASHContext. This must be called before 
 * HASH_Begin is called. 
 * If salt_len > block size, truncate it.
 * If salt_len < 16. reject it.
 * Please note that, the params will have data of the block length by 
 * repeating the input salt. 
 */
PRBool
HASH_SetRandomize(HASHContext* context,
                  unsigned char* salt,
                  int salt_len)
{
  RANDHashParams *params;
  int len;
  int total_written;
  int to_write;
  if (salt_len < MIN_SALT_LENGTH) {
    return PR_FALSE;
  }  
  params = (RANDHashParams*)PORT_Alloc(sizeof(RANDHashParams));
  if (params == NULL) {
    return PR_FALSE;
  }
  len = context->hashobj->blocklength;
  if (salt_len > len) {
    /* ignore the rest */
    salt_len = len;
  }
  params->salt_length = salt_len;
  params->next_index_ = 0;
  /* TODO: check null. */
  params->salt = (unsigned char*) PORT_Alloc(len);
  total_written = 0;
  to_write = len - total_written;
  to_write = to_write > salt_len ? salt_len : to_write;
  while (to_write > 0) {
    memcpy(params->salt + total_written, salt, to_write);
    total_written += to_write;
    to_write = len - total_written;
    to_write = to_write > salt_len ? salt_len : to_write;
  }

  context->rand_params = params;
  return PR_TRUE;
}

HASHContext* HASH_CreateEx(HASH_HashType type,
                           SECItem* params)
{
  /* TODO: deprecate HASH_Create */
  HASHContext* ctx = HASH_Create(type);
  if (ctx && params->type == siSaltValue) {
    HASH_SetRandomize(ctx, params->data, params->len);
  }
  return ctx;
}

void HASH_BeginEx(HASHContext *context)
{
  unsigned char r_0[HASH_BLOCK_LENGTH_MAX];
  int len;
  HASH_Begin(context); /* TODO: deprecate HASH_Begin later */

  if (context->rand_params) {
    if (context->hashobj == NULL) {
      /* something seriously wrong. e.g, incorrect call sequence */
      return;
    }
    len = (context->hashobj->blocklength);
    /* get r_0 */

    memset(r_0, 0, HASH_BLOCK_LENGTH_MAX);
    memcpy(r_0, context->rand_params->salt, context->rand_params->salt_length);
    (* context->hashobj->update)(context->hash_context, r_0, len);
  }
}

void HASH_UpdateEx(HASHContext *context,
                   const unsigned char *src,
                   unsigned int len)
{
  int i = 0;
  if (context->rand_params) {
    unsigned short next_index = context->rand_params->next_index_;

    unsigned char* new_src = (unsigned char*) PORT_Alloc(len);

    for (i=0; i<len; i++) {
      new_src[i] = context->rand_params->salt[next_index++] ^ src[i];
      if (next_index == context->hashobj->blocklength) {
        next_index = 0;
      }
    }
    /*context->msg_len += len;*/
    context->rand_params->next_index_ = next_index;
    (* context->hashobj->update)(context->hash_context, new_src, len);
    PORT_Free(new_src);

  } else {
    (* context->hashobj->update)(context->hash_context, src, len);
  }
}


void HASH_EndEx(HASHContext *context,
                unsigned char *result,
                unsigned int *result_len,
                unsigned int max_result_len)
{
  /* TODO: figure out r_2 */
  unsigned int b = context->hashobj->blocklength;
  unsigned int c = context->hashobj->length_encoding_size;
  unsigned char* padding = NULL;
  unsigned short next_index = context->rand_params->next_index_;

  unsigned short bit_length = next_index << 3;
  int i;
  PORT_Assert((b-c-3) > 0);

  if (context->rand_params != NULL) {
    if (next_index <=(b-c-3) && next_index != 0) {
      unsigned short length = b - c - 1 - next_index;
      padding = (unsigned char*)PORT_Alloc(length);
      memset(padding, 0, length);
      memcpy(padding + (length-sizeof(unsigned short)), &bit_length,
             sizeof(unsigned short));

      for (i=0; i<length; i++) {
        padding[i] = context->rand_params->salt[next_index++] ^ padding[i];
      }
      (*context->hashobj->update)(context->hash_context, padding, b-c-next_index);
      PORT_Free(padding);
    } else {
      /* TODO: need to use htons */
      if (next_index != 0) {
        padding = (unsigned char*)PORT_Alloc(b-next_index);
        memset(padding, 0, b-next_index);
        for (i=0; i<b-next_index; i++) {
          /* TODO(wshao): combine memset and XOR into a single assignment */
          padding[i] = context->rand_params->salt[next_index + i] ^ padding[i];
        }
        (*context->hashobj->update)(context->hash_context, padding, b-next_index);
        PORT_Free(padding);
      }
      padding = (unsigned char*)PORT_Alloc(b-c-1);
      memset(padding, 0, b-c-1);
      memcpy(padding + (b-c-1-sizeof(unsigned short)),&bit_length,
             sizeof(unsigned short));

      for (i=0; i<b-c-1; i++) {
        padding[i] = padding[i] ^ context->rand_params->salt[i];
      }
      (*context->hashobj->update)(context->hash_context, padding, b-c-1);
      PORT_Free(padding);
    }
  }
  (* context->hashobj->end)(context->hash_context, result, result_len,
                            max_result_len);
}