Code development platform for open source projects from the European Union institutions

Skip to content
Snippets Groups Projects
prompt_engineering_fireworks.ai.ipynb 8.32 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install necessary packages\n",
    "Langchain supports many LLM inference providers, including Fireworks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install langchain\n",
    "!pip install python-dotenv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv\n",
    "import requests\n",
    "import json\n",
    "\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### API KEY\n",
    "* register and get api key from : https://fireworks.ai/api-keys\n",
    "* put the key in the file .env file in FIREWORKS_API_KEY variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "api_key = os.getenv(\"FIREWORKS_API_KEY\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# maximum number of tokens to generate by the model\n",
    "max_tokens = {}\n",
    "max_tokens[0] = 1000\n",
    "max_tokens[1] = 1000\n",
    "max_tokens[2] = 2000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prompting Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\n",
    "     'accounts/fireworks/models/starcoder-7b', \n",
    "     'accounts/fireworks/models/starcoder-16b', \n",
    "     'accounts/fireworks/models/llama-v2-13b-code-instruct', \n",
    "     'accounts/fireworks/models/llama-v2-34b-code-instruct',\n",
    "     'accounts/fireworks/models/llama-v2-70b-code-instruct',\n",
    "     'accounts/fireworks/models/mixtral-8x7b-instruct',\n",
    "          ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import json\n",
    "\n",
    "def send_fireworks_request(model, api_key, model_type='chat', prompt_or_messages=None, max_tokens=1024, \n",
    "                           temperature=0.6, top_p=1, top_k=40, \n",
    "                           frequency_penalty=0, presence_penalty=0):\n",
    "    # Base URL and payload setup\n",
    "    base_url = \"https://api.fireworks.ai/inference/v1/\"\n",
    "    payload = {\n",
    "        \"model\": model,\n",
    "        \"max_tokens\": max_tokens,\n",
    "        \"temperature\": temperature,\n",
    "        \"top_p\": top_p,\n",
    "        \"top_k\": top_k,\n",
    "        \"presence_penalty\": presence_penalty,\n",
    "        \"frequency_penalty\": frequency_penalty\n",
    "    }\n",
    "\n",
    "    # Configure specific parameters for chat or completion\n",
    "    if model_type == 'chat':\n",
    "        url = base_url + \"chat/completions\"\n",
    "        payload[\"messages\"] = prompt_or_messages\n",
    "    elif model_type == 'completion':\n",
    "        url = base_url + \"completions\"\n",
    "        payload[\"prompt\"] = prompt_or_messages\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported model type. Choose 'chat' or 'completion'.\")\n",
    "\n",
    "    headers = {\n",
    "        \"Accept\": \"application/json\",\n",
    "        \"Content-Type\": \"application/json\",\n",
    "        \"Authorization\": f\"Bearer {api_key}\"\n",
    "    }\n",
    "    \n",
    "    response = requests.post(url, json=payload, headers=headers)\n",
    "    if response.status_code != 200:\n",
    "        return {\"error\": response.json()}\n",
    "    return response.json()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shots = [0,1,2]\n",
    "\n",
    "for shot in shots:\n",
    "    print(f'Processing shot: {shot}')\n",
    "    \n",
    "    base_path = f'data/prompts/{shot}-shot'\n",
    "    prompt = open(f'{base_path}/prompt.txt', 'r').read()\n",
    "    \n",
    "    for model in models:\n",
    "        model_name = model.split('/')[-1]\n",
    "        print(f'Processing model: {model_name}')\n",
    "        \n",
    "        results_dir = f'{base_path}/results'\n",
    "        \n",
    "        if not os.path.exists(results_dir):\n",
    "            os.makedirs(results_dir)\n",
    "\n",
    "        \n",
    "        file_path = f'{results_dir}/{model_name}.fireworks.ai.txt'\n",
    "        \n",
    "        # Check if the result file already exists\n",
    "        if os.path.exists(file_path):\n",
    "            print('Skipping...')\n",
    "            continue\n",
    "\n",
    "        llm = Fireworks(\n",
    "\t\t\tfireworks_api_key=api_key,\n",
    "\t\t\tmodel=model,\n",
    "\t\t\tmax_tokens=max_tokens[shot])\n",
    "        result = llm.invoke(prompt)\n",
    "        \n",
    "        with open(file_path, 'w') as file:\n",
    "            file.write(result)  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### XML Extraction from results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "\n",
    "base_dir = 'data/prompts'\n",
    "shots = [0, 1, 2]\n",
    "\n",
    "for shot in shots:\n",
    "    results_path = os.path.join(base_dir, f'{shot}-shot', 'results')\n",
    "    results_xml_path = os.path.join(base_dir, f'{shot}-shot', 'results-xml')\n",
    "\n",
    "    # Ensure the results-xml directory exists\n",
    "    if not os.path.exists(results_xml_path):\n",
    "        os.makedirs(results_xml_path)\n",
    "    \n",
    "    # Loop through each result file in the results directory\n",
    "    if os.path.exists(results_path) and os.path.isdir(results_path):\n",
    "        for result_file in os.listdir(results_path):\n",
    "            file_path = os.path.join(results_path, result_file)\n",
    "            if file_path.endswith('.fireworks.ai.txt'):\n",
    "                with open(file_path, 'r') as file:\n",
    "                    result_content = file.read()\n",
    "                \n",
    "                # Regular expression to find content enclosed by <coverPage>...</coverPage>\n",
    "                # This pattern ignores any text outside the XML tags\n",
    "                start_tag = \"<coverPage>\"\n",
    "                end_tag = \"</coverPage>\"\n",
    "                \n",
    "                # Finding the last occurrence of the start_tag and the last occurrence of the end_tag\n",
    "                start = result_content.rfind(start_tag)\n",
    "                end = result_content.rfind(end_tag) + len(end_tag)\n",
    "                \n",
    "                # If the start tag or end tag is not found, return an empty string or a specific message\n",
    "                if start == -1 or end == -1:\n",
    "                    print(f\"No XML content found in {result_file}\")\n",
    "                \n",
    "                xml_content = result_content[start:end]\n",
    "                \n",
    "                    \n",
    "                    # Prepares the filename and path for saving the extracted XML\n",
    "                xml_file_name = result_file.replace('.txt', '.xml')\n",
    "                xml_file_path = os.path.join(results_xml_path, xml_file_name)\n",
    "                    \n",
    "                    # Writes the XML content to a new file in the results-xml directory\n",
    "                with open(xml_file_path, 'w') as xml_file:\n",
    "                    xml_file.write(xml_content)\n",
    "                print(f'Extracted and saved XML for {xml_file_name}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}