본문 바로가기
Data

An AI application that can chat with with very large SQL databases. - 초대형 SQL 데이터베이스와 채팅할 수 있는 AI 애플리케이션

by Hagrid 2024. 2. 18.
반응형

코드공유

지난 글에서는 SQL 데이터베이스와 채팅할 수 있는 간단한 애플리케이션을 만들었습니다. 이 글을 계속 진행하기 전에 여기에서 확인해 보세요.

또한 여기에서 이 코드의 전체 리포지토리를 받으세요.


이전과 지금 글의 차이점 

이전 글에서 구축한 애플리케이션은 OpenAI의 토큰 제한으로 인해 매우 큰 데이터베이스에서 제대로 작동하지 않습니다. 데이터베이스가 너무 크면 프롬프트에 전체 열 및 테이블 목록을 콘텍스트로 전송할 수 없습니다. 이 글에서는 이러한 한계를 극복하기 위해 노력하겠습니다.

이전 글에서 이미 만든 가장 간단한 애플리케이션의 코드부터 시작해 보겠습니다. 아래 코드는 SQL 데이터베이스에 연결하여 채팅을 시작할 수 있는 간단한 간소화된 애플리케이션을 시작합니다.

import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate

from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv




folders_to_create = ['csvs']
# Check and create folders if they don't exist
for folder_name in folders_to_create:
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
        print(f"Folder '{folder_name}' created.")
    else:
        print(f"Folder '{folder_name}' already exists.")




## load the API key from the environment variable
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")


llm = OpenAI(openai_api_key=openai_api_key)
chat_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)





def get_basic_table_details(cursor):
    cursor.execute("""SELECT
            c.table_name,
            c.column_name,
            c.data_type
        FROM
            information_schema.columns c
        WHERE
            c.table_name IN (
                SELECT tablename
                FROM pg_tables
                WHERE schemaname = 'public'
    );""")
    tables_and_columns = cursor.fetchall()
    return tables_and_columns





def save_db_details(db_uri):

    unique_id = str(uuid4()).replace("-", "_")
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()

    tables_and_columns = get_basic_table_details(cursor)

    ## Get all the tables and columns and enter them in a pandas dataframe
    df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df.to_csv(filename_t, index=False)

    cursor.close()
    connection.close()

    return unique_id







def generate_template_for_sql(query, table_info, db_uri):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        f"You are an assistant that can write SQL Queries."
                        f"Given the text below, write a SQL query that answers the user's question."
                        f"DB connection string is {db_uri}"
                        f"Here is a detailed description of the table(s): "
                        f"{table_info}"
                        "Prepend and append the SQL query with three backticks '```'"
                        
                        
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    return answer.content




def get_the_output_from_llm(query, unique_id, db_uri):
    ## Load the tables csv
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df = pd.read_csv(filename_t)

    ## For each relevant table create a string that list down all the columns and their data types
    table_info = ''
    for table in df['table_name']:
        table_info += 'Information about table' + table + ':\n'
        table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'

    return generate_template_for_sql(query, table_info, db_uri)





def execute_the_solution(solution, db_uri):
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()
    _,final_query,_ = solution.split("```") 
    final_query = final_query.strip('sql')
    cursor.execute(final_query)
    result = cursor.fetchall()
    return str(result)





# Function to establish connection and read metadata for the database
def connect_with_db(uri):
    st.session_state.db_uri = uri
    st.session_state.unique_id = save_db_details(uri)

    return {"message": "Connection established to Database!"}

# Function to call the API with the provided URI
def send_message(message):
    solution = get_the_output_from_llm(message, st.session_state.unique_id, st.session_state.db_uri)
    result = execute_the_solution(solution, st.session_state.db_uri)
    return {"message": solution + "\n\n" + "Result:\n" + result}



# ## Instructions
st.subheader("Instructions")
st.markdown(
    """
    1. Enter the URI of your RDS Database in the text box below.
    2. Click the **Start Chat** button to start the chat.
    3. Enter your message in the text box below and press **Enter** to send the message to the API.
    """
)

# Initialize the chat history list
chat_history = []

# Input for the database URI
uri = st.text_input("Enter the RDS Database URI")

if st.button("Start Chat"):
    if not uri:
        st.warning("Please enter a valid database URI.")
    else:
        st.info("Connecting to the API and starting the chat...")
        chat_response = connect_with_db(uri)
        if "error" in chat_response:
            st.error("Error: Failed to start the chat. Please check the URI and try again.")
        else:
            st.success("Chat started successfully!")

