In this blogpost, we present the key highlights and rationales about the Falcon-Edge series - a collection of powerful, universal, and fine-tunable language models available in ternary format, based on the BitNet architecture.

Drawing from our experience with BitNet, Falcon-Edge introduces and validates an new pre-training paradigm that delivers a full-scope output from a single training process, simultaneously yielding both non-quantized and quantized model variants. This comprehensive approach produces a non-BitNet model in bfloat16 format, the native BitNet model, and a pre-quantized BitNet variant specifically engineered for effortless fine-tuning, enabling users and developers to precisely tailor these models to their specific applications and needs.

Available now in two sizes—1 Billion and 3 Billion parameters—each size comes in both base and instruction-tuned models. Discover the Falcon-Edge series on our dedicated Hugging Face collection.

Instruct models performance

Base models performance

Introduction

Large Language Models (LLMs), by design, are inherently large and resource-intensive. As demand grows to deploy these models efficiently on edge devices, research into model compression has accelerated. Recent efforts, such as those by DeepSeek and Llama 4, explore training with reduced precision formats—down to FP8—to improve deployment scalability. On the other hand, many state-of-the-art methods emphasize post-training quantization. In contrast to these approaches, BitNet introduces a fundamentally different paradigm: unlike reduced-precision training which still relies on floating-point formats, and post-training quantization which adjusts weights after full-precision training, BitNet operates with the lowest possible precision — ternary weights ({-1, 0, 1}) — directly during training, enabling an end-to-end ultra-efficient model design.

These ternary weights are paving the way for a “matmul-free” LLM design that is notably faster and remarkably memory-efficient in practice. The primary challenge of this innovative approach is the necessity for pre-training BitNet models, which can be computationally demanding and costly for typical users.

Unleashing the full potential of Bitnet models

Six months ago, Microsoft introduced bitnet.cpp, a framework designed to accelerate CPU inference speeds by up to 5 times for certain architectures. This advancement makes BitNet models highly appealing for local deployment, significantly enhancing their readiness for production and ease of use across various applications. However, from the community’s viewpoint, BitNet is still largely regarded as a proof of concept or prototype, primarily due to two key challenges:

  • Model Performance: Despite efforts by the community, including our own, to develop robust BitNet models, recent 1-bit large language models (LLMs) have struggled to match the overall performance of similarly sized models using other frameworks. This was the case at the time of writing the blogpost but Microsoft recently released a powerful BitNet model which brought a lot of interest from community.
  • Accessibility: The process of 1-bit fine-tuning—which involves converting a non-BitNet model checkpoint into a BitNet one—has proven ineffective. Currently, pre-training appears to be the most viable approach.

With this release, we aim to demonstrate that a novel pre-training approach has the potential to overcome these limitations and cater to a wide range of applications. In the subsequent sections, we will delve deeper into each of these issues and explore how our new paradigm addresses them.

Proposed architecture

We adopted the architecture outlined in the paper The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits, but made a key modification by eliminating the Layer Normalization layers within the BitNet layers. However, we retained the original pre-attention and pre-MLP layer norms to ensure compatibility with the Llama architecture, allowing seamless integration from the outset. Interestingly, we discovered that removing these Layer Normalization layers had no adverse effect on model performance, while also ensuring compatibility with the broader ecosystem with minimal adjustments.

Beyond the methods described in the paper, we implemented optimized Triton kernels for both activation_quant and weight_quant, significantly lowering the pre-training costs of our models. We are making these kernels accessible to the community through the Python package onebitllms, enabling researchers and developers to leverage them for efficient BitNet pre-training and fine-tuning.

To further reduce the memory footprint of the final model, we intentionally opted for a smaller vocabulary size of 32678. This tokenizer was trained on a large English-focused corpus, with the most common LaTeX tokens manually added to the vocabulary.

Falcon-Edge, a series of powerful Bitnet models

Leveraging the learnings from pre-training data strategies from our center, we pre-train our model on an internal data mixture for approximately 1.5 Tera Tokens. We use the classic WSD learning rate scheduler for pre-training.

We evaluate our models (base and instruct versions) on the former Hugging Face leaderboard v2 benchmark and report the normalized results below in comparison with other models of similar size:

For 1B scale models and below

