
Fix the auto-cast of `linalg.batch_reduce_matmul` from `cast_to_T(A * cast_to_T(B)) + C` to `cast_to_T(A) * cast_to_T(B) + C`
Fix the auto-cast of `linalg.batch_reduce_matmul` from `cast_to_T(A * cast_to_T(B)) + C` to `cast_to_T(A) * cast_to_T(B) + C`