# Chat with the API (a mock example)
st.subheader("Chat with the API")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# React to user input
if prompt := st.chat_input("What is up?"):
    # Display user message in chat message container
    st.chat_message("user").markdown(prompt)
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})

    # response = f"Echo: {prompt}"
    response = send_message(prompt)["message"]
    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(response)
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})

# Run the Streamlit app
if __name__ == "__main__":
    st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")

프롬프트를 단축하는 기본 아이디어는 프롬프트에서 사용자의 쿼리와 관련된 테이블 및 열 이름만 전송하는 것입니다. 

이를 위해 테이블 및 열 이름의 임베딩을 생성하고 사용자의 메시지와 가장 관련성이 높은 이름을 즉석에서 검색하여 프롬프트에 전달할 수 있습니다.

이 글에서는 벡터 데이터베이스로 ChromaDB를 사용하지만 Pinecone, Milvus 또는 다른 데이터베이스를 사용할 수 있습니다. 이제 chromadb를 설치해 보겠습니다.

pip install chromadb

먼저 테이블 및 열 이름의 임베딩을 저장할 CSV와 함께 벡터라는 이름의 폴더를 하나 더 만들고, 다른 테이블을 조인하는 외래 키가 무엇인지, 어디 절에 들어갈 수 있는 값의 일부와 같은 데이터베이스에 대한 다른 정보도 저장할 수 있습니다.

 


def create_vectors(filename, persist_directory):
    loader = CSVLoader(file_path=filename, encoding="utf8")
    data = loader.load()
    vectordb = Chroma.from_documents(data, embedding=embeddings, persist_directory=persist_directory)
    vectordb.persist()

또한 사용자의 쿼리에 테이블에 대한 정보가 필요한지 아니면 데이터베이스의 일반적인 스키마에 대해서만 질문하는 것인지 먼저 확인합니다.

 


def check_if_users_query_want_general_schema_information_or_sql(query):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        
                        f"In the text given text user is asking a question about database "
                        f"Figure out whether user wants information about database schema or wants to write a SQL query"
                        f"Answer 'yes' if user wants information about database schema and 'no' if user wants to write a SQL query"
                        
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
        
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content

사용자가 원하는 것에 따라 예 또는 아니요로 대답합니다. 예라고 대답하면 다음과 같은 프롬프트가 생성됩니다.

 


def prompt_when_user_want_general_db_information(query, db_uri):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        "You are an assistant who writes SQL queries."
                        "Given the text below, write a SQL query that answers the user's question."
                        "Prepend and append the SQL query with three backticks '```'"
                        "Write select query whenever possible"
                        f"Connection string to this database is {db_uri}"
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content

다음으로 대답이 '아니요'인 경우 사용자의 쿼리에 매우 구체적으로 테이블의 테이블과 열 이름이 필요하다는 뜻입니다.
이를 위해 먼저 가장 관련성이 높은 테이블과 열을 검색하고 프롬프트에 추가할 문자열을 생성합니다.
벡터가 생성되고 다른 모든 것이 정상적으로 실행되는지 확인해 보겠습니다. 

지금까지 작성한 코드는 다음과 같습니다.

import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate

from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader




folders_to_create = ['csvs', 'vectors']
# Check and create folders if they don't exist
for folder_name in folders_to_create:
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
        print(f"Folder '{folder_name}' created.")
    else:
        print(f"Folder '{folder_name}' already exists.")




## load the API key from the environment variable
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")


llm = OpenAI(openai_api_key=openai_api_key)
chat_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)





def get_basic_table_details(cursor):
    cursor.execute("""SELECT
            c.table_name,
            c.column_name,
            c.data_type
        FROM
            information_schema.columns c
        WHERE
            c.table_name IN (
                SELECT tablename
                FROM pg_tables
                WHERE schemaname = 'public'
    );""")
    tables_and_columns = cursor.fetchall()
    return tables_and_columns

def create_vectors(filename, persist_directory):
    loader = CSVLoader(file_path=filename, encoding="utf8")
    data = loader.load()
    vectordb = Chroma.from_documents(data, embedding=embeddings, persist_directory=persist_directory)
    vectordb.persist()





def save_db_details(db_uri):

    unique_id = str(uuid4()).replace("-", "_")
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()

    tables_and_columns = get_basic_table_details(cursor)

    ## Get all the tables and columns and enter them in a pandas dataframe
    df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df.to_csv(filename_t, index=False)

    create_vectors(filename_t, "./vectors/tables_"+ unique_id)

    cursor.close()
    connection.close()

    return unique_id







def generate_template_for_sql(query, table_info, db_uri):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        f"You are an assistant that can write SQL Queries."
                        f"Given the text below, write a SQL query that answers the user's question."
                        f"DB connection string is {db_uri}"
                        f"Here is a detailed description of the table(s): "
                        f"{table_info}"
                        "Prepend and append the SQL query with three backticks '```'"
                        
                        
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    return answer.content


def check_if_users_query_want_general_schema_information_or_sql(query):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        
                        f"In the text given text user is asking a question about database "
                        f"Figure out whether user wants information about database schema or wants to write a SQL query"
                        f"Answer 'yes' if user wants information about database schema and 'no' if user wants to write a SQL query"
                        
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
        
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content


def prompt_when_user_want_general_db_information(query, db_uri):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        "You are an assistant who writes SQL queries."
                        "Given the text below, write a SQL query that answers the user's question."
                        "Prepend and append the SQL query with three backticks '```'"
                        "Write select query whenever possible"
                        f"Connection string to this database is {db_uri}"
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content



    




def get_the_output_from_llm(query, unique_id, db_uri):
    ## Load the tables csv
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df = pd.read_csv(filename_t)

    ## For each relevant table create a string that list down all the columns and their data types
    table_info = ''
    for table in df['table_name']:
        table_info += 'Information about table' + table + ':\n'
        table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'

    answer_to_question_general_schema = check_if_users_query_want_general_schema_information_or_sql(query)
    if answer_to_question_general_schema == "yes":
        return prompt_when_user_want_general_db_information(query, db_uri)

    return generate_template_for_sql(query, table_info, db_uri)





def execute_the_solution(solution, db_uri):
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()
    _,final_query,_ = solution.split("```") 
    final_query = final_query.strip('sql')
    cursor.execute(final_query)
    result = cursor.fetchall()
    return str(result)





# Function to establish connection and read metadata for the database
def connect_with_db(uri):
    st.session_state.db_uri = uri
    st.session_state.unique_id = save_db_details(uri)

    return {"message": "Connection established to Database!"}

# Function to call the API with the provided URI
def send_message(message):
    solution = get_the_output_from_llm(message, st.session_state.unique_id, st.session_state.db_uri)
    result = execute_the_solution(solution, st.session_state.db_uri)
    return {"message": solution + "\n\n" + "Result:\n" + result}



# ## Instructions
st.subheader("Instructions")
st.markdown(
    """
    1. Enter the URI of your RDS Database in the text box below.
    2. Click the **Start Chat** button to start the chat.
    3. Enter your message in the text box below and press **Enter** to send the message to the API.
    """
)

# Initialize the chat history list
chat_history = []

# Input for the database URI
uri = st.text_input("Enter the RDS Database URI")

if st.button("Start Chat"):
    if not uri:
        st.warning("Please enter a valid database URI.")
    else:
        st.info("Connecting to the API and starting the chat...")
        chat_response = connect_with_db(uri)
        if "error" in chat_response:
            st.error("Error: Failed to start the chat. Please check the URI and try again.")
        else:
            st.success("Chat started successfully!")

# Chat with the API (a mock example)
st.subheader("Chat with the API")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# React to user input
if prompt := st.chat_input("What is up?"):
    # Display user message in chat message container
    st.chat_message("user").markdown(prompt)
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})

    # response = f"Echo: {prompt}"
    response = send_message(prompt)["message"]
    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(response)
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})

# Run the Streamlit app
if __name__ == "__main__":
    st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")

벡터가 생성되고 다른 모든 것이 정상적으로 실행되는지 확인

이제 다음 단계에서는 마지막으로 가장 관련성이 높은 테이블의 벡터 검색을 수행하겠습니다. 

관련 테이블의 경우 프롬프트에 더 많은 컨텍스트를 제공하기 위해 모든 열을 가져올 것입니다. 마지막으로 이 정보에서 프롬프트에 전달할 문자열을 생성합니다.

