LanceDB和CLIP让多模态AI变得简单

2023年11月30日 由 alex 发表 622 0

本文介绍了3个使用CLIP并由LanceDB作为向量存储驱动的多模态应用。


  • 使用CLIP和LanceDB进行多模态搜索
  • 将其转换为Gradio应用程序
  • 多模态视频搜索


3


介绍


在上面的图片中,你可以看到CLIP模型(对比语言-图像预训练),它是在庞大的图像-文本对语料库上训练的。这就是我们在本篇文章中将要关注的模型。


演练


例子1:多模态搜索


这里我们从GitHub上加载CLIP模型。


%pip install pillow datasets lancedb
%pip install git+https://github.com/openai/CLIP.git


加载数据集


dataset = load_dataset("CVdatasets/ImageNet15_animals_unbalanced_aug1", split="train")


这个数据集仅用数字标记图像,这对我们来说并不容易理解。因此,我们将创建一个enum来将数字与类名映射起来。


#creating a class to map all the classes
class Animal(Enum):
    italian_greyhound = 0
    coyote = 1
    beagle = 2
    rottweiler = 3
    hyena = 4
    greater_swiss_mountain_dog = 5
    Triceratops = 6
    french_bulldog = 7
    red_wolf = 8
    egyption_cat = 9
    chihuahua = 10
    irish_terrier = 11
    tiger_cat = 12
    white_wolf = 13
    timber_wolf = 14


print(dataset[0])
print(Animal(dataset[0]['labels']).name)


我们将使用来自CLIP的32位精度预训练ViT(视觉变换器)。


import clip
import torch


#use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


我们将在这里创建一个图像嵌入函数,以便可以将其输入到LanceDB中。同时,我们希望嵌入是一个标准的列表,所以我们正在将张量数组转换成Numpy数组,然后转换成列表。


我们在这里使用encode_image函数来嵌入图像。


# embed the image
def embed(img):
    image = preprocess(img).unsqueeze(0).to(device)
    embs = model.encode_image(image)
    return embs.detach().cpu().numpy()[0].tolist()


我们将创建一个PyArrow模式,并将数据输入到LanceDB中。


# define a schema for the lancedb table
schema = pa.schema(
  [
      pa.field("vector", pa.list_(pa.float32(), 512)),
      pa.field("id", pa.int32()),
      pa.field("label", pa.int32()),
  ])
tbl = db.create_table("animal_images", schema=schema)


我们将数据添加到表格中。


import pyarrow as pa


#create the db with defined schema
db = lancedb.connect('./data/tables')
schema = pa.schema(
  [
      pa.field("vector", pa.list_(pa.float32(), 512)),
      pa.field("id", pa.int32()),
      pa.field("label", pa.int32()),
  ])
tbl = db.create_table("animal_images", schema=schema, mode="overwrite")
#append the data into the table
data = []
for i in tqdm(range(1, len(dataset))):
    data.append({'vector': embed(dataset[i]['img']), 'id': i, 'label': dataset[i]['labels']})


tbl.add(data)
#converting to pandas for better visibility
tbl.to_pandas()


4


现在要测试图像搜索功能,最好的做法是先用验证集来检查。


#load the dataset
test = load_dataset("CVdatasets/ImageNet15_animals_unbalanced_aug1", split="validation")


#display the data along with length
print(len(test))
print(test[100])
print(Animal(test[100]['labels']).name)
test[100]['img']


结果应该是这样的:


5


要搜索表格,我们可以使用这种方式:


  • 嵌入我们想要的图像
  • 调用搜索功能
  • 返回Pandas DataFrame。


embs = embed(test[100]['img'])


#search the db after embedding the question(image)
res = tbl.search(embs).limit(1).to_df()
res


6


我们也可以将所有内容放入一个函数中,以便更容易进行推断。


#creating an image search funtion
def image_search(id):
    print(Animal(test[id]['labels']).name)
    display(test[id]['img'])


  res = tbl.search(embed(test[id]['img'])).limit(5).to_df()
    print(res)
    for i in range(5):
        print(Animal(res['label'][i]).name)
        data_id = int(res['id'][i])
        display(dataset[data_id]['img'])


如果你已经走到这一步,我想祝贺你的耐心和奉献。现在,情况会变得更好。


我们可以开始多模态文本搜索了。


这里我们将使用 encode_text 函数而不是 encode_image。


