Below we describe how to use these schedules in practice. We'll be using the timesteps
parameter in StableDiffusionPipeline
to pass in our custom schedules. However, since the default DPMSolverMultistepScheduler
implementation doesn't support setting custom schedules, we must add in the functionality ourselves.
import numpy as np
from diffusers import DPMSolverMultistepScheduler as DefaultDPMSolver
# Add support for setting custom timesteps
class DPMSolverMultistepScheduler(DefaultDPMSolver):
def set_timesteps(
self, num_inference_steps=None, device=None,
timesteps=None
):
if timesteps is None:
super().set_timesteps(num_inference_steps, device)
return
all_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.sigmas = torch.from_numpy(all_sigmas[timesteps])
self.timesteps = torch.tensor(timesteps[:-1]).to(device=device, dtype=torch.int64) # Ignore the last 0
self.num_inference_steps = len(timesteps)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
For discrete diffusion models - such as Stable Diffusion, SDXL, and DeepFloyd - the list of timestep indices can be passed with the timesteps
parameter.
import torch
from diffusers import StableDiffusionPipeline
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16").to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
sampling_schedule = [999, 850, 736, 645, 545, 455, 343, 233, 124, 24, 0]
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(
prompt, timesteps=sampling_schedule,
).images[0]
image.save("astronaut_rides_horse.png")
For continuous diffusion models like EDM, such as Stable Video Diffusion (SVD), the noise level values can be directly given to the model as their sigma
inputs. Unfortunately, the default SVD implementation in diffusers does not support a timesteps
input, so the underlying code must be changed.
sigma_max = 700.00
sigma_min = 0.002
num_steps = 10
# EDM Schedule
rho = 7
t_steps = torch.linspace(sigma_max ** (1 / rho), sigma_min ** (1 / rho), num_steps + 1) ** rho
# Optimized schedule
t_steps = [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]
For higher number of steps, we've found that log-linearly interpolating the noise levels works well in practice.
import numpy as np
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = np.linspace(0, 1, len(t_steps))
ys = np.log(t_steps[::-1])
new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy()
return interped_ys
sampling_schedule = [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]
num_steps = 20
t_steps = loglinear_interp(sampling_schedule, num_steps)
Model | Schedule (noise levels) | Schedule (timestep indices) |
---|---|---|
Stable Diffusion 1.5 | [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029] | [999, 850, 736, 645, 545, 455, 343, 233, 124, 24, 0] |
SDXL | [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029] | [999, 845, 730, 587, 443, 310, 193, 116, 53, 13, 0] |
DeepFloyd-IF / Stage-1 | [160.41, 8.081, 3.315, 1.885, 1.207, 0.785, 0.553, 0.293, 0.186, 0.030, 0.006] | [995, 920, 811, 686, 555, 418, 315, 174, 109, 12, 0] |
Stable Video Diffusion | [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002] | NA |