An AI application that can chat with any SQL database.
https://systemdesigner.medium.com/an-ai-application-that-can-chat-with-any-sql-database-71099a0c82ef
SQL 데이터베이스와 채팅할 수 있나요? 이 튜토리얼에서는 파이썬을 사용해 이 애플리케이션을 만들어 보겠습니다. 프론트엔드에는 스트림라이트를, AI 통합에는 랭체인을 사용하겠습니다.
첫 번째 단계는 가상 환경을 만들고 의존성을 설치하는 것입니다. 또한 환경 변수를 저장하기 위해 프로젝트 디렉토리와 app.py 및 .env라는 이름의 루트 파일을 생성합니다. 이 프로젝트에는 OPENAI_API_KEY만 필요합니다. 자신의 것으로 바꾸세요.
virtualenv chatdb
source chatdb/bin/activate
pip install langchain openai sqlalchemy streamlit python-dotenv
mkdir chatdb
cd chatdb
touch app.py
touch .env
## .env
OPENAI_API_KEY=sk-NcGHMSIv3POeMXAEf..............
이제 앱 작성을 시작하겠습니다. 아래의 각 코드 스니펫은 독립적으로 실행할 수 있습니다.
먼저 기본 채팅 인터페이스를 만들어 보겠습니다. 이 인터페이스는 데이터베이스 URI를 받는 입력으로 구성됩니다. 이를 통해 선택한 데이터베이스에 연결하고 채팅을 시작합니다.
import streamlit as st
import requests
import os
# Function to establish connection and read metadata for the database
def connect_with_db(uri):
st.session_state.db_uri = uri
return {"message": "Connection established to Database!"}
# Function to call the API with the provided URI
def send_message(message):
return {"message": message}
# ## Instructions
st.subheader("Instructions")
st.markdown(
"""
1. Enter the URI of your RDS Database in the text box below.
2. Click the **Start Chat** button to start the chat.
3. Enter your message in the text box below and press **Enter** to send the message to the API.
"""
)
# Initialize the chat history list
chat_history = []
# Input for the database URI
uri = st.text_input("Enter the RDS Database URI")
if st.button("Start Chat"):
if not uri:
st.warning("Please enter a valid database URI.")
else:
st.info("Connecting to the API and starting the chat...")
chat_response = connect_with_db(uri)
if "error" in chat_response:
st.error("Error: Failed to start the chat. Please check the URI and try again.")
else:
st.success("Chat started successfully!")
# Chat with the API (a mock example)
st.subheader("Chat with the API")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("What is up?"):
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# response = f"Echo: {prompt}"
response = send_message(prompt)["message"]
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
# Run the Streamlit app
if __name__ == "__main__":
st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")
이제 데이터베이스에 연결하여 테이블에 대한 기본 정보를 가져오는 함수를 작성해 보겠습니다.
이 시스템은 기본적으로 관련 SQL 쿼리를 생성하고 실행하는 방식으로 작동합니다.
ChatGPT를 통해 SQL 쿼리를 만들려면 관련 테이블 이름, 관련 열 이름 및 where 절에 들어가는 관련 값을 포함할 수 있는 관련 컨텍스트를 제공해야 합니다. 여기에서도 동일한 작업을 자동화해 보겠습니다. 데이터베이스에 연결하자마자 사용자 프롬프트에 모든 테이블과 열을 전달할 수 있습니다.
아래 코드에서 채팅 시작을 클릭하면 시스템이 데이터베이스의 정보를 읽고 CSV로 저장합니다.
고유 ID가 있는 CSV를 저장할 폴더를 csvs로 만듭니다. 그리고 이 고유 ID를 스트림릿 세션에 저장하여 나중에 사용자가 쿼리를 요청할 때 올바른 CSV를 확인할 수 있도록 합니다.
import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2
folders_to_create = ['csvs']
# Check and create folders if they don't exist
for folder_name in folders_to_create:
if not os.path.exists(folder_name):
os.makedirs(folder_name)
print(f"Folder '{folder_name}' created.")
else:
print(f"Folder '{folder_name}' already exists.")
def get_basic_table_details(cursor):
cursor.execute("""SELECT
c.table_name,
c.column_name,
c.data_type
FROM
information_schema.columns c
WHERE
c.table_name IN (
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
);""")
tables_and_columns = cursor.fetchall()
return tables_and_columns
def save_db_details(db_uri):
unique_id = str(uuid4()).replace("-", "_")
connection = psycopg2.connect(db_uri)
cursor = connection.cursor()
tables_and_columns = get_basic_table_details(cursor)
## Get all the tables and columns and enter them in a pandas dataframe
df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
filename_t = 'csvs/tables_' + unique_id + '.csv'
df.to_csv(filename_t, index=False)
cursor.close()
connection.close()
return unique_id
# Function to establish connection and read metadata for the database
def connect_with_db(uri):
st.session_state.unique_id = save_db_details(uri)
return {"message": "Connection established to Database!"}
# Function to call the API with the provided URI
def send_message(message):
return {"message": message}
# ## Instructions
st.subheader("Instructions")
st.markdown(
"""
1. Enter the URI of your RDS Database in the text box below.
2. Click the **Start Chat** button to start the chat.
3. Enter your message in the text box below and press **Enter** to send the message to the API.
"""
)
# Initialize the chat history list
chat_history = []
# Input for the database URI
uri = st.text_input("Enter the RDS Database URI")
if st.button("Start Chat"):
if not uri:
st.warning("Please enter a valid database URI.")
else:
st.info("Connecting to the API and starting the chat...")
chat_response = connect_with_db(uri)
if "error" in chat_response:
st.error("Error: Failed to start the chat. Please check the URI and try again.")
else:
st.success("Chat started successfully!")
# Chat with the API (a mock example)
st.subheader("Chat with the API")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("What is up?"):
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# response = f"Echo: {prompt}"
response = send_message(prompt)["message"]
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
# Run the Streamlit app
if __name__ == "__main__":
st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")
이제 사용자의 쿼리를 받아 데이터베이스 메타데이터와 결합한 후 결과 SQL 쿼리를 위해 LLM에 전달하는 함수를 정의하겠습니다.
import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate
from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
folders_to_create = ['csvs']
# Check and create folders if they don't exist
for folder_name in folders_to_create:
if not os.path.exists(folder_name):
os.makedirs(folder_name)
print(f"Folder '{folder_name}' created.")
else:
print(f"Folder '{folder_name}' already exists.")
## load the API key from the environment variable
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
llm = OpenAI(openai_api_key=openai_api_key)
chat_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
def get_basic_table_details(cursor):
cursor.execute("""SELECT
c.table_name,
c.column_name,
c.data_type
FROM
information_schema.columns c
WHERE
c.table_name IN (
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
);""")
tables_and_columns = cursor.fetchall()
return tables_and_columns
def save_db_details(db_uri):
unique_id = str(uuid4()).replace("-", "_")
connection = psycopg2.connect(db_uri)
cursor = connection.cursor()
tables_and_columns = get_basic_table_details(cursor)
## Get all the tables and columns and enter them in a pandas dataframe
df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
filename_t = 'csvs/tables_' + unique_id + '.csv'
df.to_csv(filename_t, index=False)
cursor.close()
connection.close()
return unique_id
def generate_template_for_sql(query, table_info, db_uri):
template = ChatPromptTemplate.from_messages(
[
SystemMessage(
content=(
f"You are an assistant that can write SQL Queries."
f"Given the text below, write a SQL query that answers the user's question."
f"DB connection string is {db_uri}"
f"Here is a detailed description of the table(s): "
f"{table_info}"
"Prepend and append the SQL query with three backticks '```'"
)
),
HumanMessagePromptTemplate.from_template("{text}"),
]
)
answer = chat_llm(template.format_messages(text=query))
return answer.content
def get_the_output_from_llm(query, unique_id, db_uri):
## Load the tables csv
filename_t = 'csvs/tables_' + unique_id + '.csv'
df = pd.read_csv(filename_t)
## For each relevant table create a string that list down all the columns and their data types
table_info = ''
for table in df['table_name']:
table_info += 'Information about table' + table + ':\n'
table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'
return generate_template_for_sql(query, table_info, db_uri)
# Function to establish connection and read metadata for the database
def connect_with_db(uri):
st.session_state.db_uri = uri
st.session_state.unique_id = save_db_details(uri)
return {"message": "Connection established to Database!"}
# Function to call the API with the provided URI
def send_message(message):
return {"message": get_the_output_from_llm(message, st.session_state.unique_id, st.session_state.db_uri)}
# ## Instructions
st.subheader("Instructions")
st.markdown(
"""
1. Enter the URI of your RDS Database in the text box below.
2. Click the **Start Chat** button to start the chat.
3. Enter your message in the text box below and press **Enter** to send the message to the API.
"""
)
# Initialize the chat history list
chat_history = []
# Input for the database URI
uri = st.text_input("Enter the RDS Database URI")
if st.button("Start Chat"):
if not uri:
st.warning("Please enter a valid database URI.")
else:
st.info("Connecting to the API and starting the chat...")
chat_response = connect_with_db(uri)
if "error" in chat_response:
st.error("Error: Failed to start the chat. Please check the URI and try again.")
else:
st.success("Chat started successfully!")
# Chat with the API (a mock example)
st.subheader("Chat with the API")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("What is up?"):
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# response = f"Echo: {prompt}"
response = send_message(prompt)["message"]
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
# Run the Streamlit app
if __name__ == "__main__":
st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")
Finally, we run our generated SQL query and display those in the result.
import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate
from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
folders_to_create = ['csvs']
# Check and create folders if they don't exist
for folder_name in folders_to_create:
if not os.path.exists(folder_name):
os.makedirs(folder_name)
print(f"Folder '{folder_name}' created.")
else:
print(f"Folder '{folder_name}' already exists.")
## load the API key from the environment variable
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
llm = OpenAI(openai_api_key=openai_api_key)
chat_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
def get_basic_table_details(cursor):
cursor.execute("""SELECT
c.table_name,
c.column_name,
c.data_type
FROM
information_schema.columns c
WHERE
c.table_name IN (
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
);""")
tables_and_columns = cursor.fetchall()
return tables_and_columns
def save_db_details(db_uri):
unique_id = str(uuid4()).replace("-", "_")
connection = psycopg2.connect(db_uri)
cursor = connection.cursor()
tables_and_columns = get_basic_table_details(cursor)
## Get all the tables and columns and enter them in a pandas dataframe
df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
filename_t = 'csvs/tables_' + unique_id + '.csv'
df.to_csv(filename_t, index=False)
cursor.close()
connection.close()
return unique_id
def generate_template_for_sql(query, table_info, db_uri):
template = ChatPromptTemplate.from_messages(
[
SystemMessage(
content=(
f"You are an assistant that can write SQL Queries."
f"Given the text below, write a SQL query that answers the user's question."
f"DB connection string is {db_uri}"
f"Here is a detailed description of the table(s): "
f"{table_info}"
"Prepend and append the SQL query with three backticks '```'"
)
),
HumanMessagePromptTemplate.from_template("{text}"),
]
)
answer = chat_llm(template.format_messages(text=query))
return answer.content
def get_the_output_from_llm(query, unique_id, db_uri):
## Load the tables csv
filename_t = 'csvs/tables_' + unique_id + '.csv'
df = pd.read_csv(filename_t)
## For each relevant table create a string that list down all the columns and their data types
table_info = ''
for table in df['table_name']:
table_info += 'Information about table' + table + ':\n'
table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'
return generate_template_for_sql(query, table_info, db_uri)
def execute_the_solution(solution, db_uri):
connection = psycopg2.connect(db_uri)
cursor = connection.cursor()
_,final_query,_ = solution.split("```")
final_query = final_query.strip('sql')
cursor.execute(final_query)
result = cursor.fetchall()
return str(result)
# Function to establish connection and read metadata for the database
def connect_with_db(uri):
st.session_state.db_uri = uri
st.session_state.unique_id = save_db_details(uri)
return {"message": "Connection established to Database!"}
# Function to call the API with the provided URI
def send_message(message):
solution = get_the_output_from_llm(message, st.session_state.unique_id, st.session_state.db_uri)
result = execute_the_solution(solution, st.session_state.db_uri)
return {"message": solution + "\n\n" + "Result:\n" + result}
# ## Instructions
st.subheader("Instructions")
st.markdown(
"""
1. Enter the URI of your RDS Database in the text box below.
2. Click the **Start Chat** button to start the chat.
3. Enter your message in the text box below and press **Enter** to send the message to the API.
"""
)
# Initialize the chat history list
chat_history = []
# Input for the database URI
uri = st.text_input("Enter the RDS Database URI")
if st.button("Start Chat"):
if not uri:
st.warning("Please enter a valid database URI.")
else:
st.info("Connecting to the API and starting the chat...")
chat_response = connect_with_db(uri)
if "error" in chat_response:
st.error("Error: Failed to start the chat. Please check the URI and try again.")
else:
st.success("Chat started successfully!")
# Chat with the API (a mock example)
st.subheader("Chat with the API")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("What is up?"):
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# response = f"Echo: {prompt}"
response = send_message(prompt)["message"]
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
# Run the Streamlit app
if __name__ == "__main__":
st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")
위의 스크린샷에서 보시다시피 SQL과 SQL 결과를 볼 수 있습니다. 이 방법은 데이터베이스가 작을 때는 잘 작동하지만, 데이터베이스가 정말 크다면 어떨까요? 프롬프트에 모든 테이블과 열의 목록을 전달하려고 하면 openAI의 API가 설정한 토큰 제한을 초과할 수 있습니다.
이 문제를 해결하기 위해 검색 증강 생성의 힘을 사용할 수 있습니다. 기본적으로 모든 테이블과 열을 벡터 데이터베이스에 저장하고 가장 관련성이 높은 테이블과 열의 이름만 검색할 수 있습니다.
또한 전체 채팅 기록을 사용하여 프롬프트를 생성할 수 있는 채팅 메모리를 구현할 수도 있습니다.
다음 파트에서 이러한 사용 사례에 대해 설명하겠습니다.
이 튜토리얼의 두 번째 파트에서는 매우 큰 규모의 데이터베이스와 채팅하는 방법에 대해 자세히 설명합니다.
댓글