Detailed results:
ModelNb ParamsMem FootprintIFEVALMath-HardGPQAMuSRBBHMMLU-ProAvg.
Qwen2.5-0.5B0.5B1GB16.273.930.02.086.9510.066.55
Qwen3-0.6B0.6B1.5GB23.195.893.783.3811.7513.9810.32
SmolLM2-360M0.36B720MB21.151.210.07.735.541.886.25
Qwen2.5-1.5B1.5B3.1GB26.749.1416.665.2720.614.713.85
Qwen3-1.7B1.5B4.06GB26.8113.956.7113.6520.4026.4517.99
SmolLM2-1.7B1.7B3.4GB24.42.649.34.612.643.919.58
Falcon-3-1B-Base1.5B3GB24.283.3211.349.716.763.919.89
Falcon-E-1B-Base1.8B665MB32.910.972.83.6512.2817.8213.40

For 3B scale models

Detailed result:
ModelNb ParamsMem FootprintIFEVALMath-HardGPQAMuSRBBHMMLU-ProAvg.
Falcon-3-3B-Base3B6.46GB15.7411.786.266.2721.5818.0915.74
Qwen2.5-3B3B6.17GB26.914.86.3811.7624.324.4818.1
Falcon-E-3B-Base3B999MB36.6713.458.674.1419.8327.1618.32

Below are the results for instruction fine-tuned models (for Qwen3 series, we disable the thinking mode during evaluation):

Detailed results:
ModelNb ParamsMem FootprintIFEVALMath-HardGPQAMuSRBBHMMLU-ProAvg.
Qwen2.5-0.5B-Instruct500M1GB31.5310.351.231.378.178.0010.11
Qwen3-0.6B600M1.5GB62.1516.152.811.866.137.7116.13
SmolLM2-360M-Instruct360M720MB38.421.510.672.774.171.38.14
Qwen2.5-1.5B-Instruct1.5B3.1GB44.7622.050.783.1919.9119.9918.43
Qwen3-1.7B1.7B4.02GB70.7638.173.298.329.2615.0424.14
SmolLM2-1.7B1.7B3.4GB53.685.8204.110.9211.7115.02
Bitnet-b1.58-2B-4T2B1.18GB59.117.235.251.7417.9414.8217.70
Falcon-3-1B-Instruct1.5B3GB55.576.342.2410.5612.969.3216.16
Falcon-E-1B-Instruct1.8B665MB54.359.129.642.5119.4216.518.59

For 3B scale models:

Detailed results:
ModelNb ParamsMem FootprintIFEVALMath-HardGPQAMuSRBBHMMLU-ProAvg.
Falcon-3-3B-Instruct3B6.46GB69.77255.1511.1326.2922.2826.6
Qwen2.5-3B-Instruct3B6.17GB64.7536.783.027.5725.8025.0527.16
Falcon-E-3B-Instruct3B999MB60.9715.37.452.1223.5926.4522.65

Additional results (leaderboard v1) on comparing our instructed models with Microsoft’s new BitNet model:

ModelNb ParamsMem FootprintARC-ChallengeGSM8KHellaSwagMMLUTruthfulQAAverage
Bitnet-b1.58-2B-4T2B1.18GB38.3165.2759.0247.4347.6551.54
Falcon-E-1B-Instruct1.8B665MB36.6054.7450.1948.3342.0946.39
Falcon-E-3B-Instruct3B999MB43.0964.5256.9755.7045.5853.17

Falcon-Edge demonstrates on-par and better performances than models of comparable sizes on the leaderboard v2 tasks, demonstrating that it is possible to train powerful BitNet models on desired domains while being competitive enough on other tasks.

Falcon-Edge, a series of universal models

If we look closer at the formula of the BitNet linear layer for inference (in terms of Python code):

def activation_norm_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127)
    return y, scale

class BitLinear(nn.Linear):
    
    def post_quant_process(self, input, input_scale, weight_scale):
        out = input / (input_scale * weight_scale)
        return out

    def forward(self, input):
        w = self.weight
        w_quant = unpack_weights(w, dtype=self.dtype)
        input_quant, input_scale = self.activation_quant(input)
        y = F.linear(input_quant.to(self.dtype), w_quant)
        y = self.post_quant_process(y, self.weight_scale, input_scale)
        if self.bias is not None:
            y += self.bias.view(1, -1).expand_as(y)
        return y