vectordb = Chroma(embedding_function=embeddings, persist_directory="./vectors/tables_"+ unique_id)
retriever = vectordb.as_retriever()
docs = retriever.get_relevant_documents(query)
print(docs)

relevant_tables = []
relevant_tables_and_columns = []


for doc in docs:
    table_name, column_name, data_type = doc.page_content.split("\n")
    table_name= table_name.split(":")[1].strip()
    relevant_tables.append(table_name)
    column_name = column_name.split(":")[1].strip()
    data_type = data_type.split(":")[1].strip()
    relevant_tables_and_columns.append((table_name, column_name, data_type))


## Load the tables csv
filename_t = 'csvs/tables_' + unique_id + '.csv'
df = pd.read_csv(filename_t)

## For each relevant table create a string that list down all the columns and their data types
table_info = ''
for table in relevant_tables:
    table_info += 'Information about table' + table + ':\n'
    table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'
def generate_template_for_sql(query, relevant_tables, table_info):
    tables = ",".join(relevant_tables)
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        f"You are an assistant that can write SQL Queries."
                        f"Given the text below, write a SQL query that answers the user's question."
                        f"Assume that there is/are SQL table(s) named '{tables}' "
                        f"Here is a more detailed description of the table(s): "
                        f"{table_info}"
                        "Prepend and append the SQL query with three backticks '```'"
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content

최종 완성 코드본 

Here is the complete code.

import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate

from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader




folders_to_create = ['csvs', 'vectors']
# Check and create folders if they don't exist
for folder_name in folders_to_create:
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
        print(f"Folder '{folder_name}' created.")
    else:
        print(f"Folder '{folder_name}' already exists.")




## load the API key from the environment variable
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")


llm = OpenAI(openai_api_key=openai_api_key)
chat_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)





def get_basic_table_details(cursor):
    cursor.execute("""SELECT
            c.table_name,
            c.column_name,
            c.data_type
        FROM
            information_schema.columns c
        WHERE
            c.table_name IN (
                SELECT tablename
                FROM pg_tables
                WHERE schemaname = 'public'
    );""")
    tables_and_columns = cursor.fetchall()
    return tables_and_columns

def create_vectors(filename, persist_directory):
    loader = CSVLoader(file_path=filename, encoding="utf8")
    data = loader.load()
    vectordb = Chroma.from_documents(data, embedding=embeddings, persist_directory=persist_directory)
    vectordb.persist()





def save_db_details(db_uri):

    unique_id = str(uuid4()).replace("-", "_")
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()

    tables_and_columns = get_basic_table_details(cursor)

    ## Get all the tables and columns and enter them in a pandas dataframe
    df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df.to_csv(filename_t, index=False)

    create_vectors(filename_t, "./vectors/tables_"+ unique_id)

    cursor.close()
    connection.close()

    return unique_id


# def generate_template_for_sql(query, table_info, db_uri):
#     template = ChatPromptTemplate.from_messages(
#             [
#                 SystemMessage(
#                     content=(
#                         f"You are an assistant that can write SQL Queries."
#                         f"Given the text below, write a SQL query that answers the user's question."
#                         f"DB connection string is {db_uri}"
#                         f"Here is a detailed description of the table(s): "
#                         f"{table_info}"
#                         "Prepend and append the SQL query with three backticks '```'"
                        
                        
#                     )
#                 ),
#                 HumanMessagePromptTemplate.from_template("{text}"),

#             ]
#         )
    
#     answer = chat_llm(template.format_messages(text=query))
#     return answer.content



def generate_template_for_sql(query, relevant_tables, table_info):
    tables = ",".join(relevant_tables)
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        f"You are an assistant that can write SQL Queries."
                        f"Given the text below, write a SQL query that answers the user's question."
                        f"Assume that there is/are SQL table(s) named '{tables}' "
                        f"Here is a more detailed description of the table(s): "
                        f"{table_info}"
                        "Prepend and append the SQL query with three backticks '```'"
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content



def check_if_users_query_want_general_schema_information_or_sql(query):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        
                        f"In the text given text user is asking a question about database "
                        f"Figure out whether user wants information about database schema or wants to write a SQL query"
                        f"Answer 'yes' if user wants information about database schema and 'no' if user wants to write a SQL query"
                        
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
        
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content


