class FormTool(StructuredTool, ABC):
form: BaseModel = None
state: Union[FormToolState | None] = None
skip_confirm: Optional[bool] = False
# Backup attributes for handling changes in the state
args_schema_: Optional[Type[BaseModel]] = None
description_: Optional[str] = None
name_: Optional[str] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.args_schema_ = None
self.name_ = None
self.description_ = None
self.init_state()
def init_state(self):
state_initializer = {
None: self.enter_inactive_state,
FormToolState.INACTIVE: self.enter_inactive_state,
FormToolState.ACTIVE: self.enter_active_state,
FormToolState.FILLED: self.enter_filled_state
}
state_initializer[self.state]()
def enter_inactive_state(self):
# Guard so that we don't overwrite the original args_schema if
# set_inactive_state is called multiple times
if not self.state == FormToolState.INACTIVE:
self.state = FormToolState.INACTIVE
self.name_ = self.name
self.name = f"{self.name_}Start"
self.description_ = self.description
self.description = f"Starts the form {self.name}, which {self.description_}"
self.args_schema_ = self.args_schema
self.args_schema = FormToolInactivePayload
def enter_active_state(self):
# if not self.state == FormToolState.ACTIVE:
self.state = FormToolState.ACTIVE
self.name = f"{self.name_}Update"
self.description = f"Updates data for form {self.name}, which {self.description_}"
self.args_schema = make_optional_model(self.args_schema_)
if not self.form:
self.form = self.args_schema()
elif isinstance(self.form, str):
self.form = self.args_schema(**json.loads(self.form))
def enter_filled_state(self):
self.state = FormToolState.FILLED
self.name = f"{self.name_}Finalize"
self.description = f"Finalizes form {self.name}, which {self.description_}"
self.args_schema = make_optional_model(self.args_schema_)
if not self.form:
self.form = self.args_schema()
elif isinstance(self.form, str):
self.form = self.args_schema(**json.loads(self.form))
self.args_schema = FormToolConfirmPayload
def activate(
self,
*args,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs
) -> FormToolOutcome:
self.enter_active_state()
return FormToolOutcome(
output=f"Starting form {self.name}. If the user as already provided some information, call {self.name}.",
active_form_tool=self,
tool_choice=self.name
)
def update(
self,
*args,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs
) -> FormToolOutcome:
self._update_form(**kwargs)
if self.is_form_filled():
self.enter_filled_state()
if self.skip_confirm:
return self.finalize(confirm=True)
else:
return FormToolOutcome(
active_form_tool=self,
output="Form is filled. Ask the user to confirm the information."
)
else:
return FormToolOutcome(
active_form_tool=self,
output="Form updated with the provided information. Ask the user for the next field."
)
def finalize(
self,
*args,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs
) -> FormToolOutcome:
if kwargs.get("confirm"):
# The FormTool could use self.form to get the data, but we pass it as kwargs to
# keep the signature consistent with _run
result = self._run_when_complete(**self.form.model_dump())
return FormToolOutcome(
active_form_tool=None,
output=result,
return_direct=self.return_direct
)
else:
self.enter_active_state()
return FormToolOutcome(
active_form_tool=self,
output="Ask the user to update the form."
)
def _run(
self,
*args,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs
) -> str:
match self.state:
case FormToolState.INACTIVE:
return self.activate(*args, **kwargs, run_manager=run_manager)
case FormToolState.ACTIVE:
return self.update(*args, **kwargs, run_manager=run_manager)
case FormToolState.FILLED:
return self.finalize(*args, **kwargs, run_manager=run_manager)
@abstractmethod
def _run_when_complete(self) -> str:
"""
Should raise an exception if something goes wrong.
The message should describe the error and will be sent back to the agent to try to fix it.
"""
def _update_form(self, **kwargs):
try:
model_class = type(self.form)
data = self.form.model_dump()
data.update(kwargs)
# Recreate the model with the new data merged to the old one
# This allows to validate multiple fields at once
self.form = model_class(**data)
except ValidationError as e:
raise ToolException(str(e))
def get_next_field_to_collect(
self,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""
The default implementation returns the first field that is not set.
"""
if self.state == FormToolState.FILLED:
return None
for field_name, field_info in self.args_schema.__fields__.items():
if not getattr(self.form, field_name):
return field_name
def is_form_filled(self) -> bool:
return self.get_next_field_to_collect() is None
def get_tool_start_message(self, input: dict) -> str:
message = ""
match self.state:
case FormToolState.INACTIVE:
message = f"Starting {self.name}"
case FormToolState.ACTIVE:
message = f"Updating form for {self.name}"
case FormToolState.FILLED:
message = f"Completed {self.name}"
return message