The normalization activation_norm_quant quantizes the activations in int8 format, then the activation is computed back in half precision by diving it by x_scale. Since the model has been trained with fake 8-bit activation quantization, we argue that it is possible to approximate that:

x_quant, x_scale = activation_norm_quant(x)
x ~= (x_quant / x_scale)

Therefore, instead of quantizing the model post-training, injecting the weight scale after quantizing the weights should lead to a good enough “approximation” of the non-BitNet version of the model:

def _weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-05)
    u = (w * scale).round().clamp_(-1, 1)
    return u, scale

for param_name, param_value in state_dict.items():
    if _is_param_to_not_quantize(param_name):
        continue

    param_value, param_scale = _weight_quant(param_value)
    param_value = param_value / param_scale

    state_dict_quant[param_name] = param_value

We confirm this by running end-to-end evaluations on the bfloat16 variant of our 1B and 3B base models and below are the results:

ModelIFEVALMath-HardGPQAMuSRBBHMMLU-ProAvg.
Falcon-E-1B32.910.972.83.6512.2817.8213.40
Falcon-E-1B-bf1629.8911.231.83.3212.2718.0412.75
Falcon-E-3B36.6713.458.674.1419.8327.1618.32
Falcon-E-3B-bf1634.8413.218.914.8820.327.0018.19

The bfloat16 counterparts of the models can be loaded directly via Hugging Face transformers by passing revision="bfloat16" in the from_pretrained function:

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer

model_id = "tiiuae/Falcon-E-1B-Base"

tokenizer = AutoTokenizer.from_pretrained(model_id, revision="prequantized")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    revision="bfloat16"
)

Falcon-Edge, a series of fine-tunable Bitnet models

To the best of our knowledge, except from the most recent release from Microsoft previous BitNet releases only focus on releasing the final quantized model, making it usable only for inference. Similarly to the release from Microsoft, we propose to extend the accessibility of research and application of BitNet models by releasing their pre-quantized weights. That way, users can either perform fine-tuning on their target domain, or do continuous pre-training of the BitNet checkpoint as long as nn.Linear layers are replaced by BitnetLinear layers, and by making sure to quantize the model post training in BitNet format. Since the weights corresponds to the pre-quantized weights, performing text generation without replacing the nn.Linear layers with BitnetLinear layers will produce gibberish output.

The pre-quantized weights can be downloaded via Hugging Face’s transformers library by specifying the revision argument to be prequantized:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "tiiuae/Falcon-E-1B-Base"

tokenizer = AutoTokenizer.from_pretrained(model_id, revision="prequantized")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    revision="prequantized"
)

This way, we will help fostering an ecosystem around first powerful 1-bit fine-tunes by the community. We provide to community the tools to get easily started and fine-tune their own version of powerful BitNet models by packaging all needed utility methods for performing fine-tuning on the pre-quantized weights on a Python package called onebitllms that we will cover in the next section.

Introducing onebitllms - a lightweight python package for 1-bit LLMs training toolkit

Library logo

In this release, we also introduce onebitllms - a lightweight Python package that can be plugged into your favorite LLM fine-tuning tools in order to fine-tune any pre-quantized BitNet model. At this time of writing onebitllms exposes these main functionalities:

  • Utility method to convert the prequantized model checkpoints into BitNet training format in order to pass it to any of your favorite LLM fine-tuning framework. We currently tested our library with Hugging Face’s trl library.
  • Utility method to quantize the trained checkpoint in BitNet format as well as in usual bfloat16 format.
  • Fore more fine-grained control: Bare BitnetLinear and triton kernels that be injected and used for your pre-training framework.

Currently, only full-finetuning is supported through this framework, while in this release the model sizes are relatively small, supporting Parameter-Efficient Fine-tuning (PEFT) methods for BitNet models remains an exciting and impactful open question for upcoming BitNet models.

To get started, simply install the package directly through pip or from source, and take a look at examples/ folders inside the source code.

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer
from onebitllms import replace_linear_with_bitnet_linear, quantize_to_1bit

model_id = "tiiuae/Falcon-E-1B-Base"

tokenizer = AutoTokenizer.from_pretrained(model_id, revision="prequantized")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    revision="prequantized"
)
model = replace_linear_with_bitnet_linear(model)

