Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
445ff0d
Adding support for custom protocols
giovannipizzi Nov 28, 2025
82f77d9
Adding possibility to specify cutoffs in the protocol
giovannipizzi Dec 3, 2025
0323343
Fixing pre-commit issues
giovannipizzi Dec 3, 2025
62957f2
Adding support for common protocols in VASP
giovannipizzi Dec 11, 2025
0ecc634
Fixing default protocol
giovannipizzi Dec 11, 2025
419a11a
Fix the type of a port
giovannipizzi Dec 16, 2025
27e6458
Merge branch 'master' of github.com:aiidateam/aiida-common-workflows …
giovannipizzi Dec 18, 2025
af69283
Adding support for explicit k-points in VASP generator
giovannipizzi Dec 18, 2025
a739c24
Draft support for non-collinear spins for the CommonRelaxInputGenerator
ahkole May 19, 2026
42f4c2e
Add non-collinear spin to siesta relax workchain
ahkole May 19, 2026
6147cc6
Correctly set spin parameter for siesta relax workchain for non-colli…
ahkole May 19, 2026
adb903f
Move to_spherical to separate util module + improve robustness of spi…
ahkole May 20, 2026
0b66c74
Clarifying documentation for 'magnetization_per_site'
ahkole May 20, 2026
01b5b5d
Merge branch 'master' of github.com:aiidateam/aiida-common-workflows …
giovannipizzi May 22, 2026
1b68aab
Refactor validation of magnetization_per_site
ahkole May 22, 2026
7229443
Refactor setting magnetization in siesta relax input generator
ahkole May 22, 2026
3de1c3a
Fix punctuation in docs
ahkole May 22, 2026
baa40ce
Add check for magnetization_per_site that line contains correct numbe…
ahkole May 22, 2026
7cc49cd
Merge pull request #2 from ahkole/relax-non-collinear-spin
giovannipizzi May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions src/aiida_common_workflows/workflows/relax/abinit/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def define(cls, spec):
(ElectronicType.METAL, ElectronicType.INSULATOR, ElectronicType.UNKNOWN)
)
spec.inputs['engines']['relax']['code'].valid_type = CodeType('abinit')
spec.inputs['protocol'].valid_type = ChoiceType(('fast', 'moderate', 'precise', 'verification-PBE-v1'))
spec.inputs['protocol'].valid_type = ChoiceType(
('fast', 'moderate', 'precise', 'verification-PBE-v1', 'custom')
)

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR0912,PLR0915
"""Construct a process builder based on the provided keyword arguments.
Expand All @@ -62,6 +64,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
structure = kwargs['structure']
engines = kwargs['engines']
protocol = kwargs['protocol']
custom_protocol = kwargs.get('custom_protocol', None)
spin_type = kwargs['spin_type']
relax_type = kwargs['relax_type']
electronic_type = kwargs['electronic_type']
Expand All @@ -70,7 +73,14 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
threshold_stress = kwargs.get('threshold_stress', None)
reference_workchain = kwargs.get('reference_workchain', None)

protocol = copy.deepcopy(self.get_protocol(protocol))
if protocol == 'custom':
if custom_protocol is None:
raise ValueError(
'the `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'
)
protocol = copy.deepcopy(custom_protocol)
else:
protocol = copy.deepcopy(self.get_protocol(protocol))
code = engines['relax']['code']

pseudo_family_label = protocol.pop('pseudo_family')
Expand All @@ -87,15 +97,31 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
recommended_ecut_wfc, recommended_ecut_rho = pseudo_family.get_recommended_cutoffs(
structure=structure, stringency=cutoff_stringency, unit='Eh'
)

# In both cases, if the protocol "hardcodes" the cutoff(s),
# I use that instead of the one from the pseudopotential family
# since it probably means the user really wanted that cutoff.
# I use try/except since I need to go deep into a dictionary and
# it is easier than using dict.get() a lot of times.
try:
protocol_ecut = protocol['base']['abinit']['parameters']['ecut']
except KeyError:
protocol_ecut = None

try:
protocol_pawecutdg = protocol['base']['abinit']['parameters']['pawecutdg']
except KeyError:
protocol_pawecutdg = None

if pseudo_type == 'pseudo.jthxml':
# JTH XML are PAW; we need `pawecutdg`
cutoff_parameters = {
'ecut': np.ceil(recommended_ecut_wfc),
'pawecutdg': np.ceil(recommended_ecut_rho),
'ecut': protocol_ecut if protocol_ecut is not None else np.ceil(recommended_ecut_wfc),
'pawecutdg': protocol_pawecutdg if protocol_pawecutdg is not None else np.ceil(recommended_ecut_rho),
}
else:
# All others are NC; no need for `pawecutdg`
cutoff_parameters = {'ecut': recommended_ecut_wfc}
cutoff_parameters = {'ecut': protocol_ecut if protocol_ecut is not None else np.ceil(recommended_ecut_wfc)}

override = {
'abinit': {
Expand Down Expand Up @@ -187,31 +213,9 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
warnings.warn(f'input magnetization per site was None, setting it to {magnetization_per_site}')
magnetization_per_site = np.array(magnetization_per_site)

sum_is_zero = np.isclose(sum(magnetization_per_site), 0.0)
all_are_zero = np.all(np.isclose(magnetization_per_site, 0.0))
non_zero_mags = magnetization_per_site[~np.isclose(magnetization_per_site, 0.0)]
all_non_zero_pos = np.all(non_zero_mags > 0.0)
all_non_zero_neg = np.all(non_zero_mags < 0.0)

if all_are_zero: # non-magnetic
warnings.warn(
'all of the initial magnetizations per site are close to zero; doing a non-spin-polarized '
'calculation'
)
elif (sum_is_zero and not all_are_zero) or (
not all_non_zero_pos and not all_non_zero_neg
): # antiferromagnetic
print('Detected antiferromagnetic!')
builder.abinit['parameters']['nsppol'] = 1 # antiferromagnetic system
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
elif not all_are_zero and (all_non_zero_pos or all_non_zero_neg): # ferromagnetic
print('Detected ferromagnetic!')
builder.abinit['parameters']['nsppol'] = 2 # collinear spin-polarization
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
else:
raise ValueError(f'Initial magnetization {magnetization_per_site} is ambiguous')
builder.abinit['parameters']['nsppol'] = 2 # collinear spin-polarization
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
elif spin_type == SpinType.NON_COLLINEAR:
if magnetization_per_site is None:
magnetization_per_site = get_initial_magnetization(structure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_stress(parameters):
def get_forces(parameters):
"""Return the forces array from the given parameters node."""
forces = orm.ArrayData()
forces.set_array(name='forces', array=np.array(parameters.base.attributes.get('forces')))
forces.set_array(name='forces', array=np.array(parameters.base.attributes.get('cart_forces')))
return forces


Expand Down
19 changes: 18 additions & 1 deletion src/aiida_common_workflows/workflows/relax/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ def validate_inputs(value, _):
if value.get('magnetization_per_site') is not None and value.get('fixed_total_cell_magnetization') is not None:
return 'the inputs `magnetization_per_site` and ' '`fixed_total_cell_magnetization` are mutually exclusive.'

if value.get('protocol') == 'custom' and value.get('custom_protocol') is None:
return 'the `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'

if value.get('protocol') != 'custom' and value.get('custom_protocol') is not None:
return 'the `custom_protocol` input can only be provided when the `protocol` input is set to `custom`.'

# TODO: ensure all plugins actually honor this new custom_protocol input! (only QE implemented for now)


class OptionalRelaxFeatures(OptionalFeature):
FIXED_MAGNETIZATION = 'fixed_total_cell_magnetization'
Expand Down Expand Up @@ -45,7 +53,7 @@ def define(cls, spec):
)
spec.input(
'protocol',
valid_type=ChoiceType(('fast', 'moderate', 'precise')),
valid_type=ChoiceType(('fast', 'moderate', 'precise', 'custom')),
default='moderate',
non_db=True,
help='The protocol to use for the automated input generation. This value indicates the level of precision '
Expand Down Expand Up @@ -82,6 +90,15 @@ def define(cls, spec):
'electrons, for the site. This also corresponds to the magnetization of the site in Bohr magnetons '
'(μB).',
)
spec.input(
'custom_protocol',
valid_type=dict,
non_db=True,
required=False,
default=None,
help='A custom protocol dictionary that can be provided when the `protocol` input is set to `custom`. '
'In that case, this dictionary will be used to override the default protocol settings.',
)
spec.input(
'fixed_total_cell_magnetization',
valid_type=OptionalFeatureType(float),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def define(cls, spec):
"""
super().define(spec)
spec.inputs['protocol'].valid_type = ChoiceType(
('fast', 'balanced', 'stringent', 'moderate', 'precise', 'verification-PBE-v1')
('fast', 'balanced', 'stringent', 'moderate', 'precise', 'verification-PBE-v1', 'custom')
)
spec.inputs['spin_type'].valid_type = ChoiceType((SpinType.NONE, SpinType.COLLINEAR))
spec.inputs['relax_type'].valid_type = ChoiceType(
Expand Down Expand Up @@ -155,7 +155,15 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
# Currently, the `aiida-quantumespresso` workflows will expect one of the basic protocols to be passed to the
# `get_builder_from_protocol()` method. Here, we switch to using the default protocol for the
# `aiida-quantumespresso` plugin and pass the local protocols as `overrides`.
if (
if protocol == 'custom':
custom_protocol = kwargs.get('custom_protocol', None)
if custom_protocol is None:
raise ValueError(
'The `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'
)
overrides = custom_protocol
protocol = self._default_protocol
elif (
protocol not in self.process_class._process_class.get_available_protocols()
and self.process_class._process_class._check_if_alias(protocol)
not in self.process_class._process_class.get_available_protocols()
Expand Down
Loading