Thanks a lot for the example, it is great that you have explained the errors instead of just posting the correct implementation. I was able to run the flux-dev pipeline with only slight modifications based on the guidelines in the post. For those interested, this involved an extra pytree registration and a slightly edited compilation code. Flux pipeline has two text encoders and the second one (T5) returns a different class that required slightly more delicate flatten/unflatten functions:
def base_model_output_with_past_and_cross_attentions_flatten(v: BaseModelOutputWithPastAndCrossAttentions):
return list(v.values()), list(v.keys())
def base_model_output_with_past_and_cross_attentions_unflatten(aux_data, children):
return BaseModelOutputWithPastAndCrossAttentions(**dict(zip(aux_data, children)))
register_pytree_node(
BaseModelOutputWithPastAndCrossAttentions,
base_model_output_with_past_and_cross_attentions_flatten,
base_model_output_with_past_and_cross_attentions_unflatten
)
The model weights are bfloat16 but I let the vae run in float32. Without this the decoding stage fails. I guess we could arrange the jax precision flags to let the tensors stay in bfloat16.
pipe = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16)
pipe.vae = pipe.vae.to(torch.float32)
We need to compile the two text encoders with appropriate static arguments:
with env:
pipe.transformer = torchax.compile(
pipe.transformer, torchax.CompileOptions(
jax_jit_kwargs={'static_argnames': ('return_dict', )}
))
pipe.vae = torchax.compile(
pipe.vae,
torchax.CompileOptions(
methods_to_compile=['decode'],
jax_jit_kwargs={'static_argnames': ('return_dict', )}
))
pipe.text_encoder = torchax.compile(
pipe.text_encoder,
torchax.CompileOptions(
jax_jit_kwargs={'static_argnames': ('output_hidden_states',)}
))
pipe.text_encoder_2 = torchax.compile(
pipe.text_encoder_2,
torchax.CompileOptions(
jax_jit_kwargs={'static_argnames': ('output_hidden_states',)}
))