Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The inference results are incorrect #683

Open
shahidalihakro opened this issue Feb 2, 2025 · 2 comments
Open

The inference results are incorrect #683

shahidalihakro opened this issue Feb 2, 2025 · 2 comments

Comments

@shahidalihakro
Copy link

shahidalihakro commented Feb 2, 2025

following example produce all true when sequence length is less than 30 but when it's above 30 it produce incorrect result in inference why. anyone know why it's like this ?

@torch.inference_mode()
def run():
batch, length, dim = 2, 29, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
layer_idx=0,
).to("cuda")

# Training-style forward pass (full sequence in parallel)
y1 = model(x)
assert y1.shape == x.shape

# Inference-style forward pass (full sequence in parallel)
infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
y2 = model(x, inference_params=infer_params)

# Inference-style forward pass (step by step using for loop)
infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
outs = []
for i in range(length):
    out = model(x[:, i : i + 1, :], inference_params=infer_params)
    infer_params.seqlen_offset += 1
    outs.append(out)
y3 = torch.cat(outs, 1)

print(torch.allclose(y1, y2))  # prints True
print(torch.allclose(y2, y3))  # prints True
print(torch.allclose(y1, y3))  # prints True

if name == 'main':
run()

@karannb
Copy link

karannb commented Feb 2, 2025

I had raised this issue a while back, #571 . Please let me know if you figure out anything.

@shahidalihakro
Copy link
Author

I had raised this issue a while back, #571 . Please let me know if you figure out anything.

Sure, I will let you know if I can figure out. if you find solution also let me know please.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants