Image Generation Prompt iteration¶
This is based off of a tweet from @ThorondorLLC
Tweet is here
This will take an initial desired prompt, and iteratively refine it until the image generated matches the desired prompt.
This is not DSPy prompt optimization as it is normally used, but it is a good example of how to use multimodal DSPy.
A future upgrade would be to create a dataset of initial, final prompts to optimize the prompt generation.
You can install DSPy via:
pip install -U dspy
For this example, we'll use Flux Pro from FAL. You can get an API key here
We will also need to install Pillow and dotenv.
pip install fal-client pillow dotenv
Now, let's import the necessary libraries and set up the environment:
In [ ]:
Copied!
# Optional
#os.environ["FAL_API_KEY"] = "your_fal_api_key"
#os.environ["OPENAI_API_KEY"] = "your_openai_api_key"
# Optional
#os.environ["FAL_API_KEY"] = "your_fal_api_key"
#os.environ["OPENAI_API_KEY"] = "your_openai_api_key"
In [1]:
Copied!
import dspy
from PIL import Image
from io import BytesIO
import requests
import fal_client
from dotenv import load_dotenv
load_dotenv()
# import display
from IPython.display import display
lm = dspy.LM(model="gpt-4o-mini", temperature=0.5)
dspy.settings.configure(lm=lm)
import dspy
from PIL import Image
from io import BytesIO
import requests
import fal_client
from dotenv import load_dotenv
load_dotenv()
# import display
from IPython.display import display
lm = dspy.LM(model="gpt-4o-mini", temperature=0.5)
dspy.settings.configure(lm=lm)
/Users/isaac/sd_optimizer/.venv/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2: * 'fields' has been removed warnings.warn(message, UserWarning) /Users/isaac/sd_optimizer/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
In [9]:
Copied!
def generate_image(prompt):
request_id = fal_client.submit(
"fal-ai/flux-pro/v1.1-ultra",
arguments={
"prompt": prompt
},
).request_id
result = fal_client.result("fal-ai/flux-pro/v1.1-ultra", request_id)
url = result["images"][0]["url"]
return dspy.Image.from_url(url)
def display_image(image):
url = image.url
# download the image
response = requests.get(url)
image = Image.open(BytesIO(response.content))
# display at 25% of original size
display(image.resize((image.width // 4, image.height // 4)))
def generate_image(prompt):
request_id = fal_client.submit(
"fal-ai/flux-pro/v1.1-ultra",
arguments={
"prompt": prompt
},
).request_id
result = fal_client.result("fal-ai/flux-pro/v1.1-ultra", request_id)
url = result["images"][0]["url"]
return dspy.Image.from_url(url)
def display_image(image):
url = image.url
# download the image
response = requests.get(url)
image = Image.open(BytesIO(response.content))
# display at 25% of original size
display(image.resize((image.width // 4, image.height // 4)))
In [18]:
Copied!
check_and_revise_prompt = dspy.Predict("desired_prompt: str, current_image: dspy.Image, current_prompt:str -> feedback:str, image_strictly_matches_desired_prompt: bool, revised_prompt: str")
initial_prompt = "A scene that's both peaceful and tense"
current_prompt = initial_prompt
max_iter = 5
for i in range(max_iter):
print(f"Iteration {i+1} of {max_iter}")
current_image = generate_image(current_prompt)
result = check_and_revise_prompt(desired_prompt=initial_prompt, current_image=current_image, current_prompt=current_prompt)
display_image(current_image)
if result.image_strictly_matches_desired_prompt:
break
else:
current_prompt = result.revised_prompt
print(f"Feedback: {result.feedback}")
print(f"Revised prompt: {result.revised_prompt}")
print(f"Final prompt: {current_prompt}")
check_and_revise_prompt = dspy.Predict("desired_prompt: str, current_image: dspy.Image, current_prompt:str -> feedback:str, image_strictly_matches_desired_prompt: bool, revised_prompt: str")
initial_prompt = "A scene that's both peaceful and tense"
current_prompt = initial_prompt
max_iter = 5
for i in range(max_iter):
print(f"Iteration {i+1} of {max_iter}")
current_image = generate_image(current_prompt)
result = check_and_revise_prompt(desired_prompt=initial_prompt, current_image=current_image, current_prompt=current_prompt)
display_image(current_image)
if result.image_strictly_matches_desired_prompt:
break
else:
current_prompt = result.revised_prompt
print(f"Feedback: {result.feedback}")
print(f"Revised prompt: {result.revised_prompt}")
print(f"Final prompt: {current_prompt}")
Iteration 1 of 5