trainer = SFTTrainer(
    model,
    ...
)

trainer.train()

quantize_to_1bit(output_directory)

With this package, we hope to accelerate research and development around ternary format LLMs, and hope to see many derivations of Falcon-Edge and other future powerful BitNet models developed by the community.

Going further

We believe this release opens up multiple interesting directions - among all the possible follow up directions, we currently think that the following open questions will make BitNet models much more impactful in the near future:

  • Writing more powerful GPU inference kernels for BitNet architecture: leveraging the same core ideas behind bitnet.cpp, we hope that this release will convince the research community to focus on developping powerful BitNet inference kernels for faster inference on GPUs - thus making them faster than native models on GPUs.
  • Support PEFT methods for BitNet fine-tuning: This remains an unexplored research question that can open up multiple new possibilities for BitNet models.
  • More rigourous investigation on the universality of Bitnet checkpoints: While we observe that simply injecting the weight scale leads to having a descent non-Bitnet checkpoint, we believe that more research can be done to minimize the performance degradation between the Bitnet checkpoint and its bfloat16 counterpart, thus making it fully performance degradation-free.
  • On multi-modal Bitnet models: We hope these Bitnet foundational models together with onebitllms package can serve a as a foundational work for creating first multi-modal Bitnet VLM (Vision Language Model) etc.
  • More optimized Bitnet training kernels: To write our kernels, we decided to take a two stages approach to first compute the global maximum to later use it block-wise for normalization. This approach can be revised to write more efficient kernels. In our tests, we estimate the overhead to be around ~20% between non-Bitnet pre-training against Bitnet pre-training. We will release soon more extensive numbers on the overhead introduced by Bitnet for training.

Citation

If you find this work useful for your research and work, please consider citing our work, as well as citing all the foundational work behind BitNet models:

@misc{tiionebitllms,
    title = {Falcon-E, a series of powerful, universal and fine-tunable 1.58bit language models.},
    author = {Falcon-LLM Team},
    month = {May},
    url = {https://falcon-lm.github.io/blog/falcon-edge},
    year = {2025}
}
More References
@misc{ma2025bitnetb1582b4ttechnical,
      title={BitNet b1.58 2B4T Technical Report}, 
      author={Shuming Ma and Hongyu Wang and Shaohan Huang and Xingxing Zhang and Ying Hu and Ting Song and Yan Xia and Furu Wei},
      year={2025},
      eprint={2504.12285},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2504.12285}, 
}
@misc{wang2025bitnetcppefficientedgeinference,
      title={Bitnet.cpp: Efficient Edge Inference for Ternary LLMs}, 
      author={Jinheng Wang and Hansong Zhou and Ting Song and Shijie Cao and Yan Xia and Ting Cao and Jianyu Wei and Shuming Ma and Hongyu Wang and Furu Wei},
      year={2025},
      eprint={2502.11880},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2502.11880}, 
}
@misc{,
      title={1.58-Bit LLM: A New Era of Extreme Quantization}, 
      author={Mohamed Mekkouri and Marc Sun and Leandro von Werra and Thomas Wolf},
      year={2024},
}
@misc{ma2024era1bitllmslarge,
      title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits}, 
      author={Shuming Ma and Hongyu Wang and Lingxiao Ma and Lei Wang and Wenhui Wang and Shaohan Huang and Li Dong and Ruiping Wang and Jilong Xue and Furu Wei},
      year={2024},
      eprint={2402.17764},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2402.17764}, 
}
@misc{wang2023bitnetscaling1bittransformers,
      title={BitNet: Scaling 1-bit Transformers for Large Language Models}, 
      author={Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Huaijie Wang and Lingxiao Ma and Fan Yang and Ruiping Wang and Yi Wu and Furu Wei},
      year={2023},
      eprint={2310.11453},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2310.11453}, 
}

Core Contributors

Younes Belkada

Younes Belkada

Qiyang Zhao

Qiyang Zhao

Hang Zou

Hang Zou

DhiaEddine Rhaiem

DhiaEddine Rhaiem

Ilyas Chahed

Ilyas Chahed

Maksim Velikanov

Maksim Velikanov

Jingwei Zuo

Jingwei Zuo

Mikhail Lubinets

Mikhail Lubinets

Hakim Hacid

Hakim Hacid