def prompt_when_user_want_general_db_information(query, db_uri):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        "You are an assistant who writes SQL queries."
                        "Given the text below, write a SQL query that answers the user's question."
                        "Prepend and append the SQL query with three backticks '```'"
                        "Write select query whenever possible"
                        f"Connection string to this database is {db_uri}"
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content






def get_the_output_from_llm(query, unique_id, db_uri):
    ## Load the tables csv
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df = pd.read_csv(filename_t)

    ## For each relevant table create a string that list down all the columns and their data types
    table_info = ''
    for table in df['table_name']:
        table_info += 'Information about table' + table + ':\n'
        table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'

    answer_to_question_general_schema = check_if_users_query_want_general_schema_information_or_sql(query)
    if answer_to_question_general_schema == "yes":
        return prompt_when_user_want_general_db_information(query, db_uri)
    else:
        vectordb = Chroma(embedding_function=embeddings, persist_directory="./vectors/tables_"+ unique_id)
        retriever = vectordb.as_retriever()
        docs = retriever.get_relevant_documents(query)
        print(docs)

        relevant_tables = []
        relevant_tables_and_columns = []

        for doc in docs:
            table_name, column_name, data_type = doc.page_content.split("\n")
            table_name= table_name.split(":")[1].strip()
            relevant_tables.append(table_name)
            column_name = column_name.split(":")[1].strip()
            data_type = data_type.split(":")[1].strip()
            relevant_tables_and_columns.append((table_name, column_name, data_type))

        ## Load the tables csv
        filename_t = 'csvs/tables_' + unique_id + '.csv'
        df = pd.read_csv(filename_t)

        ## For each relevant table create a string that list down all the columns and their data types
        table_info = ''
        for table in relevant_tables:
            table_info += 'Information about table' + table + ':\n'
            table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'
        return generate_template_for_sql(query, relevant_tables, table_info)





def execute_the_solution(solution, db_uri):
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()
    _,final_query,_ = solution.split("```") 
    final_query = final_query.strip('sql')
    cursor.execute(final_query)
    result = cursor.fetchall()
    return str(result)





# Function to establish connection and read metadata for the database
def connect_with_db(uri):
    st.session_state.db_uri = uri
    st.session_state.unique_id = save_db_details(uri)

    return {"message": "Connection established to Database!"}

# Function to call the API with the provided URI
def send_message(message):
    solution = get_the_output_from_llm(message, st.session_state.unique_id, st.session_state.db_uri)
    result = execute_the_solution(solution, st.session_state.db_uri)
    return {"message": solution + "\n\n" + "Result:\n" + result}



# ## Instructions
st.subheader("Instructions")
st.markdown(
    """
    1. Enter the URI of your RDS Database in the text box below.
    2. Click the **Start Chat** button to start the chat.
    3. Enter your message in the text box below and press **Enter** to send the message to the API.
    """
)

# Initialize the chat history list
chat_history = []

# Input for the database URI
uri = st.text_input("Enter the RDS Database URI")

if st.button("Start Chat"):
    if not uri:
        st.warning("Please enter a valid database URI.")
    else:
        st.info("Connecting to the API and starting the chat...")
        chat_response = connect_with_db(uri)
        if "error" in chat_response:
            st.error("Error: Failed to start the chat. Please check the URI and try again.")
        else:
            st.success("Chat started successfully!")

# Chat with the API (a mock example)
st.subheader("Chat with the API")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# React to user input
if prompt := st.chat_input("What is up?"):
    # Display user message in chat message container
    st.chat_message("user").markdown(prompt)
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})

    # response = f"Echo: {prompt}"
    response = send_message(prompt)["message"]
    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(response)
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})

# Run the Streamlit app
if __name__ == "__main__":
    st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")

One final thing we can do is to give information about foreign keys to the prompt.

import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate

from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader




folders_to_create = ['csvs', 'vectors']
# Check and create folders if they don't exist
for folder_name in folders_to_create:
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
        print(f"Folder '{folder_name}' created.")
    else:
        print(f"Folder '{folder_name}' already exists.")




## load the API key from the environment variable
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")


llm = OpenAI(openai_api_key=openai_api_key)
chat_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)





def get_basic_table_details(cursor):
    cursor.execute("""SELECT
            c.table_name,
            c.column_name,
            c.data_type
        FROM
            information_schema.columns c
        WHERE
            c.table_name IN (
                SELECT tablename
                FROM pg_tables
                WHERE schemaname = 'public'
    );""")
    tables_and_columns = cursor.fetchall()
    return tables_and_columns



def get_foreign_key_info(cursor):
    query_for_foreign_keys = """SELECT
    conrelid::regclass AS table_name,
    conname AS foreign_key,
    pg_get_constraintdef(oid) AS constraint_definition,
    confrelid::regclass AS referred_table,
    array_agg(a2.attname) AS referred_columns
    FROM
        pg_constraint
    JOIN
        pg_attribute AS a1 ON conrelid = a1.attrelid AND a1.attnum = ANY(conkey)
    JOIN
        pg_attribute AS a2 ON confrelid = a2.attrelid AND a2.attnum = ANY(confkey)
    WHERE
        contype = 'f'
        AND connamespace = 'public'::regnamespace
    GROUP BY
        conrelid, conname, oid, confrelid
    ORDER BY
        conrelid::regclass::text, contype DESC;
    """

    cursor.execute(query_for_foreign_keys)
    foreign_keys = cursor.fetchall()

    return foreign_keys




def create_vectors(filename, persist_directory):
    loader = CSVLoader(file_path=filename, encoding="utf8")
    data = loader.load()
    vectordb = Chroma.from_documents(data, embedding=embeddings, persist_directory=persist_directory)
    vectordb.persist()





def save_db_details(db_uri):

    unique_id = str(uuid4()).replace("-", "_")
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()

    tables_and_columns = get_basic_table_details(cursor)

    ## Get all the tables and columns and enter them in a pandas dataframe
    df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df.to_csv(filename_t, index=False)

    create_vectors(filename_t, "./vectors/tables_"+ unique_id)

    ## Get all the foreign keys and enter them in a pandas dataframe
    foreign_keys = get_foreign_key_info(cursor)
    df = pd.DataFrame(foreign_keys, columns=['table_name', 'foreign_key', 'foreign_key_details', 'referred_table', 'referred_columns'])
    filename_fk = 'csvs/foreign_keys_' + unique_id + '.csv'
    df.to_csv(filename_fk, index=False)

    cursor.close()
    connection.close()

    return unique_id


# def generate_template_for_sql(query, table_info, db_uri):
#     template = ChatPromptTemplate.from_messages(
#             [
#                 SystemMessage(
#                     content=(
#                         f"You are an assistant that can write SQL Queries."
#                         f"Given the text below, write a SQL query that answers the user's question."
#                         f"DB connection string is {db_uri}"
#                         f"Here is a detailed description of the table(s): "
#                         f"{table_info}"
#                         "Prepend and append the SQL query with three backticks '```'"
                        
                        
#                     )
#                 ),
#                 HumanMessagePromptTemplate.from_template("{text}"),

#             ]
#         )
    
#     answer = chat_llm(template.format_messages(text=query))
#     return answer.content



def generate_template_for_sql(query, relevant_tables, table_info, foreign_key_info):
    tables = ",".join(relevant_tables)
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        f"You are an assistant that can write SQL Queries."
                        f"Given the text below, write a SQL query that answers the user's question."
                        f"Assume that there is/are SQL table(s) named '{tables}' "
                        f"Here is a more detailed description of the table(s): "
                        f"{table_info}"
                        "Here is some information about some relevant foreign keys:"
                        f"{foreign_key_info}"
                        "Prepend and append the SQL query with three backticks '```'"
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content



def check_if_users_query_want_general_schema_information_or_sql(query):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        
                        f"In the text given text user is asking a question about database "
                        f"Figure out whether user wants information about database schema or wants to write a SQL query"
                        f"Answer 'yes' if user wants information about database schema and 'no' if user wants to write a SQL query"
                        
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
        
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content


def prompt_when_user_want_general_db_information(query, db_uri):
    template = ChatPromptTemplate.from_messages(
            [
                SystemMessage(
                    content=(
                        "You are an assistant who writes SQL queries."
                        "Given the text below, write a SQL query that answers the user's question."
                        "Prepend and append the SQL query with three backticks '```'"
                        "Write select query whenever possible"
                        f"Connection string to this database is {db_uri}"
                    )
                ),
                HumanMessagePromptTemplate.from_template("{text}"),

            ]
        )
    
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content