#text embedding function
def embed_txt(txt):
    text = clip.tokenize([txt]).to(device)
    embs = model.encode_text(text)
    return embs.detach().cpu().numpy()[0].tolist()


#check the length of the embedded text
len(embed_txt("Black and white dog"))


搜索表


#search through the database
res = tbl.search(embed_txt("a french_bulldog ")).limit(1).to_df()
res


print(Animal(res['label'][0]).name)
data_id = int(res['id'][0])
display(dataset[data_id]['img'])


7


再一次,让我们将所有内容组合成一个函数。


#making a text_search function to streamline the process
def text_search(text):
    res = tbl.search(embed_txt(text)).limit(5).to_df()
    print(res)
    for i in range(len(res)):
        print(Animal(res['label'][i]).name)
        data_id = int(res['id'][i])
        display(dataset[data_id]['img'])


很好,我们已经用CLIP模型用于SQL、关键词、图片和文本搜索。


例子2:使用CLIP进行多模态搜索


加载数据。我们将使用已经储存在S3桶中的diffusiondb数据。


8


!wget https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz
!tar -xvf diffusiondb_lance.tar.gz
!mv diffusiondb_test rawdata.lance


创建并打开LanceDB表。


import pyarrow.compute as pc
import lance
db = lancedb.connect("~/datasets/demo")
if "diffusiondb" in db.table_names():
    tbl= db.open_table("diffusiondb")
else:
    # First data processing and full-text-search index
    data = lance.dataset("rawdata.lance/diffusiondb_test").to_table()
    # remove null prompts
    tbl = db.create_table("diffusiondb", data.filter(~pc.field("prompt").is_null()), mode="overwrite")
    tbl.create_fts_index(["prompt"])


“创建CLIP嵌入,针对文本。


from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
MODEL_ID = "openai/clip-vit-base-patch32"
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
Crea
def embed_func(query):
    inputs = tokenizer([query], padding=True, return_tensors="pt")
    text_features = model.get_text_features(**inputs)
    return text_features.detach().numpy()[0]


让我们看看这个模式,以及LanceDB表中的数据。


tbl.schema
tbl.to_pandas().head()


9


现在,为了正确地可视化我们的嵌入和数据,我们将创建一个Gradio界面。在此之前,让我们构建一些实用的搜索功能。


#find the image vectors from the database
def find_image_vectors(query):
    emb = embed_func(query)
    code = (
        "import lancedb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        f"embedding = embed_func('{query}')\n"
        "tbl.search(embedding).limit(9).to_df()"
    )
    return (_extract(tbl.search(emb).limit(9).to_df()), code)
#find the image keywords
def find_image_keywords(query):
    code = (
        "import lancedb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        f"tbl.search('{query}').limit(9).to_df()"
    )
    return (_extract(tbl.search(query).limit(9).to_df()), code)
#using SQL style commands to find the image
def find_image_sql(query):
    code = (
        "import lancedb\n"
        "import duckdb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        "diffusiondb = tbl.to_lance()\n"
        f"duckdb.sql('{query}').to_df()"
    )    
    diffusiondb = tbl.to_lance()
    return (_extract(duckdb.sql(query).to_df()), code)
#extract the image
def _extract(df):
    image_col = "image"
    return [(PIL.Image.open(io.BytesIO(row[image_col])), row["prompt"]) for _, row in df.iterrows()]


让我们设置Gradio界面。


import gradio as gr
#gradio block
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Tab("Embeddings"):
            vector_query = gr.Textbox(value="portraits of a person", show_label=False)
            b1 = gr.Button("Submit")
        with gr.Tab("Keywords"):
            keyword_query = gr.Textbox(value="ninja turtle", show_label=False)
            b2 = gr.Button("Submit")
        with gr.Tab("SQL"):
            sql_query = gr.Textbox(value="SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9", show_label=False)
            b3 = gr.Button("Submit")
    with gr.Row():
        code = gr.Code(label="Code", language="python")
    with gr.Row():
        gallery = gr.Gallery(
                label="Found images", show_label=False, elem_id="gallery"
            ).style(columns=[3], rows=[3], object_fit="contain", height="auto")   
        
    b1.click(find_image_vectors, inputs=vector_query, outputs=[gallery, code])
    b2.click(find_image_keywords, inputs=keyword_query, outputs=[gallery, code])
    b3.click(find_image_sql, inputs=sql_query, outputs=[gallery, code])
    
demo.launch()


10


我们也可以通过图像和文本进行搜索。


例子3:多模态视频搜索


