{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "02eca690",
   "metadata": {},
   "source": [
    "# 机器学习与社会科学应用"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c22dc224",
   "metadata": {},
   "source": [
    "# 第七章 深度学习"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03a90d78",
   "metadata": {},
   "source": [
    "# 第三节 循环神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97e7c62f-6e60-4fc4-9229-c886cd2a596c",
   "metadata": {},
   "source": [
    "<font face=\"宋体\" >郭峰    \n",
    "    教授、博士生导师  \n",
    "上海财经大学公共经济与管理学院  \n",
    "上海财经大学数实融合与智能治理实验室    \n",
    "    邮箱：guofengsfi@163.com</font> "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "843ca6fb",
   "metadata": {},
   "source": [
    "<font face=\"宋体\" >本节目录：   \n",
    "3.1.数据预处理   \n",
    "3.2.创建RNN模型   \n",
    "3.3.定义损失函数和优化器   \n",
    "3.4.配置检查点（模型保存）   \n",
    "3.5.训练模型   \n",
    "3.6.文本生成</font> "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d85a10c0",
   "metadata": {},
   "source": [
    "### 3.1. 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b894f73d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "\n",
    "\n",
    "# 首先，我们需要加载并预处理文本数据\n",
    "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')\n",
    "\n",
    "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
    "\n",
    "# 去除文本中的换行符\n",
    "text = text.replace('\\n', ' ')\n",
    "\n",
    "# 创建文本字典，将字符映射到整数并反向映射\n",
    "vocab = sorted(set(text))\n",
    "char2idx = {u: i for i, u in enumerate(vocab)}\n",
    "idx2char = np.array(vocab)\n",
    "text_as_int = np.array([char2idx[c] for c in text])\n",
    "\n",
    "# 创建训练样本\n",
    "seq_length = 100\n",
    "examples_per_epoch = len(text) // (seq_length + 1)\n",
    "\n",
    "char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)\n",
    "\n",
    "sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)\n",
    "\n",
    "def split_input_target(chunk):\n",
    "    input_text = chunk[:-1]\n",
    "    target_text = chunk[1:]\n",
    "    return input_text, target_text\n",
    "\n",
    "dataset = sequences.map(split_input_target)\n",
    "\n",
    "BATCH_SIZE = 64\n",
    "BUFFER_SIZE = 10000\n",
    "\n",
    "dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3539a145",
   "metadata": {},
   "source": [
    "### 3.2. 创建RNN模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca0df0e1-49d3-4490-a434-f142f45e92c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 下面我们将创建一个简单的RNN模型\n",
    "\n",
    "vocab_size = len(vocab)\n",
    "embedding_dim = 256\n",
    "rnn_units = 1024\n",
    "\n",
    "def build_model(vocab_size, embedding_dim, rnn_units, batch_size):\n",
    "    model = tf.keras.Sequential([\n",
    "        tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),\n",
    "        tf.keras.layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),\n",
    "        tf.keras.layers.Dense(vocab_size)\n",
    "    ])\n",
    "    return model\n",
    "\n",
    "model = build_model(\n",
    "    vocab_size=vocab_size,\n",
    "    embedding_dim=embedding_dim,\n",
    "    rnn_units=rnn_units,\n",
    "    batch_size=BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee4783df",
   "metadata": {},
   "source": [
    "### 3.3. 定义损失函数和优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19cdf100",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def loss(labels, logits):\n",
    "    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)\n",
    "\n",
    "model.compile(optimizer='adam', loss=loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f9d8fb3",
   "metadata": {},
   "source": [
    "### 3.4. 配置检查点（模型保存）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "893c8462",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "checkpoint_dir = './training_checkpoints'\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
    "\n",
    "checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
    "    filepath=checkpoint_prefix,\n",
    "    save_weights_only=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddfab44b",
   "metadata": {},
   "source": [
    "### 3.5. 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c0e53f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "EPOCHS = 5\n",
    "\n",
    "history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c404895",
   "metadata": {},
   "source": [
    "### 3.6. 文本生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bc65b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 下面是生成文本的代码\n",
    "\n",
    "model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)\n",
    "\n",
    "model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))\n",
    "model.build(tf.TensorShape([1, None]))\n",
    "\n",
    "def generate_text(model, start_string):\n",
    "    num_generate = 1000\n",
    "    input_eval = [char2idx[s] for s in start_string]\n",
    "    input_eval = tf.expand_dims(input_eval, 0)\n",
    "    text_generated = []\n",
    "\n",
    "    temperature = 0.5\n",
    "\n",
    "    model.reset_states()\n",
    "    for i in range(num_generate):\n",
    "        predictions = model(input_eval)\n",
    "        predictions = tf.squeeze(predictions, 0)\n",
    "\n",
    "        predictions = predictions / temperature\n",
    "        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()\n",
    "\n",
    "        input_eval = tf.expand_dims([predicted_id], 0)\n",
    "\n",
    "        text_generated.append(idx2char[predicted_id])\n",
    "\n",
    "    return start_string + ''.join(text_generated)\n",
    "\n",
    "# 使用模型生成文本\n",
    "generated_text = generate_text(model, start_string=u\"ROMEO: \")\n",
    "print(generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52e9ec2c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