def get_the_output_from_llm(query, unique_id, db_uri):
    ## Load the tables csv
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df = pd.read_csv(filename_t)

    ## For each relevant table create a string that list down all the columns and their data types
    table_info = ''
    for table in df['table_name']:
        table_info += 'Information about table' + table + ':\n'
        table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'

    answer_to_question_general_schema = check_if_users_query_want_general_schema_information_or_sql(query)
    if answer_to_question_general_schema == "yes":
        return prompt_when_user_want_general_db_information(query, db_uri)
    else:
        vectordb = Chroma(embedding_function=embeddings, persist_directory="./vectors/tables_"+ unique_id)
        retriever = vectordb.as_retriever()
        docs = retriever.get_relevant_documents(query)
        print(docs)

        relevant_tables = []
        relevant_tables_and_columns = []

        for doc in docs:
            table_name, column_name, data_type = doc.page_content.split("\n")
            table_name= table_name.split(":")[1].strip()
            relevant_tables.append(table_name)
            column_name = column_name.split(":")[1].strip()
            data_type = data_type.split(":")[1].strip()
            relevant_tables_and_columns.append((table_name, column_name, data_type))

        ## Load the tables csv
        filename_t = 'csvs/tables_' + unique_id + '.csv'
        df = pd.read_csv(filename_t)

        ## For each relevant table create a string that list down all the columns and their data types
        table_info = ''
        for table in relevant_tables:
            table_info += 'Information about table' + table + ':\n'
            table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'


        
        ## Load the foreign keys csv
        filename_fk = 'csvs/foreign_keys_' + unique_id + '.csv'
        df_fk = pd.read_csv(filename_fk)
        ## If table from relevant_tables above lies in refered_table or table_name in df_fk, then add the foreign key details to a string
        foreign_key_info = ''
        for i, series in df_fk.iterrows():
            if series['table_name'] in relevant_tables:
                text = table + ' has a foreign key ' + series['foreign_key'] + ' which refers to table ' + series['referred_table'] + ' and column(s) ' + series['referred_columns']
                foreign_key_info += text + '\n\n' 
            if series['referred_table'] in relevant_tables:
                text = table + ' is referred to by table ' + series['table_name'] + ' via foreign key ' + series['foreign_key'] + ' and column(s) ' + series['referred_columns']
                foreign_key_info += text + '\n\n'

    
        return generate_template_for_sql(query, relevant_tables, table_info, foreign_key_info)





def execute_the_solution(solution, db_uri):
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()
    _,final_query,_ = solution.split("```") 
    final_query = final_query.strip('sql')
    cursor.execute(final_query)
    result = cursor.fetchall()
    return str(result)





# Function to establish connection and read metadata for the database
def connect_with_db(uri):
    st.session_state.db_uri = uri
    st.session_state.unique_id = save_db_details(uri)

    return {"message": "Connection established to Database!"}

# Function to call the API with the provided URI
def send_message(message):
    solution = get_the_output_from_llm(message, st.session_state.unique_id, st.session_state.db_uri)
    result = execute_the_solution(solution, st.session_state.db_uri)
    return {"message": solution + "\n\n" + "Result:\n" + result}



# ## Instructions
st.subheader("Instructions")
st.markdown(
    """
    1. Enter the URI of your RDS Database in the text box below.
    2. Click the **Start Chat** button to start the chat.
    3. Enter your message in the text box below and press **Enter** to send the message to the API.
    """
)

# Initialize the chat history list
chat_history = []

# Input for the database URI
uri = st.text_input("Enter the RDS Database URI")

if st.button("Start Chat"):
    if not uri:
        st.warning("Please enter a valid database URI.")
    else:
        st.info("Connecting to the API and starting the chat...")
        chat_response = connect_with_db(uri)
        if "error" in chat_response:
            st.error("Error: Failed to start the chat. Please check the URI and try again.")
        else:
            st.success("Chat started successfully!")

# Chat with the API (a mock example)
st.subheader("Chat with the API")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# React to user input
if prompt := st.chat_input("What is up?"):
    # Display user message in chat message container
    st.chat_message("user").markdown(prompt)
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})

    # response = f"Echo: {prompt}"
    response = send_message(prompt)["message"]
    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(response)
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})

# Run the Streamlit app
if __name__ == "__main__":
    st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")

In similar ways we can keep enhancing this application by adding fallbacks. In each fallback we can keep adding additional information.

 
 
반응형

댓글