现在,我们将使用它来搜索视频。


我们已经制作了一个包含Lance格式数据的tar文件。


#getting the data
!wget https://vectordb-recipes.s3.us-west-2.amazonaws.com/multimodal_video_lance.tar.gz
!tar -xvf multimodal_video_lance.tar.gz
!mkdir -p data/video-lancedb
!mv multimodal_video.lance data/video-lancedb/


创建表


#intialize the db and open a table
db = lancedb.connect("data/video-lancedb")
tbl = db.open_table("multimodal_video")


带有分词器、处理器和嵌入功能的CLIP模型


from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
MODEL_ID = "openai/clip-vit-base-patch32"
#load the tokenizer and processor for CLIP model
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
#embedding function for the query
def embed_func(query):
    inputs = tokenizer([query], padding=True, return_tensors="pt")
    text_features = model.get_text_features(**inputs)
    return text_features.detach().numpy()[0]


我们将会使用Gradio,因此让我们事先定义一些搜索功能函数。


#function to find the vectors most relevant to a video
def find_video_vectors(query):
    emb = embed_func(query)
    code = (
        "import lancedb\n"
        "db = lancedb.connect('data/video-lancedb')\n"
        "tbl = db.open_table('multimodal_video')\n\n"
        f"embedding = embed_func('{query}')\n"
        "tbl.search(embedding).limit(9).to_df()"
    )
    return (_extract(tbl.search(emb).limit(9).to_df()), code)
#function to find the search for the video keywords from lancedb
def find_video_keywords(query):
    code = (
        "import lancedb\n"
        "db = lancedb.connect('data/video-lancedb')\n"
        "tbl = db.open_table('multimodal_video')\n\n"
        f"tbl.search('{query}').limit(9).to_df()"
    )
    return (_extract(tbl.search(query).limit(9).to_df()), code)
#create a SQL command to retrieve the video from the db
def find_video_sql(query):
    code = (
        "import lancedb\n"
        "import duckdb\n"
        "db = lancedb.connect('data/video-lancedb')\n"
        "tbl = db.open_table('multimodal_video')\n\n"
        "videos = tbl.to_lance()\n"
        f"duckdb.sql('{query}').to_df()"
    )
    videos = tbl.to_lance()
    return (_extract(duckdb.sql(query).to_df()), code)
#extract the video from the df
def _extract(df):
    video_id_col = "video_id"
    start_time_col = "start_time"
    grid_html = '<div style="display: grid; grid-template-columns: repeat(3, 1fr); grid-gap: 20px;">'
    for _, row in df.iterrows():
        iframe_code = f'<iframe width="100%" height="315" src="https://www.youtube.com/embed/{row[video_id_col]}?start={str(row[start_time_col])}" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>'
        grid_html += f'<div style="width: 100%;">{iframe_code}</div>'
    grid_html += '</div>'
    return grid_html


设置音频界面


import gradio as gr
#gradio block
with gr.Blocks() as demo:
    gr.Markdown('''
            # Multimodal video search using CLIP and LanceDB
            We used LanceDB to store frames every thirty seconds and the title of 13000+ videos, 5 random from each top category from the Youtube 8M dataset. 
            Then, we used the CLIP model to embed frames and titles together. With LanceDB, we can perform embedding, keyword, and SQL search on these videos.
            ''')
    with gr.Row():
        with gr.Tab("Embeddings"):
            vector_query = gr.Textbox(value="retro gaming", show_label=False)
            b1 = gr.Button("Submit")
        with gr.Tab("Keywords"):
            keyword_query = gr.Textbox(value="ninja turtles", show_label=False)
            b2 = gr.Button("Submit")
        with gr.Tab("SQL"):
            sql_query = gr.Textbox(value="SELECT DISTINCT video_id, * from videos WHERE start_time > 0 LIMIT 9", show_label=False)
            b3 = gr.Button("Submit")
    with gr.Row():
        code = gr.Code(label="Code", language="python")
    with gr.Row():
        gallery = gr.HTML()
        
    b1.click(find_video_vectors, inputs=vector_query, outputs=[gallery, code])
    b2.click(find_video_keywords, inputs=keyword_query, outputs=[gallery, code])
    b3.click(find_video_sql, inputs=sql_query, outputs=[gallery, code])
    
demo.launch()


11


好了,希望这篇文章对你有帮助!




文章来源:https://blog.lancedb.com/multi-modal-ai-made-easy-with-lancedb-clip-5aaf8801c939
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消