{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyMDQAGwDnmThXqSHdNwTwvq", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# LangChain Callback 示例" ], "metadata": { "id": "GIUvp6uXWAcn" } }, { "cell_type": "markdown", "source": [ "## 准备环境" ], "metadata": { "id": "h4Ps0wdmWFGh" } }, { "cell_type": "markdown", "source": [ "1. 安装langchain版本0.0.235,以及openai" ], "metadata": { "id": "E8OKcwR6VdHL" } }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "yDa7VdF9SlAm", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "bf91585d-8b22-4bd9-d3ab-111d991cbd9b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.3 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.2/1.3 MB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m21.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m17.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.6/73.6 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.0/90.0 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.4/49.4 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install -q -U langchain==0.0.235 openai" ] }, { "cell_type": "markdown", "source": [ "2. 设置OPENAI API Key" ], "metadata": { "id": "N2-YYQa1s_jQ" } }, { "cell_type": "code", "source": [ "import os\n", "\n", "os.environ['OPENAI_API_KEY'] = \"您的有效openai api key\"" ], "metadata": { "id": "OSGQIm6FtD8A" }, "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 示例代码" ], "metadata": { "id": "i0wXx4wfWIjF" } }, { "cell_type": "markdown", "source": [ "1. 内置回调处理器 `StdOutCallbackHandler`" ], "metadata": { "id": "lEPqka5QWcph" } }, { "cell_type": "code", "source": [ "from langchain.callbacks import StdOutCallbackHandler\n", "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", "from langchain.prompts import PromptTemplate\n", "\n", "handler = StdOutCallbackHandler()\n", "llm = OpenAI()\n", "prompt = PromptTemplate.from_template(\"Who is {name}?\")\n", "chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])\n", "chain.run(name=\"Super Mario\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 202 }, "id": "jhoNdmhQs1J3", "outputId": "3e110c89-750d-4f64-c0b9-8bc49bb24895" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "\n", "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", "Prompt after formatting:\n", "\u001b[32;1m\u001b[1;3mWho is Super Mario?\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "'\\n\\nSuper Mario is the protagonist of the popular video game franchise of the same name created by Nintendo. He is a fictional character who stars in video games, television shows, comic books, and films. He is a plumber who is usually portrayed as a portly Italian-American, who is often accompanied by his brother Luigi. He is well known for his catchphrase \"It\\'s-a me, Mario!\"'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 3 } ] }, { "cell_type": "markdown", "source": [ "2. 自定义回调处理器\n", "\n", "我们来实现一个处理器,统计每次 `LLM` 交互的处理时间。" ], "metadata": { "id": "NS9I0GYhs4Ny" } }, { "cell_type": "code", "source": [ "from langchain.callbacks.base import BaseCallbackHandler\n", "import time\n", "\n", "class TimerHandler(BaseCallbackHandler):\n", "\n", " def __init__(self) -> None:\n", " super().__init__()\n", " self.previous_ms = None\n", " self.durations = []\n", "\n", " def current_ms(self):\n", " return int(time.time() * 1000 + time.perf_counter() % 1 * 1000)\n", "\n", " def on_chain_start(self, serialized, inputs, **kwargs) -> None:\n", " self.previous_ms = self.current_ms()\n", "\n", " def on_chain_end(self, outputs, **kwargs) -> None:\n", " if self.previous_ms:\n", " duration = self.current_ms() - self.previous_ms\n", " self.durations.append(duration)\n", "\n", " def on_llm_start(self, serialized, prompts, **kwargs) -> None:\n", " self.previous_ms = self.current_ms()\n", "\n", " def on_llm_end(self, response, **kwargs) -> None:\n", " if self.previous_ms:\n", " duration = self.current_ms() - self.previous_ms\n", " self.durations.append(duration)" ], "metadata": { "id": "WJQTUyrnwIf6" }, "execution_count": 20, "outputs": [] }, { "cell_type": "code", "source": [ "llm = OpenAI()\n", "timerHandler = TimerHandler()\n", "prompt = PromptTemplate.from_template(\"What is the HEX code of color {color_name}?\")\n", "chain = LLMChain(llm=llm, prompt=prompt, callbacks=[timerHandler])\n", "response = chain.run(color_name=\"blue\")\n", "print(response)\n", "response = chain.run(color_name=\"purple\")\n", "print(response)\n", "\n", "timerHandler.durations" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "B7Y9E34VPObp", "outputId": "5d5f7945-8ad2-43ce-ce52-b6787ae61dac" }, "execution_count": 17, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'color_name': 'blue'}\n", "{'text': '\\n\\nThe HEX code of color blue is #0000FF.'}\n", "\n", "\n", "The HEX code of color blue is #0000FF.\n", "{'color_name': 'purple'}\n", "{'text': '\\n\\nThe HEX code of color purple is #800080.'}\n", "\n", "\n", "The HEX code of color purple is #800080.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[1065, 1075]" ] }, "metadata": {}, "execution_count": 17 } ] }, { "cell_type": "markdown", "source": [ "3. `Model` 与 `callbacks`\n", "\n", "`callbacks` 可以在构造函数中指定,也可以在执行期间的函数调用中指定。\n", "\n", "请参考如下代码:" ], "metadata": { "id": "3U6j88miOHis" } }, { "cell_type": "code", "source": [ "timerHandler = TimerHandler()\n", "llm = OpenAI(callbacks=[timerHandler])\n", "response = llm.predict(\"What is the HEX code of color BLACK?\")\n", "print(response)\n", "\n", "timerHandler.durations" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Gh8e7GWzOOkT", "outputId": "2af5da44-5a0f-40b0-8e55-1c0643c9c4a8" }, "execution_count": 22, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['What is the HEX code of color BLACK?']\n", "generations=[[Generation(text='\\n\\nThe hex code of black is #000000.', generation_info={'finish_reason': 'stop', 'logprobs': None})]] llm_output={'token_usage': {'prompt_tokens': 10, 'total_tokens': 21, 'completion_tokens': 11}, 'model_name': 'text-davinci-003'} run=None\n", "\n", "\n", "The hex code of black is #000000.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[1223]" ] }, "metadata": {}, "execution_count": 22 } ] }, { "cell_type": "code", "source": [ "timerHandler = TimerHandler()\n", "llm = OpenAI()\n", "response = llm.predict(\"What is the HEX code of color BLACK?\", callbacks=[timerHandler])\n", "print(response)\n", "\n", "timerHandler.durations" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Wm94hadJR1a2", "outputId": "00e4053b-0006-40f9-9c1a-d6ae1afef3c6" }, "execution_count": 25, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['What is the HEX code of color BLACK?']\n", "generations=[[Generation(text='\\n\\nThe Hex code of the color black is #000000.', generation_info={'finish_reason': 'stop', 'logprobs': None})]] llm_output={'token_usage': {'prompt_tokens': 10, 'total_tokens': 23, 'completion_tokens': 13}, 'model_name': 'text-davinci-003'} run=None\n", "\n", "\n", "The Hex code of the color black is #000000.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[1777]" ] }, "metadata": {}, "execution_count": 25 } ] } ] }